Skip to content

Commit

Permalink
Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
remade committed Nov 14, 2024
1 parent 9b8a429 commit 90a209a
Show file tree
Hide file tree
Showing 23 changed files with 200 additions and 59 deletions.
2 changes: 1 addition & 1 deletion surrealdb/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION='0.0.2'
VERSION='0.1.0'
27 changes: 25 additions & 2 deletions surrealdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
from surrealdb.async_surrealdb import AsyncSurrealDB
from surrealdb.surrealdb import SurrealDB
from surrealdb.errors import SurrealDbError, SurrealDbConnectionError
from surrealdb.errors import SurrealDbError, SurrealDbConnectionError, SurrealDbDecodeError, SurrealDbEncodeError

__all__ = ("SurrealDB", "AsyncSurrealDB", "SurrealDbError", "SurrealDbConnectionError")
from surrealdb.data.models import Patch, QueryResponse, GraphQLOptions
from surrealdb.data.types.duration import Duration
from surrealdb.data.types.future import Future
from surrealdb.data.types.geometry import Geometry, GeometryPoint, GeometryLine, GeometryPolygon, GeometryMultiPoint, \
GeometryMultiLine, GeometryMultiPolygon, GeometryCollection
from surrealdb.data.types.range import Bound, BoundIncluded, BoundExcluded, Range
from surrealdb.data.types.record_id import RecordID
from surrealdb.data.types.table import Table

