diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f7a2e74..b9dc5f5 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -10,7 +10,6 @@ "python.pythonPath": "/usr/local/bin/python", "python.linting.enabled": true, "python.linting.pylintEnabled": true, - "python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8", "python.formatting.blackPath": "/usr/local/py-utils/bin/black", "python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf", "python.linting.banditPath": "/usr/local/py-utils/bin/bandit", diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bc9e71..9c48c96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 # Do not update this repository; it is pinned for compatibility with python 3.8 hooks: - id: black exclude: ^tests/\w+/snapshots/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.5 + rev: v0.9.3 hooks: - id: ruff exclude: ^tests/\w+/snapshots/ @@ -24,7 +24,7 @@ repos: files: '^docs/.*\.mdx?$' - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-merge-conflict @@ -33,7 +33,7 @@ repos: - id: check-toml - repo: https://github.com/adamchainz/blacken-docs - rev: 1.16.0 + rev: 1.18.0 # Do not update this repository; it is pinned for compatibility with python 3.8 hooks: - id: blacken-docs args: [--skip-errors] diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000..e962e56 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Add support for secondary table relationships in SQLAlchemy mapper, addressing a bug and enhancing the loader to handle these relationships efficiently. \ No newline at end of file diff --git a/src/strawberry_sqlalchemy_mapper/exc.py b/src/strawberry_sqlalchemy_mapper/exc.py index df4c8f1..eb8388e 100644 --- a/src/strawberry_sqlalchemy_mapper/exc.py +++ b/src/strawberry_sqlalchemy_mapper/exc.py @@ -36,3 +36,11 @@ def __init__(self, model): f"Model `{model}` is not polymorphic or is not the base model of its " + "inheritance chain, and thus cannot be used as an interface." ) + + +class InvalidLocalRemotePairs(Exception): + def __init__(self, relationship_name): + super().__init__( + f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or missing. " + + "This is likely an issue with the library. Please report this error to the maintainers." + ) diff --git a/src/strawberry_sqlalchemy_mapper/loader.py b/src/strawberry_sqlalchemy_mapper/loader.py index 40047e0..b05f160 100644 --- a/src/strawberry_sqlalchemy_mapper/loader.py +++ b/src/strawberry_sqlalchemy_mapper/loader.py @@ -18,6 +18,8 @@ from sqlalchemy.orm import RelationshipProperty, Session from strawberry.dataloader import DataLoader +from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs + class StrawberrySQLAlchemyLoader: """ @@ -45,13 +47,22 @@ def __init__( "One of bind or async_bind_factory must be set for loader to function properly." ) - async def _scalars_all(self, *args, **kwargs): + async def _scalars_all(self, *args, query_secondary_tables=False, **kwargs): + # query_secondary_tables explanation: + # We need to retrieve values from both the self_model and related_model. + # To achieve this, we must disable the default SQLAlchemy optimization + # that returns only related_model values. + # This is necessary because we use the keys variable + # to match both related_model and self_model. if self._async_bind_factory: async with self._async_bind_factory() as bind: + if query_secondary_tables: + return (await bind.execute(*args, **kwargs)).all() return (await bind.scalars(*args, **kwargs)).all() - else: - assert self._bind is not None - return self._bind.scalars(*args, **kwargs).all() + assert self._bind is not None + if query_secondary_tables: + return self._bind.execute(*args, **kwargs).all() + return self._bind.scalars(*args, **kwargs).all() def loader_for(self, relationship: RelationshipProperty) -> DataLoader: """ @@ -63,14 +74,81 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader: related_model = relationship.entity.entity async def load_fn(keys: List[Tuple]) -> List[Any]: - query = select(related_model).filter( - tuple_( - *[remote for _, remote in relationship.local_remote_pairs or []] - ).in_(keys) + def _build_normal_relationship_query(related_model, relationship, keys): + return select(related_model).filter( + tuple_( + *[ + remote + for _, remote in relationship.local_remote_pairs or [] + ] + ).in_(keys) + ) + + def _build_relationship_with_secondary_table_query( + related_model, relationship, keys + ): + # Use another query when relationship uses a secondary table + self_model = relationship.parent.entity + + if not relationship.local_remote_pairs: + raise InvalidLocalRemotePairs( + f"{related_model.__name__} -- {self_model.__name__}" + ) + + self_model_key_label = str( + relationship.local_remote_pairs[0][1].key + ) + related_model_key_label = str( + relationship.local_remote_pairs[1][1].key + ) + + self_model_key = str(relationship.local_remote_pairs[0][0].key) + related_model_key = str(relationship.local_remote_pairs[1][0].key) + + remote_to_use = relationship.local_remote_pairs[0][1] + query_keys = tuple([item[0] for item in keys]) + + # This query returns rows in this format -> (self_model.key, related_model) + return ( + select( + getattr(self_model, self_model_key).label( + self_model_key_label + ), + related_model, + ) + .join( + relationship.secondary, + getattr(relationship.secondary.c, related_model_key_label) + == getattr(related_model, related_model_key), + ) + .join( + self_model, + getattr(relationship.secondary.c, self_model_key_label) + == getattr(self_model, self_model_key), + ) + .filter(remote_to_use.in_(query_keys)) + ) + + query = ( + _build_normal_relationship_query(related_model, relationship, keys) + if relationship.secondary is None + else _build_relationship_with_secondary_table_query( + related_model, relationship, keys + ) ) + if relationship.order_by: query = query.order_by(*relationship.order_by) - rows = await self._scalars_all(query) + + if relationship.secondary is not None: + # We need to retrieve values from both the self_model and related_model. + # To achieve this, we must disable the default SQLAlchemy optimization + # that returns only related_model values. + # This is necessary because we use the keys variable + # to match both related_model and self_model. + rows = await self._scalars_all(query, query_secondary_tables=True) + else: + rows = await self._scalars_all(query) def group_by_remote_key(row: Any) -> Tuple: return tuple( @@ -82,8 +160,13 @@ def group_by_remote_key(row: Any) -> Tuple: ) grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list) - for row in rows: - grouped_keys[group_by_remote_key(row)].append(row) + if relationship.secondary is None: + for row in rows: + grouped_keys[group_by_remote_key(row)].append(row) + else: + for row in rows: + grouped_keys[(row[0],)].append(row[1]) + if relationship.uselist: return [grouped_keys[key] for key in keys] else: diff --git a/src/strawberry_sqlalchemy_mapper/mapper.py b/src/strawberry_sqlalchemy_mapper/mapper.py index 1d8a888..20bf28c 100644 --- a/src/strawberry_sqlalchemy_mapper/mapper.py +++ b/src/strawberry_sqlalchemy_mapper/mapper.py @@ -82,6 +82,7 @@ from strawberry_sqlalchemy_mapper.exc import ( HybridPropertyNotAnnotated, InterfaceModelNotPolymorphic, + InvalidLocalRemotePairs, UnsupportedAssociationProxyTarget, UnsupportedColumnType, UnsupportedDescriptorType, @@ -387,7 +388,7 @@ def _convert_relationship_to_strawberry_type( if relationship.uselist: # Use list if excluding relay pagination if use_list: - return List[ForwardRef(type_name)] # type: ignore + return List[ForwardRef(type_name)] # type: ignore return self._connection_type_for(type_name) else: @@ -500,13 +501,30 @@ async def resolve(self, info: Info): if relationship.key not in instance_state.unloaded: related_objects = getattr(self, relationship.key) else: - relationship_key = tuple( - [ + if relationship.secondary is None: + relationship_key = tuple( getattr(self, local.key) for local, _ in relationship.local_remote_pairs or [] if local.key - ] - ) + ) + else: + # If has a secondary table, gets only the first ID as additional IDs require a separate query + if not relationship.local_remote_pairs: + raise InvalidLocalRemotePairs( + f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}" + ) + + local_remote_pairs_secondary_table_local = ( + relationship.local_remote_pairs[0][0] + ) + relationship_key = tuple( + [ + getattr( + self, str(local_remote_pairs_secondary_table_local.key) + ), + ] + ) + if any(item is None for item in relationship_key): if relationship.uselist: return [] @@ -536,7 +554,9 @@ def connection_resolver_for( if relationship.uselist and not use_list: return self.make_connection_wrapper_resolver( relationship_resolver, - self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type] + self.model_to_type_or_interface_name( + relationship.entity.entity # type: ignore[arg-type] + ), ) else: return relationship_resolver @@ -785,7 +805,8 @@ def convert(type_: Any) -> Any: # ignore inherited `is_type_of` if "is_type_of" not in type_.__dict__: type_.is_type_of = ( - lambda obj, info: type(obj) == model or type(obj) == type_ + lambda obj, info: type(obj) == model # noqa: E721 + or type(obj) == type_ # noqa: E721 ) # Default querying methods for relay diff --git a/tests/conftest.py b/tests/conftest.py index 7c600f3..1d23424 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,3 +111,463 @@ def async_sessionmaker(async_engine): @pytest.fixture def base(): return orm.declarative_base() + + +@pytest.fixture +def default_employee_department_join_table(base): + EmployeeDepartmentJoinTable = sqlalchemy.Table( + "employee_department_join_table", + base.metadata, + sqlalchemy.Column("employee_id", sqlalchemy.ForeignKey("employee.id"), primary_key=True), + sqlalchemy.Column("department_id", sqlalchemy.ForeignKey("department.id"), primary_key=True), + ) + + +@pytest.fixture +def secondary_tables(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True, nullable=False) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=True) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_another_foreign_key(base): + EmployeeDepartmentJoinTable = sqlalchemy.Table( + "employee_department_join_table", + base.metadata, + sqlalchemy.Column("employee_name", sqlalchemy.ForeignKey("employee.name"), primary_key=True), + sqlalchemy.Column("department_id", sqlalchemy.ForeignKey( + "department.id"), primary_key=True), + ) + + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, nullable=False) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False, primary_key=True) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=True) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_more_secondary_tables(base, default_employee_department_join_table): + EmployeeBuildingJoinTable = sqlalchemy.Table( + "employee_building_join_table", + base.metadata, + sqlalchemy.Column("employee_id", sqlalchemy.ForeignKey("employee.id"), primary_key=True), + sqlalchemy.Column("building_id", sqlalchemy.ForeignKey("building.id"), primary_key=True), + ) + + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + building = orm.relationship( + "Building", + secondary="employee_building_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + class Building(base): + __tablename__ = "building" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_building_join_table", + back_populates="building", + ) + + return Employee, Department, Building + + +@pytest.fixture +def secondary_tables_with_use_list_false(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + uselist=False + ) + + return Employee, Department + + +@pytest.fixture +def secondary_tables_with_normal_relationship(base, default_employee_department_join_table): + class Employee(base): + __tablename__ = "employee" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + role = sqlalchemy.Column(sqlalchemy.String, nullable=True) + department = orm.relationship( + "Department", + secondary="employee_department_join_table", + back_populates="employees", + ) + building_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("building.id")) + building = orm.relationship( + "Building", + back_populates="employees", + ) + + class Department(base): + __tablename__ = "department" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + secondary="employee_department_join_table", + back_populates="department", + ) + + class Building(base): + __tablename__ = "building" + id = sqlalchemy.Column(sqlalchemy.Integer, autoincrement=True, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String, nullable=False) + employees = orm.relationship( + "Employee", + back_populates="building", + ) + + return Employee, Department, Building + + +@pytest.fixture +def expected_schema_from_secondary_tables(): + return ''' + type Department { + id: Int! + name: String + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables(): + return ''' + type Building { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type BuildingConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [BuildingEdge!]! + } + + type BuildingEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Building! + } + + type Department { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + building: BuildingConnection! + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false(): + return ''' + type Department { + id: Int! + name: String! + employees: Employee + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + department: DepartmentConnection! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' + + +@pytest.fixture +def expected_schema_from_secondary_tables_with_more_secondary_tables_with__with_normal_relationship(): + return ''' + type Building { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type Department { + id: Int! + name: String! + employees: EmployeeConnection! + } + + type DepartmentConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [DepartmentEdge!]! + } + + type DepartmentEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Department! + } + + type Employee { + id: Int! + name: String! + role: String + buildingId: Int + department: DepartmentConnection! + building: Building + } + + type EmployeeConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + edges: [EmployeeEdge!]! + } + + type EmployeeEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Employee! + } + + """Information to aid in pagination.""" + type PageInfo { + """When paginating forwards, are there more items?""" + hasNextPage: Boolean! + + """When paginating backwards, are there more items?""" + hasPreviousPage: Boolean! + + """When paginating backwards, the cursor to continue.""" + startCursor: String + + """When paginating forwards, the cursor to continue.""" + endCursor: String + } + + type Query { + departments: [Department!]! + } + ''' diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index 31160c0..e0251e4 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import sessionmaker from strawberry import relay -from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection +from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection, StrawberrySQLAlchemyLoader from strawberry_sqlalchemy_mapper.relay import KeysetConnection @@ -37,7 +37,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -74,7 +75,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -259,7 +261,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -319,7 +322,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -381,7 +385,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -441,7 +446,8 @@ class Fruit(relay.Node): @strawberry.type class Query: - fruits: relay.ListConnection[Fruit] = connection(sessionmaker=sessionmaker) + fruits: relay.ListConnection[Fruit] = connection( + sessionmaker=sessionmaker) schema = strawberry.Schema(query=Query) @@ -467,7 +473,8 @@ class Query: session.commit() result = schema.execute_sync( - query, {"first": 1, "before": relay.to_base64("arrayconnection", 2)} + query, {"first": 1, "before": relay.to_base64( + "arrayconnection", 2)} ) assert result.errors is None @@ -755,3 +762,162 @@ class Query: }, } } + + +# TODO Investigate this test +@pytest.mark.skip("This test is currently failing because the Query with relay.ListConnection generates two DepartmentConnection, which violates the schema's expectations. After investigation, it appears this issue is related to the Relay implementation rather than the secondary table issue. We'll address this later. Additionally, note that the `result.data` may be incorrect in this test.") +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list( + secondary_tables, + base, + async_engine, + async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @strawberry.type + class Query: + departments: relay.ListConnection[Department] = connection( + sessionmaker=async_sessionmaker) + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """\ + query { + departments { + edges { + node { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department1 = DepartmentModel(id=10, name="Department Test 1") + department2 = DepartmentModel(id=3, name="Department Test 2") + e1 = EmployeeModel(id=1, name="John", role="Developer") + e2 = EmployeeModel(id=5, name="Bill", role="Doctor") + e3 = EmployeeModel(id=4, name="Maria", role="Teacher") + department1.employees.append(e1) + department1.employees.append(e2) + department2.employees.append(e3) + session.add_all([department1, department2, e1, e2, e3]) + await session.commit() + + result = await schema.execute(query, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + "departments": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + } + ] + } + } + }, + { + "node": { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2" + } + } + ] + } + } + } + ] + } + } + } + ] + } + } diff --git a/tests/test_loader.py b/tests/test_loader.py index df33189..6eb404e 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, ForeignKey, Integer, String, Table from sqlalchemy.orm import relationship from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader +from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs pytest_plugins = ("pytest_asyncio",) @@ -26,38 +27,6 @@ class Department(base): return Employee, Department -@pytest.fixture -def secondary_tables(base): - EmployeeDepartmentJoinTable = Table( - "employee_department_join_table", - base.metadata, - Column("employee_id", ForeignKey("employee.e_id"), primary_key=True), - Column("department_id", ForeignKey("department.d_id"), primary_key=True), - ) - - class Employee(base): - __tablename__ = "employee" - e_id = Column(Integer, autoincrement=True, primary_key=True) - name = Column(String, nullable=False) - departments = relationship( - "Department", - secondary="employee_department_join_table", - back_populates="employees", - ) - - class Department(base): - __tablename__ = "department" - d_id = Column(Integer, autoincrement=True, primary_key=True) - name = Column(String, nullable=False) - employees = relationship( - "Employee", - secondary="employee_department_join_table", - back_populates="departments", - ) - - return Employee, Department - - def test_loader_init(): loader = StrawberrySQLAlchemyLoader(bind=None) assert loader._bind is None @@ -146,36 +115,156 @@ async def test_loader_with_async_session( assert {e.name for e in employees} == {"e1"} -@pytest.mark.xfail +def create_default_data_on_secondary_table_tests(session, Employee, Department): + e1 = Employee(name="e1", id=1) + e2 = Employee(name="e2", id=2) + d1 = Department(name="d1") + d2 = Department(name="d2") + d3 = Department(name="d3") + session.add_all([e1, e2, d1, d2, d3]) + session.flush() + + e1.department.append(d1) + e1.department.append(d2) + e2.department.append(d2) + return e1, e2, d1, d2, d3 + + @pytest.mark.asyncio -async def test_loader_for_secondary(engine, base, sessionmaker, secondary_tables): +async def test_loader_for_secondary_table(engine, base, sessionmaker, secondary_tables): Employee, Department = secondary_tables base.metadata.create_all(engine) with sessionmaker() as session: - e1 = Employee(name="e1") - e2 = Employee(name="e2") - d1 = Department(name="d1") - d2 = Department(name="d2") - session.add(e1) - session.add(e2) - session.add(d1) - session.add(d2) - session.flush() + e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department) + session.commit() - e1.departments.append(d1) - e1.departments.append(d2) - e2.departments.append(d2) + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr( + e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_another_foreign_key(engine, base, sessionmaker, secondary_tables_with_another_foreign_key): + Employee, Department = secondary_tables_with_another_foreign_key + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department) session.commit() base_loader = StrawberrySQLAlchemyLoader(bind=session) - loader = base_loader.loader_for(Employee.departments.property) + loader = base_loader.loader_for(Employee.department.property) key = tuple( [ - getattr(e1, local.key) - for local, _ in Employee.departments.property.local_remote_pairs + getattr( + e1, str(Employee.department.property.local_remote_pairs[0][0].key)), ] ) + departments = await loader.load(key) assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_more_secondary_tables(engine, base, sessionmaker, secondary_tables_with_more_secondary_tables): + Employee, Department, Building = secondary_tables_with_more_secondary_tables + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, e2, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department) + + b1 = Building(id=2, name="Building 1") + b1.employees.append(e1) + b1.employees.append(e2) + session.add(b1) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr( + e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_use_list_false(engine, base, sessionmaker, secondary_tables_with_use_list_false): + Employee, Department = secondary_tables_with_use_list_false + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, _, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr( + e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_with_normal_relationship(engine, base, sessionmaker, secondary_tables_with_normal_relationship): + Employee, Department, Building = secondary_tables_with_normal_relationship + base.metadata.create_all(engine) + + with sessionmaker() as session: + e1, e2, _, _, _ = create_default_data_on_secondary_table_tests(session=session, Employee=Employee, Department=Department) + + b1 = Building(id=2, name="Building 1") + b1.employees.append(e1) + b1.employees.append(e2) + session.add(b1) + session.commit() + + base_loader = StrawberrySQLAlchemyLoader(bind=session) + loader = base_loader.loader_for(Employee.department.property) + + key = tuple( + [ + getattr( + e1, str(Employee.department.property.local_remote_pairs[0][0].key)), + ] + ) + + departments = await loader.load(key) + assert {d.name for d in departments} == {"d1", "d2"} + + +@pytest.mark.asyncio +async def test_loader_for_secondary_tables_should_raise_exception_if_relationship_dont_has_local_remote_pairs(engine, base, sessionmaker, secondary_tables_with_normal_relationship): + Employee, Department, Building = secondary_tables_with_normal_relationship + base.metadata.create_all(engine) + + with sessionmaker() as session: + base_loader = StrawberrySQLAlchemyLoader(bind=session) + + Employee.department.property.local_remote_pairs = [] + loader = base_loader.loader_for(Employee.department.property) + + with pytest.raises(expected_exception=InvalidLocalRemotePairs): + await loader.load((1,)) diff --git a/tests/test_mapper.py b/tests/test_mapper.py index 7875b3a..d75636e 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -379,3 +379,123 @@ def departments(self) -> Department: ... } ''' assert str(schema) == textwrap.dedent(expected).strip() + + +def test_relationships_schema_with_secondary_tables(secondary_tables, mapper, expected_schema_from_secondary_tables): + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip() + + +def test_relationships_schema_with_secondary_tables_with_another_foreign_key(secondary_tables_with_another_foreign_key, mapper, expected_schema_from_secondary_tables): + EmployeeModel, DepartmentModel = secondary_tables_with_another_foreign_key + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables).strip() + + +def test_relationships_schema_with_secondary_tables_with_more_secondary_tables(secondary_tables_with_more_secondary_tables, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables): + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_more_secondary_tables + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(BuildingModel) + class Building: + pass + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables).strip() + + +def test_relationships_schema_with_secondary_tables_with_use_list_false(secondary_tables_with_use_list_false, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false): + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables_with_use_list_false).strip() + + +def test_relationships_schema_with_secondary_tables_with_normal_relationship(secondary_tables_with_normal_relationship, mapper, expected_schema_from_secondary_tables_with_more_secondary_tables_with__with_normal_relationship): + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_normal_relationship + + @mapper.type(EmployeeModel) + class Employee: + pass + + @mapper.type(DepartmentModel) + class Department: + pass + + @mapper.type(BuildingModel) + class Building(): + pass + + + @strawberry.type + class Query: + @strawberry.field + def departments(self) -> List[Department]: ... + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + assert str(schema) == textwrap.dedent(expected_schema_from_secondary_tables_with_more_secondary_tables_with__with_normal_relationship).strip() diff --git a/tests/test_secondary_tables_query.py b/tests/test_secondary_tables_query.py new file mode 100644 index 0000000..72a78c5 --- /dev/null +++ b/tests/test_secondary_tables_query.py @@ -0,0 +1,874 @@ +from typing import List + +import pytest +import strawberry +from sqlalchemy import select +from strawberry import relay +from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection, StrawberrySQLAlchemyLoader + + +@pytest.fixture +def default_query_secondary_table(): + return """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + } + } + } + } + } + } + """ + + +def created_default_secondary_table_data(session, EmployeeModel, DepartmentModel): + department1 = DepartmentModel(id=10, name="Department Test 1") + department2 = DepartmentModel(id=3, name="Department Test 2") + e1 = EmployeeModel(id=1, name="John", role="Developer") + e2 = EmployeeModel(id=5, name="Bill", role="Doctor") + e3 = EmployeeModel(id=4, name="Maria", role="Teacher") + department1.employees.append(e1) + department1.employees.append(e2) + department2.employees.append(e3) + session.add_all([department1, department2, e1, e2, e3]) + return e1, e2, e3, department1, department2 + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_without_list_connection( + secondary_tables, + base, + async_engine, + async_sessionmaker, + default_query_secondary_table +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + created_default_secondary_table_data(session=session, EmployeeModel=EmployeeModel, DepartmentModel=DepartmentModel) + await session.commit() + + result = await schema.execute(default_query_secondary_table, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + } + ] + } + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2" + } + } + ] + } + } + } + ] + } + } + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_with_foreign_key_different_than_id( + secondary_tables_with_another_foreign_key, + base, + async_engine, + async_sessionmaker, + default_query_secondary_table +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_another_foreign_key + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + created_default_secondary_table_data(session=session, EmployeeModel=EmployeeModel, DepartmentModel=DepartmentModel) + await session.commit() + + result = await schema.execute(default_query_secondary_table, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + } + } + } + ] + } + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2" + } + } + ] + } + } + } + ] + } + } + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_tables_with_more_than_2_colluns_values_list( + secondary_tables_with_more_secondary_tables, + base, + async_engine, + async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_more_secondary_tables + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @mapper.type(BuildingModel) + class Building(): + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + }, + building { + edges { + node { + id + name + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + building = BuildingModel(id=2, name="Building 1") + e1, e2, e3, _, _ = created_default_secondary_table_data(session=session, EmployeeModel=EmployeeModel, DepartmentModel=DepartmentModel) + building.employees.append(e1) + building.employees.append(e2) + building.employees.append(e3) + session.add(building) + await session.commit() + + result = await schema.execute(query, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + }, + "building": { + "edges": [ + { + "node": { + "id": 2, + "name": "Building 1" + } + } + ] + } + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + }, + "building": { + "edges": [ + { + "node": { + "id": 2, + "name": "Building 1" + } + } + ] + } + } + } + ] + } + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2" + } + } + ] + }, + "building": { + "edges": [ + { + "node": { + "id": 2, + "name": "Building 1" + } + } + ] + } + } + } + ] + } + } + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table( + secondary_tables_with_use_list_false, + base, + async_engine, + async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @strawberry.type + class Query: + employees: relay.ListConnection[Employee] = connection( + sessionmaker=async_sessionmaker) + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + employees { + id + name + role + } + } + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department = DepartmentModel(name="Department Test") + e1 = EmployeeModel(name="John", role="Developer") + e2 = EmployeeModel(name="Bill", role="Doctor") + e3 = EmployeeModel(name="Maria", role="Teacher") + e1.department.append(department) + session.add_all([department, e1, e2, e3]) + await session.commit() + + result = await schema.execute(query, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + 'employees': { + 'edges': [ + { + 'node': { + 'id': 1, + 'name': 'John', + 'role': 'Developer', + 'department': { + 'edges': [ + { + 'node': { + 'id': 1, + 'name': 'Department Test', + 'employees': { + 'id': 1, + 'name': 'John', + 'role': 'Developer' + } + } + } + ] + } + } + }, + { + 'node': { + 'id': 2, + 'name': 'Bill', + 'role': 'Doctor', + 'department': { + 'edges': [] + } + } + }, + { + 'node': { + 'id': 3, + 'name': 'Maria', + 'role': 'Teacher', + 'department': { + 'edges': [] + } + } + } + ] + } + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_without_list_connection( + secondary_tables_with_use_list_false, + base, + async_engine, + async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel = secondary_tables_with_use_list_false + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @strawberry.type + class Query: + @strawberry.field + async def employees(self) -> List[Employee]: + async with async_sessionmaker() as session: + result = await session.execute(select(EmployeeModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + employees { + id + name + role + department { + edges { + node { + id + name + employees { + id + name + role + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + department = DepartmentModel(name="Department Test") + e1 = EmployeeModel(name="John", role="Developer") + e2 = EmployeeModel(name="Bill", role="Doctor") + e3 = EmployeeModel(name="Maria", role="Teacher") + e1.department.append(department) + session.add_all([department, e1, e2, e3]) + await session.commit() + + result = await schema.execute(query, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + 'employees': [ + { + 'id': 1, + 'name': 'John', + 'role': 'Developer', + 'department': { + 'edges': [ + { + 'node': { + 'id': 1, + 'name': 'Department Test', + 'employees': { + 'id': 1, + 'name': 'John', + 'role': 'Developer' + } + } + } + ] + } + }, + { + 'id': 2, + 'name': 'Bill', + 'role': 'Doctor', + 'department': { + 'edges': [] + } + }, + { + 'id': 3, + 'name': 'Maria', + 'role': 'Teacher', + 'department': { + 'edges': [] + } + } + ] + } + + +@pytest.mark.asyncio +async def test_query_with_secondary_table_with_values_list_and_normal_relationship( + secondary_tables_with_normal_relationship, + base, + async_engine, + async_sessionmaker +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + mapper = StrawberrySQLAlchemyMapper() + EmployeeModel, DepartmentModel, BuildingModel = secondary_tables_with_normal_relationship + + @mapper.type(DepartmentModel) + class Department(): + pass + + @mapper.type(EmployeeModel) + class Employee(): + pass + + @mapper.type(BuildingModel) + class Building(): + pass + + @strawberry.type + class Query: + @strawberry.field + async def departments(self) -> List[Department]: + async with async_sessionmaker() as session: + result = await session.execute(select(DepartmentModel)) + return result.scalars().all() + + mapper.finalize() + schema = strawberry.Schema(query=Query) + + query = """ + query { + departments { + id + name + employees { + edges { + node { + id + name + role + department { + edges { + node { + id + name + } + } + }, + building { + id + name + } + } + } + } + } + } + """ + + # Create test data + async with async_sessionmaker(expire_on_commit=False) as session: + e1, e2, e3, _, _ = created_default_secondary_table_data(session=session, EmployeeModel=EmployeeModel, DepartmentModel=DepartmentModel) + building = BuildingModel(id=2, name="Building 1") + building.employees.append(e1) + building.employees.append(e2) + building.employees.append(e3) + session.add(building) + await session.commit() + + result = await schema.execute(query, context_value={ + "sqlalchemy_loader": StrawberrySQLAlchemyLoader( + async_bind_factory=async_sessionmaker + ) + }) + assert result.errors is None + assert result.data == { + "departments": [ + { + "id": 10, + "name": "Department Test 1", + "employees": { + "edges": [ + { + "node": { + "id": 5, + "name": "Bill", + "role": "Doctor", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + }, + "building": { + "id": 2, + "name": "Building 1" + } + } + }, + { + "node": { + "id": 1, + "name": "John", + "role": "Developer", + "department": { + "edges": [ + { + "node": { + "id": 10, + "name": "Department Test 1" + } + } + ] + }, + "building": { + "id": 2, + "name": "Building 1" + } + } + } + ] + } + }, + { + "id": 3, + "name": "Department Test 2", + "employees": { + "edges": [ + { + "node": { + "id": 4, + "name": "Maria", + "role": "Teacher", + "department": { + "edges": [ + { + "node": { + "id": 3, + "name": "Department Test 2" + } + } + ] + }, + "building": { + "id": 2, + "name": "Building 1" + } + } + } + ] + } + } + ] + }