|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from functools import partial |
| 4 | + |
1 | 5 | import pytest
|
2 |
| -from anyio import Event, create_task_group, move_on_after, sleep |
| 6 | +from anyio import Event, fail_after |
3 | 7 | from pycrdt import Array, Doc, Map
|
4 | 8 | from websockets import connect
|
5 | 9 |
|
6 | 10 | from pycrdt_websocket import WebsocketProvider
|
7 | 11 |
|
8 | 12 |
|
9 |
| -class YTest: |
10 |
| - def __init__(self, ydoc: Doc, timeout: float = 1.0): |
11 |
| - self.ydoc = ydoc |
| 13 | +class Change: |
| 14 | + def __init__(self, event, timeout, ydata, sid, key): |
| 15 | + self.event = event |
12 | 16 | self.timeout = timeout
|
13 |
| - self.ydoc["_test"] = self.ytest = Map() |
14 |
| - self.clock = -1.0 |
| 17 | + self.ydata = ydata |
| 18 | + self.sid = sid |
| 19 | + self.key = key |
15 | 20 |
|
16 |
| - def run_clock(self): |
17 |
| - self.clock = max(self.clock, 0.0) |
18 |
| - self.ytest["clock"] = self.clock |
| 21 | + async def wait(self): |
| 22 | + with fail_after(self.timeout): |
| 23 | + await self.event.wait() |
| 24 | + self.ydata.unobserve(self.sid) |
| 25 | + if self.key is None: |
| 26 | + return |
| 27 | + return self.ydata[self.key] |
19 | 28 |
|
20 |
| - async def clock_run(self): |
21 |
| - change = Event() |
22 | 29 |
|
23 |
| - def callback(event): |
24 |
| - if "clock" in event.keys: |
25 |
| - clk = event.keys["clock"]["newValue"] |
26 |
| - if clk > self.clock: |
27 |
| - self.clock = clk + 1.0 |
28 |
| - change.set() |
| 30 | +def callback(change_event, key, event): |
| 31 | + if key is None or key in event.keys: |
| 32 | + change_event.set() |
29 | 33 |
|
30 |
| - subscription_id = self.ytest.observe(callback) |
31 |
| - async with create_task_group(): |
32 |
| - with move_on_after(self.timeout): |
33 |
| - await change.wait() |
34 | 34 |
|
35 |
| - self.ytest.unobserve(subscription_id) |
| 35 | +def watch(ydata, key: str | None = None, timeout: float = 1.0): |
| 36 | + change_event = Event() |
| 37 | + sid = ydata.observe(partial(callback, change_event, key)) |
| 38 | + return Change(change_event, timeout, ydata, sid, key) |
36 | 39 |
|
37 | 40 |
|
38 | 41 | @pytest.mark.anyio
|
39 | 42 | @pytest.mark.parametrize("yjs_client", "0", indirect=True)
|
40 | 43 | async def test_pycrdt_yjs_0(yws_server, yjs_client):
|
41 | 44 | ydoc = Doc()
|
42 |
| - ytest = YTest(ydoc) |
43 | 45 | async with connect("ws://127.0.0.1:1234/my-roomname") as websocket, WebsocketProvider(
|
44 | 46 | ydoc, websocket
|
45 | 47 | ):
|
46 | 48 | ydoc["map"] = ymap = Map()
|
47 |
| - # set a value in "in" |
48 | 49 | for v_in in range(10):
|
49 | 50 | ymap["in"] = float(v_in)
|
50 |
| - ytest.run_clock() |
51 |
| - await ytest.clock_run() |
52 |
| - v_out = ymap["out"] |
| 51 | + v_out = await watch(ymap, "out").wait() |
53 | 52 | assert v_out == v_in + 1.0
|
54 | 53 |
|
55 | 54 |
|
56 | 55 | @pytest.mark.anyio
|
57 | 56 | @pytest.mark.parametrize("yjs_client", "1", indirect=True)
|
58 | 57 | async def test_pycrdt_yjs_1(yws_server, yjs_client):
|
59 |
| - # wait for the JS client to connect |
60 |
| - tt, dt = 0, 0.1 |
61 |
| - while True: |
62 |
| - await sleep(dt) |
63 |
| - if "/my-roomname" in yws_server.rooms: |
64 |
| - break |
65 |
| - tt += dt |
66 |
| - if tt >= 1: |
67 |
| - raise RuntimeError("Timeout waiting for client to connect") |
68 |
| - ydoc = yws_server.rooms["/my-roomname"].ydoc |
69 |
| - ytest = YTest(ydoc) |
70 |
| - ytest.run_clock() |
71 |
| - await ytest.clock_run() |
| 58 | + ydoc = Doc() |
72 | 59 | ydoc["cells"] = ycells = Array()
|
73 | 60 | ydoc["state"] = ystate = Map()
|
74 |
| - assert ycells.to_py() == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}] |
75 |
| - assert ystate.to_py() == {"state": {"dirty": False}} |
| 61 | + ycells_change = watch(ycells) |
| 62 | + ystate_change = watch(ystate) |
| 63 | + async with connect("ws://127.0.0.1:1234/my-roomname") as websocket, WebsocketProvider( |
| 64 | + ydoc, websocket |
| 65 | + ): |
| 66 | + await ycells_change.wait() |
| 67 | + await ystate_change.wait() |
| 68 | + assert ycells.to_py() == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}] |
| 69 | + assert ystate.to_py() == {"state": {"dirty": False}} |
0 commit comments