Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
remade committed Nov 13, 2024
1 parent 4df3d4b commit bb294eb
Show file tree
Hide file tree
Showing 14 changed files with 299 additions and 358 deletions.
39 changes: 39 additions & 0 deletions del-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
anyio==3.7.0
bleach==6.0.0
cbor2==5.6.5
certifi==2023.5.7
charset-normalizer==3.2.0
coverage==7.2.7
docutils==0.20.1
exceptiongroup==1.1.1
h11==0.14.0
httpcore==0.17.2
httpx==0.24.1
idna==3.4
importlib-metadata==6.8.0
jaraco.classes==3.3.0
keyring==24.2.0
markdown-it-py==3.0.0
maturin==1.1.0
mdurl==0.1.2
more-itertools==10.0.0
pkginfo==1.9.6
pydantic==1.10.9
Pygments==2.15.1
readme-renderer==40.0
requests==2.31.0
requests-toolbelt==1.0.0
rfc3986==2.0.0
rich==13.4.2
semantic-version==2.10.0
setuptools-rust==1.6.0
six==1.16.0
sniffio==1.3.0
tomli==2.0.1
twine==4.0.2
typing_extensions==4.6.3
urllib3==2.0.4
webencodings==0.5.1
websocket-client==1.8.0
websockets==10.4
zipp==3.16.2
9 changes: 4 additions & 5 deletions examples/notebook_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
"source": [
"from surrealdb import AsyncSurrealDB\n",
"\n",
"db = AsyncSurrealDB(\"ws://localhost:8000/database/namespace\")\n",
"db = AsyncSurrealDB(\"ws://localhost:8000\")\n",
"\n",
"await db.connect()\n",
"\n",
"await db.signin({\n",
" \"username\": \"root\",\n",
" \"password\": \"root\",\n",
"})"
"await db.use(\"test\", \"test\")\n",
"\n",
"await db.sign_in(\"root\", \"root\")"
]
},
{
Expand Down
44 changes: 6 additions & 38 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,39 +1,7 @@
anyio==3.7.0
bleach==6.0.0
cbor2==5.6.5
certifi==2023.5.7
charset-normalizer==3.2.0
coverage==7.2.7
docutils==0.20.1
exceptiongroup==1.1.1
h11==0.14.0
httpcore==0.17.2
httpx==0.24.1
idna==3.4
importlib-metadata==6.8.0
jaraco.classes==3.3.0
keyring==24.2.0
markdown-it-py==3.0.0
maturin==1.1.0
mdurl==0.1.2
more-itertools==10.0.0
pkginfo==1.9.6
pydantic==1.10.9
Pygments==2.15.1
readme-renderer==40.0
requests==2.31.0
requests-toolbelt==1.0.0
rfc3986==2.0.0
rich==13.4.2
semantic-version==2.10.0
setuptools-rust==1.6.0
six==1.16.0
sniffio==1.3.0
tomli==2.0.1
twine==4.0.2
typing_extensions==4.6.3
urllib3==2.0.4
webencodings==0.5.1
websocket-client==1.8.0
websockets==10.4
zipp==3.16.2
docker_py==1.10.6
Requests==2.32.3
setuptools==63.2.0
setuptools==65.5.1
setuptools_rust==1.6.0
websockets==14.1
8 changes: 4 additions & 4 deletions surrealdb/async_surrealdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from surrealdb.constants import DEFAULT_CONNECTION_URL
from surrealdb.connection_factory import create_connection_factory
from surrealdb.data import Table, RecordID, Patch
from surrealdb.data import Table, RecordID, Patch, QueryResponse

_Self = TypeVar('_Self', bound='AsyncSurrealDB')

Expand Down Expand Up @@ -133,10 +133,10 @@ async def version(self) -> str:
return await self.__connection.send('version')

async def set(self, name: str, value) -> None:
await self.__connection.send('let', name, value)
await self.__connection.set(name, value)

async def unset(self, name: str) -> None:
await self.__connection.send('unset', name)
await self.__connection.unset(name)

async def select(self, what: Union[str, Table, RecordID]) -> Union[List[dict], dict]:
"""
Expand All @@ -148,7 +148,7 @@ async def select(self, what: Union[str, Table, RecordID]) -> Union[List[dict], d
"""
return await self.__connection.send('select', what)

async def query(self, query: str, variables: dict = {}) -> List[dict]:
async def query(self, query: str, variables: dict = {}) -> List[QueryResponse]:
"""
Queries sends a custom SurrealQL query.
Expand Down
10 changes: 7 additions & 3 deletions surrealdb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ async def close(self) -> None:
async def _make_request(self, request_payload: bytes) -> Tuple[bool, bytes]:
pass

async def set(self, key: str, value):
pass

async def unset(self, key: str):
pass

async def send(self, method: str, *params):
# print("Request: ", method, params)
request_data = {
Expand All @@ -49,15 +55,13 @@ async def send(self, method: str, *params):
error_msg = response.get("error").get("message")
raise SurrealDbConnectionError(error_msg)

# print("Response: ", response)
# print("Result: ", response_data.hex())
# print("Result: ", response.get("result"))
# print("------------------------------------------------------------------------------------------------------------------------")
return response.get("result")
except Exception as e:
raise e



def set_token(self, token: Optional[str] = None) -> None:
self._auth_token = token

Expand Down
1 change: 0 additions & 1 deletion surrealdb/connection_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
class HTTPConnection(Connection):

async def use(self, namespace: str, database: str) -> None:
print(namespace, database)
self._namespace = namespace
self._database = database

Expand Down
14 changes: 10 additions & 4 deletions surrealdb/connection_ws.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import websockets

from typing import Optional, Tuple
from websockets import Subprotocol, ConnectionClosed
from surrealdb.connection import Connection
from surrealdb.errors import SurrealDbConnectionError
from websockets.asyncio.client import connect


class WebsocketConnection(Connection):
Expand All @@ -13,7 +13,7 @@ def __init__(self, base_url: str, namespace: Optional[str] = None, database: Opt

async def connect(self):
try:
self._ws = await websockets.connect(self._base_url + "/rpc", subprotocols=['cbor'])
self._ws = await connect(self._base_url + "/rpc", subprotocols=[Subprotocol('cbor')])
except Exception as e:
raise SurrealDbConnectionError('cannot connect db server', e)

Expand All @@ -23,6 +23,12 @@ async def use(self, namespace: str, database: str) -> None:

await self.send("use", namespace, database)

async def set(self, key: str, value):
await self.send("let", key, value)

async def unset(self, key: str):
await self.send("unset", key)

async def close(self):
if self._ws:
await self._ws.close()
Expand All @@ -32,7 +38,7 @@ async def _make_request(self, request_payload: bytes) -> Tuple[bool, bytes]:
await self._ws.send(request_payload)
response = await self._ws.recv()
return True, response
except websockets.ConnectionClosed as e:
except ConnectionClosed as e:
raise SurrealDbConnectionError(e)
except Exception as e:
raise e
2 changes: 1 addition & 1 deletion surrealdb/data/types/record_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ def __init__(self, table_name: str, identifier):
self.id = identifier

def __repr__(self) -> str:
return "".join([self.table_name, ":", self.id])
return f"{self.table_name}:{self.id}"


4 changes: 2 additions & 2 deletions surrealdb/surrealdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def version(self) -> str:

def set(self, name: str, value) -> None:
loop_manager = AsyncioRuntime()
loop_manager.loop.run_until_complete(self.__connection.send('let', name, value))
loop_manager.loop.run_until_complete(self.__connection.set(name, value))

def unset(self, name: str) -> None:
loop_manager = AsyncioRuntime()
loop_manager.loop.run_until_complete(self.__connection.send('unset', name))
loop_manager.loop.run_until_complete(self.__connection.unset(name))

def select(self, what: Union[str, Table, RecordID]) -> Union[List[dict], dict]:
"""
Expand Down
62 changes: 24 additions & 38 deletions tests/integration/async/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,38 @@
Handles the integration tests for logging into the database.
"""

import asyncio
import os
from unittest import TestCase, main
from unittest import IsolatedAsyncioTestCase, main

from surrealdb import AsyncSurrealDB
from surrealdb import AsyncSurrealDB, SurrealDbConnectionError
from tests.integration.connection_params import TestConnectionParams


class TestAsyncAuth(TestCase):
def setUp(self):
class TestAsyncAuth(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.params = TestConnectionParams()
self.db = AsyncSurrealDB(self.params.url)

def tearDown(self):
pass

async def login(self, username: str, password: str):
await self.db.connect()
outcome = await self.db.sign_in(username, password)
return outcome

def test_login_success(self):
outcome = asyncio.run(self.login("root", "root"))
self.assertEqual(None, outcome)

def test_login_wrong_password(self):
with self.assertRaises(RuntimeError) as context:
asyncio.run(self.login("root", "wrong"))

if os.environ.get("CONNECTION_PROTOCOL", "http") == "http":
self.assertEqual(True, "(401 Unauthorized)" in str(context.exception))
else:
self.assertEqual(
'"There was a problem with authentication"', str(context.exception)
)

def test_login_wrong_username(self):
with self.assertRaises(RuntimeError) as context:
asyncio.run(self.login("wrong", "root"))

if os.environ.get("CONNECTION_PROTOCOL", "http") == "http":
self.assertEqual(True, "(401 Unauthorized)" in str(context.exception))
else:
self.assertEqual(
'"There was a problem with authentication"', str(context.exception)
)
await self.db.use(self.params.namespace, self.params.database)

async def asyncTearDown(self):
await self.db.close()

async def test_login_success(self):
outcome = await self.db.sign_in("root", "root")
self.assertNotEqual(None, outcome)

async def test_login_wrong_password(self):
with self.assertRaises(SurrealDbConnectionError) as context:
await self.db.sign_in("root", "wrong")

self.assertEqual(True, "There was a problem with authentication" in str(context.exception))

async def test_login_wrong_username(self):
with self.assertRaises(SurrealDbConnectionError) as context:
await self.db.sign_in("wrong", "root")

self.assertEqual(True, "There was a problem with authentication" in str(context.exception))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit bb294eb

Please sign in to comment.