|
| 1 | +from typing import Annotated |
| 2 | + |
| 3 | +import attrs |
| 4 | +import pytest |
| 5 | + |
| 6 | +from cattrs.gen._shared import get_fields_annotated_by |
| 7 | + |
| 8 | + |
| 9 | +class NotThere: ... |
| 10 | + |
| 11 | + |
| 12 | +class IgnoreMe: |
| 13 | + def __init__(self, why: str | None = None): |
| 14 | + self.why = why |
| 15 | + |
| 16 | + |
| 17 | +class FindMe: |
| 18 | + def __init__(self, taint: str): |
| 19 | + self.taint = taint |
| 20 | + |
| 21 | + |
| 22 | +class EmptyClassExample: |
| 23 | + pass |
| 24 | + |
| 25 | + |
| 26 | +class PureClassExample: |
| 27 | + id: Annotated[int, FindMe("red")] |
| 28 | + name: Annotated[str, FindMe] |
| 29 | + |
| 30 | + |
| 31 | +class MultipleAnnotationsExample: |
| 32 | + id: Annotated[int, FindMe("red"), IgnoreMe()] |
| 33 | + name: Annotated[str, IgnoreMe()] |
| 34 | + surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")] |
| 35 | + |
| 36 | + |
| 37 | +@attrs.define |
| 38 | +class AttrsClassExample: |
| 39 | + id: int = attrs.field(default=0) |
| 40 | + color: Annotated[str, FindMe("blue")] = attrs.field(default="red") |
| 41 | + config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict) |
| 42 | + |
| 43 | + |
| 44 | +class PureClassInheritanceExample(PureClassExample): |
| 45 | + include: dict |
| 46 | + exclude: Annotated[dict, FindMe("boring things")] |
| 47 | + extras: Annotated[dict, FindMe] |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.parametrize( |
| 51 | + "klass,expected", |
| 52 | + [ |
| 53 | + (EmptyClassExample, {}), |
| 54 | + (PureClassExample, {"id": isinstance}), |
| 55 | + (AttrsClassExample, {"color": isinstance, "config": isinstance}), |
| 56 | + (MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}), |
| 57 | + (PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}), |
| 58 | + ], |
| 59 | +) |
| 60 | +@pytest.mark.parametrize("instantiate", [True, False]) |
| 61 | +def test_gets_annotated_types(klass, expected, instantiate: bool): |
| 62 | + annotated = get_fields_annotated_by(klass, FindMe("irrelevant") if instantiate else FindMe) |
| 63 | + |
| 64 | + assert set(annotated.keys()) == set(expected.keys()), "Too many or too few annotations" |
| 65 | + assert all( |
| 66 | + assertion_func(annotated[field_name], FindMe) for field_name, assertion_func in expected.items() |
| 67 | + ), "Unexpected type of annotation" |
| 68 | + |
| 69 | + |
| 70 | +def test_empty_result_for_missing_annotation(): |
| 71 | + annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere) |
| 72 | + assert not annotated, "No annotation should be found." |
0 commit comments