5
5
from distutils .util import strtobool
6
6
from typing import Union , Type , Optional , Sequence , Dict , Any , List
7
7
8
+ import typing
9
+
8
10
from .field_mapper import FieldMapper
9
11
from .exceptions import CsvValueError
10
12
@@ -29,6 +31,20 @@ def _verify_duplicate_header_items(header):
29
31
)
30
32
31
33
34
+ def is_union_type (t ):
35
+ if hasattr (t , "__origin__" ) and t .__origin__ is Union :
36
+ return True
37
+
38
+ return False
39
+
40
+
41
+ def get_args (t ):
42
+ if hasattr (t , "__args__" ):
43
+ return t .__args__
44
+
45
+ return tuple ()
46
+
47
+
32
48
class DataclassReader :
33
49
def __init__ (
34
50
self ,
@@ -61,6 +77,8 @@ def __init__(
61
77
if validate_header :
62
78
_verify_duplicate_header_items (self ._reader .fieldnames )
63
79
80
+ self .type_hints = typing .get_type_hints (cls )
81
+
64
82
def _get_optional_fields (self ):
65
83
return [
66
84
field .name
@@ -175,20 +193,12 @@ def _process_row(self, row):
175
193
values .append (None )
176
194
continue
177
195
178
- field_type = field .type
179
- # Special handling for Optional (Union of a single real type and None)
180
- if (
181
- # The first part of the condition is for Python < 3.8
182
- type (field_type ).__name__ == "_Union"
183
- # The second part of the condition is for Python >= 3.8
184
- or "__origin__" in field_type .__dict__
185
- and field_type .__origin__ is Union
186
- ):
187
- real_types = [
188
- t for t in field_type .__args__ if t is not type (None ) # noqa: E721
189
- ]
190
- if len (real_types ) == 1 :
191
- field_type = real_types [0 ]
196
+ field_type = self .type_hints [field .name ]
197
+
198
+ if is_union_type (field_type ):
199
+ type_args = [x for x in get_args (field_type ) if x is not type (None )]
200
+ if len (type_args ) == 1 :
201
+ field_type = type_args [0 ]
192
202
193
203
if field_type is datetime :
194
204
try :
0 commit comments