Skip to content

Add support to secondary tables relationships #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de06188
first fix versoin, working only if the items has the same id
Ckk3 Nov 19, 2024
0cd732d
bring back the first version, still missin the different ids logic!
Ckk3 Nov 19, 2024
401cd65
fix: now query can pickup related_model and self_model id
Ckk3 Nov 22, 2024
6f644e3
fix: not working with different ids
Ckk3 Nov 22, 2024
ca9bc1c
add nes tests
Ckk3 Nov 23, 2024
eb852ce
add tests
Ckk3 Nov 23, 2024
5770379
Fix mypy erros, still missing some tests
Ckk3 Nov 24, 2024
be77996
update code to work with sqlalchemy 1.4
Ckk3 Nov 24, 2024
fb6a580
remove old code that only works with sqlalchemy 2
Ckk3 Nov 24, 2024
0fb61bb
add seconday tables tests in test_loader
Ckk3 Nov 24, 2024
03a5438
add new tests to loadar and start mapper tests
Ckk3 Nov 26, 2024
a575650
add mapper tests
Ckk3 Nov 28, 2024
beaa3f9
refactor conftest
Ckk3 Nov 30, 2024
8a65328
refactor test_loader
Ckk3 Nov 30, 2024
9d76061
refactor test_mapper
Ckk3 Nov 30, 2024
91c24c5
run autopep
Ckk3 Nov 30, 2024
1cd8df4
run autopep
Ckk3 Nov 30, 2024
e96f179
separate test
Ckk3 Nov 30, 2024
4b6516b
fix lint
Ckk3 Nov 30, 2024
9b079d4
add release file
Ckk3 Nov 30, 2024
4baa7ae
refactor tests
Ckk3 Nov 30, 2024
33d7758
refactor loader
Ckk3 Nov 30, 2024
2a53474
fix release
Ckk3 Nov 30, 2024
d04af46
update pre-commit to work with python 3.8
Ckk3 Jan 26, 2025
3f7f13d
update loader.py
Ckk3 Jan 26, 2025
ff3e419
updated mapper
Ckk3 Jan 26, 2025
6752231
fix lint
Ckk3 Jan 26, 2025
0cd68d2
remote autopep8 from dev container because it give problems when work…
Ckk3 Jan 26, 2025
0745c64
fix lint
Ckk3 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"python.pythonPath": "/usr/local/bin/python",
"python.linting.enabled": true,
"python.linting.pylintEnabled": true,
"python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8",
"python.formatting.blackPath": "/usr/local/py-utils/bin/black",
"python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf",
"python.linting.banditPath": "/usr/local/py-utils/bin/bandit",
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0 # Do not update this repository; it is pinned for compatibility with python 3.8
hooks:
- id: black
exclude: ^tests/\w+/snapshots/

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
rev: v0.9.3
hooks:
- id: ruff
exclude: ^tests/\w+/snapshots/
Expand All @@ -24,7 +24,7 @@ repos:
files: '^docs/.*\.mdx?$'

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-merge-conflict
Expand All @@ -33,7 +33,7 @@ repos:
- id: check-toml

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.16.0
rev: 1.18.0 # Do not update this repository; it is pinned for compatibility with python 3.8
hooks:
- id: blacken-docs
args: [--skip-errors]
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Add support for secondary table relationships in SQLAlchemy mapper, addressing a bug and enhancing the loader to handle these relationships efficiently.
8 changes: 8 additions & 0 deletions src/strawberry_sqlalchemy_mapper/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def __init__(self, model):
f"Model `{model}` is not polymorphic or is not the base model of its "
+ "inheritance chain, and thus cannot be used as an interface."
)


class InvalidLocalRemotePairs(Exception):
def __init__(self, relationship_name):
super().__init__(
f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or missing. "
+ "This is likely an issue with the library. Please report this error to the maintainers."
)
105 changes: 94 additions & 11 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from sqlalchemy.orm import RelationshipProperty, Session
from strawberry.dataloader import DataLoader

from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs


class StrawberrySQLAlchemyLoader:
"""
Expand Down Expand Up @@ -45,13 +47,22 @@ def __init__(
"One of bind or async_bind_factory must be set for loader to function properly."
)

async def _scalars_all(self, *args, **kwargs):
async def _scalars_all(self, *args, query_secondary_tables=False, **kwargs):
# query_secondary_tables explanation:
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values.
# This is necessary because we use the keys variable
# to match both related_model and self_model.
if self._async_bind_factory:
async with self._async_bind_factory() as bind:
if query_secondary_tables:
return (await bind.execute(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).all()
else:
assert self._bind is not None
return self._bind.scalars(*args, **kwargs).all()
assert self._bind is not None
if query_secondary_tables:
return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
"""
Expand All @@ -63,14 +74,81 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
related_model = relationship.entity.entity

