Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to secondary tables relationships #218

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de06188
first fix versoin, working only if the items has the same id
Ckk3 Nov 19, 2024
0cd732d
bring back the first version, still missin the different ids logic!
Ckk3 Nov 19, 2024
401cd65
fix: now query can pickup related_model and self_model id
Ckk3 Nov 22, 2024
6f644e3
fix: not working with different ids
Ckk3 Nov 22, 2024
ca9bc1c
add nes tests
Ckk3 Nov 23, 2024
eb852ce
add tests
Ckk3 Nov 23, 2024
5770379
Fix mypy erros, still missing some tests
Ckk3 Nov 24, 2024
be77996
update code to work with sqlalchemy 1.4
Ckk3 Nov 24, 2024
fb6a580
remove old code that only works with sqlalchemy 2
Ckk3 Nov 24, 2024
0fb61bb
add seconday tables tests in test_loader
Ckk3 Nov 24, 2024
03a5438
add new tests to loadar and start mapper tests
Ckk3 Nov 26, 2024
a575650
add mapper tests
Ckk3 Nov 28, 2024
beaa3f9
refactor conftest
Ckk3 Nov 30, 2024
8a65328
refactor test_loader
Ckk3 Nov 30, 2024
9d76061
refactor test_mapper
Ckk3 Nov 30, 2024
91c24c5
run autopep
Ckk3 Nov 30, 2024
1cd8df4
run autopep
Ckk3 Nov 30, 2024
e96f179
separate test
Ckk3 Nov 30, 2024
4b6516b
fix lint
Ckk3 Nov 30, 2024
9b079d4
add release file
Ckk3 Nov 30, 2024
4baa7ae
refactor tests
Ckk3 Nov 30, 2024
33d7758
refactor loader
Ckk3 Nov 30, 2024
2a53474
fix release
Ckk3 Nov 30, 2024
d04af46
update pre-commit to work with python 3.8
Ckk3 Jan 26, 2025
3f7f13d
update loader.py
Ckk3 Jan 26, 2025
ff3e419
updated mapper
Ckk3 Jan 26, 2025
6752231
fix lint
Ckk3 Jan 26, 2025
0cd68d2
remote autopep8 from dev container because it give problems when work…
Ckk3 Jan 26, 2025
0745c64
fix lint
Ckk3 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update loader.py
Ckk3 committed Jan 26, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 3f7f13db25f4a4111df22a75194d7380fe0d866e
85 changes: 51 additions & 34 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
@@ -11,14 +11,15 @@
Tuple,
Union,
)
from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs

from sqlalchemy import select, tuple_
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.orm import RelationshipProperty, Session
from strawberry.dataloader import DataLoader

from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs


class StrawberrySQLAlchemyLoader:
"""
@@ -46,17 +47,22 @@ def __init__(
"One of bind or async_bind_factory must be set for loader to function properly."
)

async def _scalars_all(self, *args, disabled_optimization_to_secondary_tables=False, **kwargs):
async def _scalars_all(self, *args, query_secondary_tables=False, **kwargs):
# query_secondary_tables explanation:
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values.
# This is necessary because we use the keys variable
# to match both related_model and self_model.
if self._async_bind_factory:
async with self._async_bind_factory() as bind:
if disabled_optimization_to_secondary_tables is True:
if query_secondary_tables:
return (await bind.execute(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).all()
else:
assert self._bind is not None
if disabled_optimization_to_secondary_tables is True:
return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()
assert self._bind is not None
if query_secondary_tables:
return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
"""
@@ -71,27 +77,33 @@ async def load_fn(keys: List[Tuple]) -> List[Any]:
def _build_normal_relationship_query(related_model, relationship, keys):
return select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
*[
remote
for _, remote in relationship.local_remote_pairs or []
]
).in_(keys)
)

def _build_relationship_with_secondary_table_query(related_model, relationship, keys):

def _build_relationship_with_secondary_table_query(
related_model, relationship, keys
):
# Use another query when relationship uses a secondary table
self_model = relationship.parent.entity

if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{related_model.__name__} -- {self_model.__name__}")
f"{related_model.__name__} -- {self_model.__name__}"
)

self_model_key_label = str(
relationship.local_remote_pairs[0][1].key)
relationship.local_remote_pairs[0][1].key
)
related_model_key_label = str(
relationship.local_remote_pairs[1][1].key)
relationship.local_remote_pairs[1][1].key
)

self_model_key = str(
relationship.local_remote_pairs[0][0].key)
related_model_key = str(
relationship.local_remote_pairs[1][0].key)
self_model_key = str(relationship.local_remote_pairs[0][0].key)
related_model_key = str(relationship.local_remote_pairs[1][0].key)

remote_to_use = relationship.local_remote_pairs[0][1]
query_keys = tuple([item[0] for item in keys])
@@ -100,35 +112,41 @@ def _build_relationship_with_secondary_table_query(related_model, relationship,
return (
select(
getattr(self_model, self_model_key).label(
self_model_key_label),
related_model
self_model_key_label
),
related_model,
)
.join(
relationship.secondary,
getattr(relationship.secondary.c,
related_model_key_label) == getattr(related_model, related_model_key)
getattr(relationship.secondary.c, related_model_key_label)
== getattr(related_model, related_model_key),
)
.join(
self_model,
getattr(relationship.secondary.c,
self_model_key_label) == getattr(self_model, self_model_key)
)
.filter(
remote_to_use.in_(query_keys)
getattr(relationship.secondary.c, self_model_key_label)
== getattr(self_model, self_model_key),
)
.filter(remote_to_use.in_(query_keys))
)

def _build_query(*args):
return _build_normal_relationship_query(*args) if relationship.secondary is None else _build_relationship_with_secondary_table_query(*args)

query = _build_query(related_model, relationship, keys)
query = (
_build_normal_relationship_query(related_model, relationship, keys)
if relationship.secondary is None
else _build_relationship_with_secondary_table_query(
related_model, relationship, keys
)
)

if relationship.order_by:
query = query.order_by(*relationship.order_by)

if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values.
# This is necessary because we use the keys variable
# to match both related_model and self_model.
rows = await self._scalars_all(query, query_secondary_tables=True)
else:
rows = await self._scalars_all(query)

@@ -159,4 +177,3 @@ def group_by_remote_key(row: Any) -> Tuple:

self._loaders[relationship] = DataLoader(load_fn=load_fn)
return self._loaders[relationship]