Skip to content

Commit 2330122

Browse files
authored
Add support for typing.Annotated (#721)
1 parent 29ae3e1 commit 2330122

File tree

7 files changed

+115
-31
lines changed

7 files changed

+115
-31
lines changed

examples/wiring/example.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,27 @@
22

33
from dependency_injector import containers, providers
44
from dependency_injector.wiring import Provide, inject
5+
from typing import Annotated
56

67

7-
class Service:
8-
...
8+
class Service: ...
99

1010

1111
class Container(containers.DeclarativeContainer):
1212

1313
service = providers.Factory(Service)
1414

1515

16+
# You can place marker on parameter default value
1617
@inject
17-
def main(service: Service = Provide[Container.service]) -> None:
18-
...
18+
def main(service: Service = Provide[Container.service]) -> None: ...
19+
20+
21+
# Also, you can place marker with typing.Annotated
22+
@inject
23+
def main_with_annotated(
24+
service: Annotated[Service, Provide[Container.service]]
25+
) -> None: ...
1926

2027

2128
if __name__ == "__main__":

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ numpy
1818
scipy
1919
boto3
2020
mypy_boto3_s3
21+
typing_extensions
2122

2223
-r requirements-ext.txt

src/dependency_injector/wiring.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ class GenericMeta(type): ...
3737
else:
3838
GenericAlias = None
3939

40+
if sys.version_info >= (3, 9):
41+
from typing import Annotated, get_args, get_origin
42+
else:
43+
try:
44+
from typing_extensions import Annotated, get_args, get_origin
45+
except ImportError:
46+
Annotated = object()
47+
48+
# For preventing NameError. Never executes
49+
def get_args(hint):
50+
return ()
51+
52+
def get_origin(tp):
53+
return None
54+
4055

4156
try:
4257
import fastapi.params
@@ -572,6 +587,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
572587
setattr(patched.member, patched.name, patched.marker)
573588

574589

590+
def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
591+
if get_origin(parameter.annotation) is Annotated:
592+
marker = get_args(parameter.annotation)[1]
593+
else:
594+
marker = parameter.default
595+
596+
if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
597+
return None
598+
599+
if _is_fastapi_depends(marker):
600+
marker = marker.dependency
601+
602+
if not isinstance(marker, _Marker):
603+
return None
604+
605+
return marker
606+
607+
575608
def _fetch_reference_injections( # noqa: C901
576609
fn: Callable[..., Any],
577610
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -596,18 +629,10 @@ def _fetch_reference_injections( # noqa: C901
596629
injections = {}
597630
closing = {}
598631
for parameter_name, parameter in signature.parameters.items():
599-
if not isinstance(parameter.default, _Marker) and not _is_fastapi_depends(
600-
parameter.default
601-
):
602-
continue
632+
marker = _extract_marker(parameter)
603633

604-
marker = parameter.default
605-
606-
if _is_fastapi_depends(marker):
607-
marker = marker.dependency
608-
609-
if not isinstance(marker, _Marker):
610-
continue
634+
if marker is None:
635+
continue
611636

612637
if isinstance(marker, Closing):
613638
marker = marker.provider

tests/unit/samples/wiringfastapi/web.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import sys
22

3+
from typing_extensions import Annotated
4+
35
from fastapi import FastAPI, Depends
4-
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
6+
from fastapi import (
7+
Request,
8+
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
59
from fastapi.security import HTTPBasic, HTTPBasicCredentials
610
from dependency_injector import containers, providers
711
from dependency_injector.wiring import inject, Provide
@@ -28,11 +32,16 @@ async def index(service: Service = Depends(Provide[Container.service])):
2832
return {"result": result}
2933

3034

35+
@app.api_route("/annotated")
36+
@inject
37+
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
38+
result = await service.process()
39+
return {"result": result}
40+
41+
3142
@app.get("/auth")
3243
@inject
33-
def read_current_user(
34-
credentials: HTTPBasicCredentials = Depends(security)
35-
):
44+
def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
3645
return {"username": credentials.username, "password": credentials.password}
3746

3847

tests/unit/samples/wiringflask/web.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing_extensions import Annotated
2+
13
from flask import Flask, jsonify, request, current_app, session, g
24
from dependency_injector import containers, providers
35
from dependency_injector.wiring import inject, Provide
@@ -26,5 +28,12 @@ def index(service: Service = Provide[Container.service]):
2628
return jsonify({"result": result})
2729

2830

31+
@app.route("/annotated")
32+
@inject
33+
def annotated(service: Annotated[Service, Provide[Container.service]]):
34+
result = service.process()
35+
return jsonify({"result": result})
36+
37+
2938
container = Container()
3039
container.wire(modules=[__name__])

tests/unit/wiring/test_fastapi_py36.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
66
import os
7+
78
_SAMPLES_DIR = os.path.abspath(
8-
os.path.sep.join((
9-
os.path.dirname(__file__),
10-
"../samples/",
11-
)),
9+
os.path.sep.join(
10+
(
11+
os.path.dirname(__file__),
12+
"../samples/",
13+
)
14+
),
1215
)
1316
import sys
17+
1418
sys.path.append(_SAMPLES_DIR)
1519

1620

@@ -37,6 +41,19 @@ async def process(self):
3741
assert response.json() == {"result": "Foo"}
3842

3943

44+
@mark.asyncio
45+
async def test_depends_with_annotated(async_client: AsyncClient):
46+
class ServiceMock:
47+
async def process(self):
48+
return "Foo"
49+
50+
with web.container.service.override(ServiceMock()):
51+
response = await async_client.get("/")
52+
53+
assert response.status_code == 200
54+
assert response.json() == {"result": "Foo"}
55+
56+
4057
@mark.asyncio
4158
async def test_depends_injection(async_client: AsyncClient):
4259
response = await async_client.get("/auth", auth=("john_smith", "secret"))

tests/unit/wiring/test_flask_py36.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,25 @@
22

33
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
44
import os
5+
56
_TOP_DIR = os.path.abspath(
6-
os.path.sep.join((
7-
os.path.dirname(__file__),
8-
"../",
9-
)),
7+
os.path.sep.join(
8+
(
9+
os.path.dirname(__file__),
10+
"../",
11+
)
12+
),
1013
)
1114
_SAMPLES_DIR = os.path.abspath(
12-
os.path.sep.join((
13-
os.path.dirname(__file__),
14-
"../samples/",
15-
)),
15+
os.path.sep.join(
16+
(
17+
os.path.dirname(__file__),
18+
"../samples/",
19+
)
20+
),
1621
)
1722
import sys
23+
1824
sys.path.append(_TOP_DIR)
1925
sys.path.append(_SAMPLES_DIR)
2026

@@ -29,3 +35,13 @@ def test_wiring_with_flask():
2935

3036
assert response.status_code == 200
3137
assert json.loads(response.data) == {"result": "OK"}
38+
39+
40+
def test_wiring_with_annotated():
41+
client = web.app.test_client()
42+
43+
with web.app.app_context():
44+
response = client.get("/annotated")
45+
46+
assert response.status_code == 200
47+
assert json.loads(response.data) == {"result": "OK"}

0 commit comments

Comments
 (0)