__all__ = (
"SurrealDB", "AsyncSurrealDB",
"SurrealDbError", "SurrealDbConnectionError", "SurrealDbDecodeError", "SurrealDbEncodeError",

"Patch", "QueryResponse", "GraphQLOptions",

"Duration",
"Future",
"Geometry", "GeometryPoint", "GeometryLine", "GeometryPolygon", "GeometryMultiPoint",
"GeometryMultiLine", "GeometryMultiPolygon", "GeometryCollection",
"Bound", "BoundIncluded", "BoundExcluded", "Range",
"RecordID",
"Table"

)
16 changes: 16 additions & 0 deletions surrealdb/async_surrealdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,20 @@ async def merge(self, thing: Union[str, RecordID, Table], data: dict) -> Union[L
"""
return await self.__connection.send('update', thing, data)

async def live(self, thing: Union[str, Table], diff: Optional[bool] = False) -> Union[List[dict], dict]:
"""
Live initiates a live query for a specified table name.
:param thing: The Table tquery.
:param diff: If set to true, live notifications will contain an array of JSON Patches instead of the entire record
:return: the live query uuid
"""
return await self.__connection.send('live', thing, diff)

async def kill(self, live_query_id: str) -> None:
"""
This kills an active live query
:param live_query_id: The Table or Record ID to merge into.
"""
return await self.__connection.send('kill', live_query_id)
26 changes: 16 additions & 10 deletions surrealdb/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import secrets
import string
from typing import Optional, Tuple
import logging

from typing import Optional, Tuple
from surrealdb.constants import REQUEST_ID_LENGTH
from surrealdb.data.cbor import encode, decode
from surrealdb.errors import SurrealDbConnectionError
Expand All @@ -11,13 +12,14 @@ class Connection:
def __init__(
self,
base_url: str,
namespace: Optional[str] = None,
database: Optional[str] = None
logger: logging.Logger,
):
self._auth_token = None
self._namespace = None
self._database = None

self._base_url = base_url
self._namespace = namespace
self._database = database
self._logger = logger

async def use(self, namespace: str, database: str) -> None:
pass
Expand All @@ -38,12 +40,13 @@ async def unset(self, key: str):
pass

async def send(self, method: str, *params):
# print("Request: ", method, params)
req_id = request_id(REQUEST_ID_LENGTH)
request_data = {
'id': request_id(REQUEST_ID_LENGTH),
'id': req_id,
'method': method,
'params': params
}
self._logger.debug(f"Request {req_id}:", request_data)

try:
successful, response_data = await self._make_request(encode(request_data))
Expand All @@ -55,11 +58,14 @@ async def send(self, method: str, *params):
error_msg = response.get("error").get("message")
raise SurrealDbConnectionError(error_msg)

# print("Response: ", response)
# print("Result: ", response_data.hex())
# print("------------------------------------------------------------------------------------------------------------------------")
self._logger.debug(f"Response {req_id}:", response_data.hex())
self._logger.debug(f"Decoded Result {req_id}:", response)
self._logger.debug("----------------------------------------------------------------------------------")

return response.get("result")
except Exception as e:
self._logger.debug(f"Error {req_id}:", e)
self._logger.debug("----------------------------------------------------------------------------------")
raise e

def set_token(self, token: Optional[str] = None) -> None:
Expand Down
9 changes: 7 additions & 2 deletions surrealdb/connection_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from urllib.parse import urlparse

from surrealdb.connection import Connection
Expand All @@ -8,14 +9,18 @@


def create_connection_factory(url: str) -> Connection:
logger: logging.Logger = logging.getLogger(__name__)

parsed_url = urlparse(url)
if parsed_url.scheme not in ALLOWED_CONNECTION_SCHEMES:
raise SurrealDbConnectionError("invalid scheme. allowed schemes are", "".join(ALLOWED_CONNECTION_SCHEMES))

if parsed_url.scheme in WS_CONNECTION_SCHEMES:
return WebsocketConnection(url)
logger.debug("websocket url detected, creating a websocket connection")
return WebsocketConnection(url, logger)

if parsed_url.scheme in HTTP_CONNECTION_SCHEMES:
return HTTPConnection(url)
logger.debug("http url detected, creating a http connection")
return HTTPConnection(url, logger)

raise Exception('no connection type available')
1 change: 1 addition & 0 deletions surrealdb/connection_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def connect(self) -> None:

response = requests.get(self._base_url + '/health')
if response.status_code != 200:
self._logger.debug("HTTP health check successful")
raise SurrealDbConnectionError('connection failed. check server is up and base url is correct')

async def _make_request(self, request_payload: bytes) -> Tuple[bool, bytes]:
Expand Down
8 changes: 5 additions & 3 deletions surrealdb/connection_ws.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Optional, Tuple
import logging
from typing import Tuple
from websockets import Subprotocol, ConnectionClosed

from surrealdb.connection import Connection
from surrealdb.errors import SurrealDbConnectionError
from websockets.asyncio.client import connect


class WebsocketConnection(Connection):
def __init__(self, base_url: str, namespace: Optional[str] = None, database: Optional[str] = None):
super().__init__(base_url, namespace, database)
def __init__(self, base_url: str, logger: logging.Logger):
super().__init__(base_url, logger)

self._ws = None

Expand Down
4 changes: 2 additions & 2 deletions surrealdb/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
REQUEST_ID_LENGTH = 10

AUTH_TOKEN_KEY = "auth_token"

ALLOWED_CONNECTION_SCHEMES = ['http', 'https', 'ws', 'wss']
HTTP_CONNECTION_SCHEMES = ['http', 'https']
WS_CONNECTION_SCHEMES = ['ws', 'wss']

DEFAULT_CONNECTION_URL = "http://127.0.0.1:8000"

UNSUPPORTED_HTTP_METHODS = ["kill", "live"]
29 changes: 27 additions & 2 deletions surrealdb/data/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union

from surrealdb import Table, RecordID


@dataclass
Expand Down Expand Up @@ -31,3 +33,26 @@ class QueryResponse:
time: str
status: str
result: List[Dict[str, Any]]


@dataclass
class GraphQLOptions:
"""
Represents the options parameter for graphql method.
Attributes:
pretty: (optional, default false): A boolean indicating whether the output should be pretty-printed.
format: (optional, default "json"): The response format. Currently, only "json" is supported.
"""

pretty: bool = field(default=False)
format: str = field(default="json")


def table_or_record_id(resource_str: str) -> Union[Table, RecordID]:
table, record_id = resource_str.split(":") if ":" in resource_str else (resource_str, None)
if record_id is not None:
return RecordID(table, record_id)

return Table(table)

27 changes: 25 additions & 2 deletions surrealdb/data/types/geometry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Any, Tuple
from typing import List, Tuple


class Geometry:
Expand All @@ -17,7 +17,7 @@ class GeometryPoint(Geometry):
latitude: float

def __repr__(self):
return 'GeometryPoint(longitude={self.longitude}, latitude={self.latitude})'.format(self=self)
return f'{self.__class__.__name__}(longitude={self.longitude}, latitude={self.latitude})'.format(self=self)

def get_coordinates(self) -> Tuple[float, float]:
return self.longitude, self.latitude
Expand All @@ -27,6 +27,7 @@ def parse_coordinates(coordinates):
return GeometryPoint(coordinates[0], coordinates[1])


@dataclass
class GeometryLine(Geometry):

def __init__(self, point1: GeometryPoint, point2: GeometryPoint, *other_points: GeometryPoint):
Expand All @@ -35,54 +36,73 @@ def __init__(self, point1: GeometryPoint, point2: GeometryPoint, *other_points:
def get_coordinates(self) -> List[Tuple[float, float]]:
return [point.get_coordinates() for point in self.geometry_points]

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_points)})'

@staticmethod
def parse_coordinates(coordinates):
return GeometryLine(*[GeometryPoint.parse_coordinates(point) for point in coordinates])


@dataclass
class GeometryPolygon(Geometry):
def __init__(self, line1, line2, *other_lines: GeometryLine):
self.geometry_lines = [line1, line2] + list(other_lines)

def get_coordinates(self) -> List[List[Tuple[float, float]]]:
return [line.get_coordinates() for line in self.geometry_lines]

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_lines)})'

@staticmethod
def parse_coordinates(coordinates):
return GeometryPolygon(*[GeometryLine.parse_coordinates(line) for line in coordinates])


@dataclass
class GeometryMultiPoint(Geometry):
def __init__(self, *geometry_points: GeometryPoint):
self.geometry_points = geometry_points

def get_coordinates(self) -> List[Tuple[float, float]]:
return [point.get_coordinates() for point in self.geometry_points]

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_points)})'

@staticmethod
def parse_coordinates(coordinates):
return GeometryMultiPoint(*[GeometryPoint.parse_coordinates(point) for point in coordinates])


@dataclass
class GeometryMultiLine(Geometry):
def __init__(self, *geometry_lines: GeometryLine):
self.geometry_lines = geometry_lines

def get_coordinates(self) -> List[List[Tuple[float, float]]]:
return [line.get_coordinates() for line in self.geometry_lines]

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_lines)})'

@staticmethod
def parse_coordinates(coordinates):
return GeometryMultiLine(*[GeometryLine.parse_coordinates(line) for line in coordinates])


@dataclass
class GeometryMultiPolygon(Geometry):
def __init__(self, *geometry_polygons: GeometryPolygon):
self.geometry_polygons = geometry_polygons

def get_coordinates(self) -> List[List[List[Tuple[float, float]]]]:
return [polygon.get_coordinates() for polygon in self.geometry_polygons]

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_polygons)})'

@staticmethod
def parse_coordinates(coordinates):
return GeometryMultiPolygon(*[GeometryPolygon.parse_coordinates(polygon) for polygon in coordinates])
Expand All @@ -93,3 +113,6 @@ class GeometryCollection:

def __init__(self, *geometries: Geometry):
self.geometries = geometries

def __repr__(self):
return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometries)})'
2 changes: 2 additions & 0 deletions surrealdb/data/types/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ def __init__(self):
pass


@dataclass
class BoundIncluded(Bound):
def __init__(self, value):
super().__init__()
self.value = value


@dataclass
class BoundExcluded(Bound):
def __init__(self, value):
super().__init__()
Expand Down
14 changes: 13 additions & 1 deletion surrealdb/data/types/record_id.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from dataclasses import dataclass


@dataclass
class RecordID:
def __init__(self, table_name: str, identifier):
self.table_name = table_name
self.id = identifier

def __repr__(self) -> str:
def __str__(self) -> str:
return f"{self.table_name}:{self.id}"

def __repr__(self) -> str:
return f'{self.__class__.__name__}(table_name={self.table_name}, record_id={self.id})'.format(self=self)

@staticmethod
def parse(record_str: str):
if ":" not in record_str:
raise ValueError('invalid string provided for parse. the expected string format is "table_name:record_id"')

table, record_id = record_str.split(":")
return RecordID(table, record_id)
6 changes: 6 additions & 0 deletions surrealdb/data/types/table.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
class Table:
def __init__(self, table_name: str):
self.table_name = table_name

def __str__(self) -> str:
return f"{self.table_name}"

def __repr__(self) -> str:
return f"{self.table_name}"
4 changes: 2 additions & 2 deletions surrealdb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class SurrealDbConnectionError(SurrealDbError):

class SurrealDbDecodeError(SurrealDbError):
"""
Exceptions from connections
Exceptions from Decoding responses
"""


class SurrealDbEncodeError(SurrealDbError):
"""
Exceptions from connections
Exceptions from encoding requests
"""
Loading

0 comments on commit 90a209a

Please sign in to comment.