Skip to content

Commit 777fc26

Browse files
manzttlambert03
andauthored
feat(experimental): Add command system for calling Python from JS (#453)
Co-authored-by: Talley Lambert <[email protected]>
1 parent cbc69db commit 777fc26

File tree

11 files changed

+1553
-1737
lines changed

11 files changed

+1553
-1737
lines changed

.changeset/nervous-goats-rescue.md

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
---
2+
"anywidget": patch
3+
"@anywidget/types": patch
4+
---
5+
6+
Add experimental `invoke` API to call Python functions from the front end and
7+
await the response.
8+
9+
This removes a lot of boilerplate required for this pattern. The API is
10+
experimental and opt-in only. Subclasses must use the `command` to register
11+
functions.
12+
13+
```py
14+
class Widget(anywidget.AnyWidget):
15+
_esm = """
16+
export default {
17+
async render({ model, el, experimental }) {
18+
let [msg, buffers] = await experimental.invoke("_echo", "hello, world");
19+
console.log(msg); // "HELLO, WORLD"
20+
},
21+
};
22+
"""
23+
24+
@anywidget.experimental.command
25+
def _echo(self, msg, buffers):
26+
# upper case the message
27+
return msg.upper(), buffers
28+
```

anywidget/_descriptor.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
from ._protocols import CommMessage
5858

5959
class _GetState(Protocol):
60-
def __call__(self, obj: Any, include: set[str] | None) -> dict:
61-
...
60+
def __call__(self, obj: Any, include: set[str] | None) -> dict: ...
6261

6362
# catch all for types that can be serialized ... too hard to actually type
6463
Serializable: TypeAlias = Any
@@ -205,12 +204,10 @@ def __set_name__(self, owner: type, name: str) -> None:
205204
self._name = name
206205

207206
@overload
208-
def __get__(self, instance: None, owner: type) -> MimeBundleDescriptor:
209-
...
207+
def __get__(self, instance: None, owner: type) -> MimeBundleDescriptor: ...
210208

211209
@overload
212-
def __get__(self, instance: object, owner: type) -> ReprMimeBundle:
213-
...
210+
def __get__(self, instance: object, owner: type) -> ReprMimeBundle: ...
214211

215212
def __get__(
216213
self, instance: object | None, owner: type

anywidget/_protocols.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Sequence
3+
from typing import TYPE_CHECKING, Any, Callable, Sequence
44

55
from typing_extensions import Literal, Protocol, TypedDict
66

@@ -40,7 +40,7 @@ class CommMessage(TypedDict):
4040

4141

4242
class MimeReprCallable(Protocol):
43-
"""Protocol for _repr_mimebundle.
43+
"""Protocol for _repr_mimebundle_.
4444
4545
https://ipython.readthedocs.io/en/stable/config/integrating.html#more-powerful-methods
4646
@@ -52,11 +52,20 @@ class MimeReprCallable(Protocol):
5252

5353
def __call__(
5454
self, include: Sequence[str], exclude: Sequence[str]
55-
) -> dict | tuple[dict, dict]:
56-
...
55+
) -> dict | tuple[dict, dict]: ...
5756

5857

5958
class AnywidgetProtocol(Protocol):
6059
"""Anywidget classes have a MimeBundleDescriptor at `_repr_mimebundle_`."""
6160

6261
_repr_mimebundle_: MimeBundleDescriptor
62+
63+
64+
class WidgetBase(Protocol):
65+
"""Widget subclasses with a custom message reducer."""
66+
67+
def send(self, msg: str | dict | list, buffers: list[bytes]) -> None: ...
68+
69+
def on_msg(
70+
self, callback: Callable[[Any, str | list | dict, list[bytes]], None]
71+
) -> None: ...

anywidget/experimental.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import dataclasses
45
import pathlib
56
import typing
@@ -8,6 +9,9 @@
89

910
from ._descriptor import MimeBundleDescriptor
1011

12+
if typing.TYPE_CHECKING: # pragma: no cover
13+
from ._protocols import WidgetBase
14+
1115
__all__ = ["dataclass", "widget", "MimeBundleDescriptor"]
1216

1317
_T = typing.TypeVar("_T")
@@ -103,3 +107,62 @@ def _decorator(cls: T) -> T:
103107
return cls
104108

105109
return _decorator(cls) if cls is not None else _decorator # type: ignore
110+
111+
112+
_ANYWIDGET_COMMAND = "_anywidget_command"
113+
114+
_AnyWidgetCommand = typing.Callable[
115+
[typing.Any, typing.Any, typing.List[bytes]],
116+
typing.Tuple[typing.Any, typing.List[bytes]],
117+
]
118+
119+
120+
def command(cmd: _AnyWidgetCommand) -> _AnyWidgetCommand:
121+
"""Mark a function as a command for anywidget."""
122+
setattr(cmd, _ANYWIDGET_COMMAND, True)
123+
return cmd
124+
125+
126+
_AnyWidgetCommandBound = typing.Callable[
127+
[typing.Any, typing.List[bytes]], typing.Tuple[typing.Any, typing.List[bytes]]
128+
]
129+
130+
131+
def _collect_commands(widget: WidgetBase) -> dict[str, _AnyWidgetCommandBound]:
132+
cmds: dict[str, _AnyWidgetCommandBound] = {}
133+
for attr_name in dir(widget):
134+
# suppressing silly assertion erro from ipywidgets _staticproperty
135+
# ref: https://github.com/jupyter-widgets/ipywidgets/blob/b78de43e12ff26e4aa16e6e4c6844a7c82a8ee1c/python/ipywidgets/ipywidgets/widgets/widget.py#L291-L297
136+
with contextlib.suppress(AssertionError):
137+
attr = getattr(widget, attr_name)
138+
if callable(attr) and getattr(attr, _ANYWIDGET_COMMAND, False):
139+
cmds[attr_name] = attr
140+
return cmds
141+
142+
143+
def _register_anywidget_commands(
144+
widget: WidgetBase,
145+
) -> None:
146+
"""Register a custom message reducer for a widget if it implements the protocol."""
147+
# Only add the callback if the widget has any commands.
148+
cmds = _collect_commands(widget)
149+
if len(cmds) == 0:
150+
return None
151+
152+
def handle_anywidget_command(
153+
self: WidgetBase, msg: str | list | dict, buffers: list[bytes]
154+
) -> None:
155+
if not isinstance(msg, dict) or msg.get("kind") != "anywidget-command":
156+
return
157+
cmd = cmds[msg["name"]]
158+
response, buffers = cmd(msg["msg"], buffers)
159+
self.send(
160+
{
161+
"id": msg["id"],
162+
"kind": "anywidget-command-response",
163+
"response": response,
164+
},
165+
buffers,
166+
)
167+
168+
widget.on_msg(handle_anywidget_command)

anywidget/widget.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
try_file_contents,
1818
)
1919
from ._version import __version__
20+
from .experimental import _register_anywidget_commands
2021

