Skip to content

Commit 3ddc9a8

Browse files
committed
Allow field overrides via Annotated
1 parent dbe138b commit 3ddc9a8

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed

src/cattrs/gen/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ._consts import AttributeOverride, already_generating, neutral
3434
from ._generics import generate_mapping
3535
from ._lc import generate_unique_filename
36-
from ._shared import find_structure_handler
36+
from ._shared import find_structure_handler, get_fields_annotated_by
3737

3838
if TYPE_CHECKING:
3939
from ..converters import BaseConverter
@@ -264,6 +264,10 @@ def make_dict_unstructure_fn(
264264

265265
working_set.add(cl)
266266

267+
# Merge overrides provided via Annotated with kwargs
268+
annotated_overrides = get_fields_annotated_by(cl, AttributeOverride)
269+
annotated_overrides.update(kwargs)
270+
267271
try:
268272
return make_dict_unstructure_fn_from_attrs(
269273
attrs,
@@ -274,7 +278,7 @@ def make_dict_unstructure_fn(
274278
_cattrs_use_linecache=_cattrs_use_linecache,
275279
_cattrs_use_alias=_cattrs_use_alias,
276280
_cattrs_include_init_false=_cattrs_include_init_false,
277-
**kwargs,
281+
**annotated_overrides,
278282
)
279283
finally:
280284
working_set.remove(cl)

src/cattrs/gen/_shared.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, TypeVar, get_type_hints
44

55
from attrs import NOTHING, Attribute, Factory
66

@@ -10,8 +10,10 @@
1010
from ..fns import raise_error
1111

1212
if TYPE_CHECKING:
13+
from collections.abc import Mapping
1314
from ..converters import BaseConverter
1415

16+
T = TypeVar("T")
1517

1618
def find_structure_handler(
1719
a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False
@@ -62,3 +64,28 @@ def handler(v, _, _h=handler):
6264
except RecursionError:
6365
# This means we're dealing with a reference cycle, so use late binding.
6466
return c.structure
67+
68+
69+
def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]:
70+
type_hints = get_type_hints(cls, include_extras=True)
71+
# Support for both AttributeOverride and AttributeOverride()
72+
annotation_type_ = annotation_type if isinstance(annotation_type, type) else type(annotation_type)
73+
74+
# First pass of filtering to get only fields with annotations
75+
fields_with_annotations = (
76+
(field_name, param_spec.__metadata__)
77+
for field_name, param_spec in type_hints.items()
78+
if hasattr(param_spec, "__metadata__")
79+
)
80+
81+
# Now that we have fields with ANY annotations, we need to remove unwanted annotations.
82+
fields_with_specific_annotation = (
83+
(
84+
field_name,
85+
next((a for a in annotations if isinstance(a, annotation_type_)), None),
86+
)
87+
for field_name, annotations in fields_with_annotations
88+
)
89+
90+
# We still might have some `None` values from previous filtering.
91+
return {field_name: annotation for field_name, annotation in fields_with_specific_annotation if annotation}

tests/test_annotated_overrides.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Annotated
2+
3+
import attrs
4+
import pytest
5+
6+
from cattrs.gen._shared import get_fields_annotated_by
7+
8+
9+
class NotThere: ...
10+
11+
12+
class IgnoreMe:
13+
def __init__(self, why: str | None = None):
14+
self.why = why
15+
16+
17+
class FindMe:
18+
def __init__(self, taint: str):
19+
self.taint = taint
20+
21+
22+
class EmptyClassExample:
23+
pass
24+
25+
26+
class PureClassExample:
27+
id: Annotated[int, FindMe("red")]
28+
name: Annotated[str, FindMe]
29+
30+
31+
class MultipleAnnotationsExample:
32+
id: Annotated[int, FindMe("red"), IgnoreMe()]
33+
name: Annotated[str, IgnoreMe()]
34+
surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")]
35+
36+
37+
@attrs.define
38+
class AttrsClassExample:
39+
id: int = attrs.field(default=0)
40+
color: Annotated[str, FindMe("blue")] = attrs.field(default="red")
41+
config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict)
42+
43+
44+
class PureClassInheritanceExample(PureClassExample):
45+
include: dict
46+
exclude: Annotated[dict, FindMe("boring things")]
47+
extras: Annotated[dict, FindMe]
48+
49+
50+
@pytest.mark.parametrize(
51+
"klass,expected",
52+
[
53+
(EmptyClassExample, {}),
54+
(PureClassExample, {"id": isinstance}),
55+
(AttrsClassExample, {"color": isinstance, "config": isinstance}),
56+
(MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}),
57+
(PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}),
58+
],
59+
)
60+
@pytest.mark.parametrize("instantiate", [True, False])
61+
def test_gets_annotated_types(klass, expected, instantiate: bool):
62+
annotated = get_fields_annotated_by(klass, FindMe("irrelevant") if instantiate else FindMe)
63+
64+
assert set(annotated.keys()) == set(expected.keys()), "Too many or too few annotations"
65+
assert all(
66+
assertion_func(annotated[field_name], FindMe) for field_name, assertion_func in expected.items()
67+
), "Unexpected type of annotation"
68+
69+
70+
def test_empty_result_for_missing_annotation():
71+
annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere)
72+
assert not annotated, "No annotation should be found."

0 commit comments

Comments
 (0)