Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement set variables for http query #122

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions surrealdb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class RequestData:

class Connection:
_queues: Dict[int, dict]
_namespace: str | None
_database: str | None
_auth_token: str | None
_namespace: str | None = None
_database: str | None = None
_auth_token: str | None = None

def __init__(
self,
Expand Down
19 changes: 16 additions & 3 deletions surrealdb/connection_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import Any
from typing import Any, Tuple

import requests

Expand All @@ -8,7 +8,7 @@


class HTTPConnection(Connection):
_request_variables: dict[str, Any]
_request_variables: dict[str, Any] = {}
_request_variables_lock = threading.Lock()

async def use(self, namespace: str, database: str) -> None:
Expand All @@ -21,7 +21,8 @@ async def set(self, key: str, value):

async def unset(self, key: str):
with self._request_variables_lock:
del self._request_variables[key]
if self._request_variables.get(key) is not None:
del self._request_variables[key]

async def connect(self) -> None:
if self._base_url is None:
Expand All @@ -34,6 +35,15 @@ async def connect(self) -> None:
"connection failed. check server is up and base url is correct"
)

def _prepare_query_method_params(self, params: Tuple) -> Tuple:
query, variables = params
variables = (
{**variables, **self._request_variables}
if variables
else self._request_variables.copy()
)
return query, variables

async def _make_request(self, request_data: RequestData):
if self._namespace is None:
raise SurrealDbConnectionError("namespace not set")
Expand All @@ -51,6 +61,9 @@ async def _make_request(self, request_data: RequestData):
if self._auth_token is not None:
headers["Authorization"] = f"Bearer {self._auth_token}"

if request_data.method.lower() == "query":
request_data.params = self._prepare_query_method_params(request_data.params)

request_payload = self._encoder(
{
"id": request_data.id,
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,36 @@ async def asyncSetUp(self):
async def test_send(self):
await self.http_con.use('test', 'test')
_ = await self.http_con.send('signin', {'user': 'root', 'pass': 'root'})

async def test_prepare_query_params(self):
query_params = ("SOME SQL QUERY;", {
"key1": "key1"
})
await self.http_con.set("key2", "key2")
await self.http_con.set("key3", "key3")

params = self.http_con._prepare_query_method_params(query_params)
self.assertEqual(query_params[0], params[0])
self.assertEqual({
"key1": "key1",
"key2": "key2",
"key3": "key3",
}, params[1])

await self.http_con.unset("key3")

params = self.http_con._prepare_query_method_params(query_params)
self.assertEqual(query_params[0], params[0])
self.assertEqual({
"key1": "key1",
"key2": "key2",
}, params[1])

await self.http_con.unset("key1") # variable key not part of prev set variables

params = self.http_con._prepare_query_method_params(query_params)
self.assertEqual(query_params[0], params[0])
self.assertEqual({
"key1": "key1",
"key2": "key2",
}, params[1])