Skip to content

Commit

Permalink
Issue fix for Notification queue (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
remade authored Nov 25, 2024
1 parent ae48005 commit 4e89316
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 32 deletions.
12 changes: 7 additions & 5 deletions surrealdb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from typing import Dict, Tuple
from surrealdb.constants import REQUEST_ID_LENGTH
from surrealdb.data.cbor import encode, decode
from asyncio import Queue


Expand All @@ -34,7 +33,12 @@ def __init__(
self,
base_url: str,
logger: logging.Logger,
encoder,
decoder,
):
self._encoder = encoder
self._decoder = decoder

self._locks = {
ResponseType.SEND: threading.Lock(),
ResponseType.NOTIFICATION: threading.Lock(),
Expand All @@ -58,7 +62,7 @@ async def connect(self) -> None:
async def close(self) -> None:
pass

async def _make_request(self, request_data: RequestData, encoder, decoder):
async def _make_request(self, request_data: RequestData):
pass

async def set(self, key: str, value):
Expand Down Expand Up @@ -104,9 +108,7 @@ async def send(self, method: str, *params):
self._logger.debug(f"Request {request_data.id}:", request_data)

try:
result = await self._make_request(
request_data, encoder=encode, decoder=decode
)
result = await self._make_request(request_data)

self._logger.debug(f"Result {request_data.id}:", result)
self._logger.debug(
Expand Down
12 changes: 7 additions & 5 deletions surrealdb/connection_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class sr_notification_t(ctypes.Structure):


class CLibConnection(Connection):
def __init__(self, base_url: str, logger: logging.Logger):
super().__init__(base_url, logger)
def __init__(self, base_url: str, logger: logging.Logger, encoder, decoder):
super().__init__(base_url, logger, encoder, decoder)

lib_path = get_lib_path()
self._lib = ctypes.CDLL(lib_path)
Expand Down Expand Up @@ -194,8 +194,8 @@ async def set(self, key: str, value):
async def unset(self, key: str):
await self.send("unset", key)

async def _make_request(self, request_data: RequestData, encoder, decoder):
request_payload = encoder(
async def _make_request(self, request_data: RequestData):
request_payload = self._encoder(
{
"id": request_data.id,
"method": request_data.method,
Expand Down Expand Up @@ -226,4 +226,6 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):

# Free the allocated byte array returned by the C library
self._lib.sr_free_byte_arr(c_res_ptr, result)
return True, response
response_data = self._decoder(response)

return response_data
7 changes: 4 additions & 3 deletions surrealdb/connection_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from surrealdb.connection_clib import CLibConnection
from surrealdb.connection_http import HTTPConnection
from surrealdb.connection_ws import WebsocketConnection
from surrealdb.data.cbor import encode, decode
from surrealdb.errors import SurrealDbConnectionError


Expand All @@ -26,14 +27,14 @@ def create_connection_factory(url: str) -> Connection:

if parsed_url.scheme in WS_CONNECTION_SCHEMES:
logger.debug("websocket url detected, creating a websocket connection")
return WebsocketConnection(url, logger)
return WebsocketConnection(url, logger, encoder=encode, decoder=decode)

if parsed_url.scheme in HTTP_CONNECTION_SCHEMES:
logger.debug("http url detected, creating a http connection")
return HTTPConnection(url, logger)
return HTTPConnection(url, logger, encoder=encode, decoder=decode)

if parsed_url.scheme in CLIB_CONNECTION_SCHEMES:
logger.debug("embedded url detected, creating a clib connection")
return CLibConnection(url, logger)
return CLibConnection(url, logger, encoder=encode, decoder=decode)

raise Exception("no connection type available")
6 changes: 3 additions & 3 deletions surrealdb/connection_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def connect(self) -> None:
"connection failed. check server is up and base url is correct"
)

async def _make_request(self, request_data: RequestData, encoder, decoder):
async def _make_request(self, request_data: RequestData):
if self._namespace is None:
raise SurrealDbConnectionError("namespace not set")

Expand All @@ -51,7 +51,7 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
if self._auth_token is not None:
headers["Authorization"] = f"Bearer {self._auth_token}"

request_payload = encoder(
request_payload = self._encoder(
{
"id": request_data.id,
"method": request_data.method,
Expand All @@ -62,7 +62,7 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
response = requests.post(
f"{self._base_url}/rpc", data=request_payload, headers=headers
)
response_data = decoder(response.content)
response_data = self._decoder(response.content)

if 200 > response.status_code > 299 or response_data.get("error"):
raise SurrealDbConnectionError(response_data.get("error").get("message"))
Expand Down
22 changes: 9 additions & 13 deletions surrealdb/connection_ws.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import asyncio
import logging
from asyncio import Task

from websockets import Subprotocol, ConnectionClosed, connect
from websockets.asyncio.client import ClientConnection

from surrealdb.connection import Connection, ResponseType, RequestData
from surrealdb.constants import WS_REQUEST_TIMEOUT
from surrealdb.data.cbor import decode
from surrealdb.errors import SurrealDbConnectionError


class WebsocketConnection(Connection):
_ws: ClientConnection
_receiver_task: Task

def __init__(self, base_url: str, logger: logging.Logger):
super().__init__(base_url, logger)

# self._ws = None
# self._receiver_task = None

async def connect(self):
try:
self._ws = await connect(
Expand Down Expand Up @@ -49,8 +41,8 @@ async def close(self):
if self._ws:
await self._ws.close()

async def _make_request(self, request_data: RequestData, encoder, decoder):
request_payload = encoder(
async def _make_request(self, request_data: RequestData):
request_payload = self._encoder(
{
"id": request_data.id,
"method": request_data.method,
Expand Down Expand Up @@ -85,14 +77,18 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):

async def listen_to_ws(self, ws):
async for message in ws:
response_data = decode(message)
response_data = self._decoder(message)

response_id = response_data.get("id")
if response_id:
queue = self.get_response_queue(ResponseType.SEND, response_id)
await queue.put(response_data)
continue

live_id = response_data.get("result").get("id")
queue = self.get_response_queue(ResponseType.NOTIFICATION, live_id)
live_id = response_data.get("result").get("id") # returned as uuid
queue = self.get_response_queue(ResponseType.NOTIFICATION, str(live_id))
if queue is None:
self._logger.error(f"No notification queue set for {live_id}")
continue

await queue.put(response_data.get("result"))
3 changes: 2 additions & 1 deletion tests/unit/test_clib_connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from unittest import IsolatedAsyncioTestCase
from logging import getLogger
from surrealdb.connection_clib import CLibConnection
from surrealdb.data.cbor import encode, decode


class TestCLibConnection(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.logger = getLogger(__name__)

self.clib = CLibConnection(base_url='surrealkv://', logger=self.logger)
self.clib = CLibConnection(base_url='surrealkv://', logger=self.logger, encoder=encode, decoder=decode)
await self.clib.connect()

async def test_send(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from unittest import IsolatedAsyncioTestCase

from surrealdb.connection_http import HTTPConnection
from surrealdb.data.cbor import encode, decode


class TestHTTPConnection(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
logger = logging.getLogger(__name__)

self.http_con = HTTPConnection(base_url='http://localhost:8000', logger=logger)
self.http_con = HTTPConnection(base_url='http://localhost:8000', logger=logger, encoder=encode, decoder=decode)
await self.http_con.connect()

async def test_send(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_ws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from unittest import IsolatedAsyncioTestCase

from surrealdb.connection_ws import WebsocketConnection
from surrealdb.data.cbor import encode, decode


class TestWSConnection(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
logger = logging.getLogger(__name__)
self.ws_con = WebsocketConnection(base_url='ws://localhost:8000', logger=logger)
self.ws_con = WebsocketConnection(base_url='ws://localhost:8000', logger=logger, encoder=encode, decoder=decode)
await self.ws_con.connect()

async def test_send(self):
Expand Down

0 comments on commit 4e89316

Please sign in to comment.