2122

2223
class AnyWidget(ipywidgets.DOMWidget): # type: ignore [misc]
@@ -57,6 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5758

5859
self.add_traits(**anywidget_traits)
5960
super().__init__(*args, **kwargs)
61+
_register_anywidget_commands(self)
6062

6163
def __init_subclass__(cls, **kwargs: dict) -> None:
6264
"""Coerces _esm and _css to FileContents if they are files."""

packages/anywidget/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"@anywidget/types": "workspace:~",
2525
"@anywidget/vite": "workspace:~",
2626
"@jupyter-widgets/base": "^2 || ^3 || ^4 || ^5 || ^6",
27+
"@lukeed/uuid": "^2.0.1",
2728
"solid-js": "^1.8.16"
2829
},
2930
"devDependencies": {

packages/anywidget/src/widget.js

+62-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import {
2-
createEffect,
3-
createResource,
4-
createRoot,
5-
createSignal,
6-
} from "solid-js";
1+
import * as uuid from "@lukeed/uuid";
2+
import * as solid from "solid-js";
73

84
/**
95
* @typedef AnyWidget
@@ -242,6 +238,50 @@ function throw_anywidget_error(source) {
242238
throw source;
243239
}
244240

241+
/**
242+
* @template T
243+
* @param {import("@anywidget/types").AnyModel} model
244+
* @param {string} name
245+
* @param {any} [msg]
246+
* @param {DataView[]} [buffers]
247+
* @param {{ timeout?: number }} [options]
248+
* @return {Promise<[T, DataView[]]>}
249+
*/
250+
export function invoke(
251+
model,
252+
name,
253+
msg,
254+
buffers = [],
255+
{ timeout = 3000 } = {},
256+
) {
257+
// crypto.randomUUID() is not available in non-secure contexts (i.e., http://)
258+
// so we use simple (non-secure) polyfill.
259+
let id = uuid.v4();
260+
return new Promise((resolve, reject) => {
261+
let timer = setTimeout(() => {
262+
reject(new Error(`Promise timed out after ${timeout} ms`));
263+
model.off("msg:custom", handler);
264+
}, timeout);
265+
266+
/**
267+
* @param {{ id: string, kind: "anywidget-command-response", response: T }} msg
268+
* @param {DataView[]} buffers
269+
*/
270+
function handler(msg, buffers) {
271+
if (!(msg.id === id)) return;
272+
clearTimeout(timer);
273+
resolve([msg.response, buffers]);
274+
model.off("msg:custom", handler);
275+
}
276+
model.on("msg:custom", handler);
277+
model.send(
278+
{ id, kind: "anywidget-command", name, msg },
279+
undefined,
280+
buffers,
281+
);
282+
});
283+
}
284+
245285
class Runtime {
246286
/** @type {() => void} */
247287
#disposer = () => {};
@@ -253,34 +293,38 @@ class Runtime {
253293

254294
/** @param {import("@jupyter-widgets/base").DOMWidgetModel} model */
255295
constructor(model) {
256-
this.#disposer = createRoot((dispose) => {
257-
let [css, set_css] = createSignal(model.get("_css"));
296+
this.#disposer = solid.createRoot((dispose) => {
297+
let [css, set_css] = solid.createSignal(model.get("_css"));
258298
model.on("change:_css", () => {
259299
let id = model.get("_anywidget_id");
260300
console.debug(`[anywidget] css hot updated: ${id}`);
261301
set_css(model.get("_css"));
262302
});
263-
createEffect(() => {
303+
solid.createEffect(() => {
264304
let id = model.get("_anywidget_id");
265305
load_css(css(), id);
266306
});
267307

268308
/** @type {import("solid-js").Signal<string>} */
269-
let [esm, setEsm] = createSignal(model.get("_esm"));
309+
let [esm, setEsm] = solid.createSignal(model.get("_esm"));
270310
model.on("change:_esm", async () => {
271311
let id = model.get("_anywidget_id");
272312
console.debug(`[anywidget] esm hot updated: ${id}`);
273313
setEsm(model.get("_esm"));
274314
});
275315
/** @type {void | (() => import("vitest").Awaitable<void>)} */
276316
let cleanup;
277-
this.#widget_result = createResource(esm, async (update) => {
317+
this.#widget_result = solid.createResource(esm, async (update) => {
278318
await safe_cleanup(cleanup, "initialize");
279319
try {
280320
model.off(null, null, INITIALIZE_MARKER);
281321
let widget = await load_widget(update);
282322
cleanup = await widget.initialize?.({
283323
model: model_proxy(model, INITIALIZE_MARKER),
324+
experimental: {
325+
// @ts-expect-error - bind isn't working
326+
invoke: invoke.bind(null, model),
327+
},
284328
});
285329
return ok(widget);
286330
} catch (e) {
@@ -302,11 +346,11 @@ class Runtime {
302346
*/
303347
async create_view(view) {
304348
let model = view.model;
305-
let disposer = createRoot((dispose) => {
349+
let disposer = solid.createRoot((dispose) => {
306350
/** @type {void | (() => import("vitest").Awaitable<void>)} */
307351
let cleanup;
308352
let resource =
309-
createResource(this.#widget_result, async (widget_result) => {
353+
solid.createResource(this.#widget_result, async (widget_result) => {
310354
cleanup?.();
311355
// Clear all previous event listeners from this hook.
312356
model.off(null, null, view);
@@ -319,12 +363,16 @@ class Runtime {
319363
cleanup = await widget.render?.({
320364
model: model_proxy(model, view),
321365
el: view.el,
366+
experimental: {
367+
// @ts-expect-error - bind isn't working
368+
invoke: invoke.bind(null, model),
369+
},
322370
});
323371
} catch (e) {
324372
throw_anywidget_error(e);
325373
}
326374
})[0];
327-
createEffect(() => {
375+
solid.createEffect(() => {
328376
if (resource.error) {
329377
// TODO: Show error in the view?
330378
}

packages/types/index.ts

+10
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,18 @@ export interface AnyModel<T extends ObjectHash = ObjectHash> {
4343
widget_manager: IWidgetManager;
4444
}
4545

46+
type Experimental = {
47+
invoke: <T>(
48+
name: string,
49+
msg?: any,
50+
buffers?: DataView[],
51+
) => Promise<[T, DataView[]]>;
52+
};
53+
4654
export interface RenderProps<T extends ObjectHash = ObjectHash> {
4755
model: AnyModel<T>;
4856
el: HTMLElement;
57+
experimental: Experimental;
4958
}
5059

5160
export interface Render<T extends ObjectHash = ObjectHash> {
@@ -54,6 +63,7 @@ export interface Render<T extends ObjectHash = ObjectHash> {
5463

5564
export interface InitializeProps<T extends ObjectHash = ObjectHash> {
5665
model: AnyModel<T>;
66+
experimental: Experimental;
5767
}
5868

5969
export interface Initialize<T extends ObjectHash = ObjectHash> {

0 commit comments

Comments
 (0)