async def load_fn(keys: List[Tuple]) -> List[Any]:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
def _build_normal_relationship_query(related_model, relationship, keys):
return select(related_model).filter(
tuple_(
*[
remote
for _, remote in relationship.local_remote_pairs or []
]
).in_(keys)
)

def _build_relationship_with_secondary_table_query(
related_model, relationship, keys
):
# Use another query when relationship uses a secondary table
self_model = relationship.parent.entity

if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{related_model.__name__} -- {self_model.__name__}"
)

self_model_key_label = str(
relationship.local_remote_pairs[0][1].key
)
related_model_key_label = str(
relationship.local_remote_pairs[1][1].key
)

self_model_key = str(relationship.local_remote_pairs[0][0].key)
related_model_key = str(relationship.local_remote_pairs[1][0].key)

remote_to_use = relationship.local_remote_pairs[0][1]
query_keys = tuple([item[0] for item in keys])

# This query returns rows in this format -> (self_model.key, related_model)
return (
select(
getattr(self_model, self_model_key).label(
self_model_key_label
),
related_model,
)
.join(
relationship.secondary,
getattr(relationship.secondary.c, related_model_key_label)
== getattr(related_model, related_model_key),
)
.join(
self_model,
getattr(relationship.secondary.c, self_model_key_label)
== getattr(self_model, self_model_key),
)
.filter(remote_to_use.in_(query_keys))
)

query = (
_build_normal_relationship_query(related_model, relationship, keys)
if relationship.secondary is None
else _build_relationship_with_secondary_table_query(
related_model, relationship, keys
)
)

if relationship.order_by:
query = query.order_by(*relationship.order_by)
rows = await self._scalars_all(query)

if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values.
# This is necessary because we use the keys variable
# to match both related_model and self_model.
rows = await self._scalars_all(query, query_secondary_tables=True)
else:
rows = await self._scalars_all(query)

def group_by_remote_key(row: Any) -> Tuple:
return tuple(
Expand All @@ -82,8 +160,13 @@ def group_by_remote_key(row: Any) -> Tuple:
)

grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
if relationship.secondary is None:
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
else:
for row in rows:
grouped_keys[(row[0],)].append(row[1])

if relationship.uselist:
return [grouped_keys[key] for key in keys]
else:
Expand Down
35 changes: 28 additions & 7 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from strawberry_sqlalchemy_mapper.exc import (
HybridPropertyNotAnnotated,
InterfaceModelNotPolymorphic,
InvalidLocalRemotePairs,
UnsupportedAssociationProxyTarget,
UnsupportedColumnType,
UnsupportedDescriptorType,
Expand Down Expand Up @@ -387,7 +388,7 @@ def _convert_relationship_to_strawberry_type(
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)] # type: ignore
return List[ForwardRef(type_name)] # type: ignore

return self._connection_type_for(type_name)
else:
Expand Down Expand Up @@ -500,13 +501,30 @@ async def resolve(self, info: Info):
if relationship.key not in instance_state.unloaded:
related_objects = getattr(self, relationship.key)
else:
relationship_key = tuple(
[
if relationship.secondary is None:
relationship_key = tuple(
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
)
)
else:
# If has a secondary table, gets only the first ID as additional IDs require a separate query
if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}"
)

local_remote_pairs_secondary_table_local = (
relationship.local_remote_pairs[0][0]
)
relationship_key = tuple(
[
getattr(
self, str(local_remote_pairs_secondary_table_local.key)
),
]
)

if any(item is None for item in relationship_key):
if relationship.uselist:
return []
Expand Down Expand Up @@ -536,7 +554,9 @@ def connection_resolver_for(
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
self.model_to_type_or_interface_name(
relationship.entity.entity # type: ignore[arg-type]
),
)
else:
return relationship_resolver
Expand Down Expand Up @@ -785,7 +805,8 @@ def convert(type_: Any) -> Any:
# ignore inherited `is_type_of`
if "is_type_of" not in type_.__dict__:
type_.is_type_of = (
lambda obj, info: type(obj) == model or type(obj) == type_
lambda obj, info: type(obj) == model # noqa: E721
or type(obj) == type_ # noqa: E721
)

# Default querying methods for relay
Expand Down
Loading
Loading