diff --git a/src/desert/_make.py b/src/desert/_make.py index 4c4892e..66d5a9f 100644 --- a/src/desert/_make.py +++ b/src/desert/_make.py @@ -220,8 +220,8 @@ def field_for_schema( if default is not marshmallow.missing: desert_metadata.setdefault("dump_default", default) - desert_metadata.setdefault("allow_none", True) desert_metadata.setdefault("load_default", default) + desert_metadata.setdefault("allow_none", True) field = None @@ -235,9 +235,20 @@ def field_for_schema( field.metadata.update(metadata) return field + field_args = { + k: v + for k, v in desert_metadata.items() + if k + in [ + "dump_default", + "load_default", + "allow_none", + ] + } + # Base types if not field and typ in _native_to_marshmallow: - field = _native_to_marshmallow[typ](dump_default=default) + field = _native_to_marshmallow[typ](**field_args) # Generic types origin = typing_inspect.get_origin(typ) @@ -253,16 +264,18 @@ def field_for_schema( collections.abc.Sequence, collections.abc.MutableSequence, ): - field = marshmallow.fields.List(field_for_schema(arguments[0])) + field = marshmallow.fields.List( + field_for_schema(arguments[0]), **field_args + ) if origin in (tuple, t.Tuple) and Ellipsis not in arguments: field = marshmallow.fields.Tuple( # type: ignore[no-untyped-call] - tuple(field_for_schema(arg) for arg in arguments) + tuple(field_for_schema(arg) for arg in arguments), **field_args ) elif origin in (tuple, t.Tuple) and Ellipsis in arguments: - field = VariadicTuple( - field_for_schema(only(arg for arg in arguments if arg != Ellipsis)) + field_for_schema(only(arg for arg in arguments if arg != Ellipsis)), + **field_args, ) elif origin in ( dict, @@ -275,47 +288,40 @@ def field_for_schema( field = marshmallow.fields.Dict( keys=field_for_schema(arguments[0]), values=field_for_schema(arguments[1]), + **field_args, ) elif typing_inspect.is_optional_type(typ): [subtyp] = (t for t in arguments if t is not NoneType) # Treat optional types as types with a None default - metadata[_DESERT_SENTINEL]["dump_default"] = metadata.get( - "dump_default", None - ) - metadata[_DESERT_SENTINEL]["load_default"] = metadata.get( - "load_default", None - ) - metadata[_DESERT_SENTINEL]["required"] = False - - field = field_for_schema(subtyp, metadata=metadata, default=None) - field.dump_default = None - field.load_default = None - field.allow_none = True + metadata[_DESERT_SENTINEL]["allow_none"] = True + if default is marshmallow.missing: + default = None + field = field_for_schema(subtyp, metadata=metadata, default=default) elif typing_inspect.is_union_type(typ): subfields = [field_for_schema(subtyp) for subtyp in arguments] import marshmallow_union - field = marshmallow_union.Union(subfields) + field = marshmallow_union.Union(subfields, **field_args) # t.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) if newtype_supertype and typing_inspect.is_new_type(typ): metadata.setdefault("description", typ.__name__) - field = field_for_schema(newtype_supertype, default=default) + field = field_for_schema(newtype_supertype, metadata=metadata, default=default) # enumerations if type(typ) is enum.EnumMeta: import marshmallow_enum - field = marshmallow_enum.EnumField(typ, metadata=metadata) + field = marshmallow_enum.EnumField(typ, metadata=metadata, **field_args) # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) if field is None: nested = forward_reference or class_schema(typ) - field = marshmallow.fields.Nested(nested) + field = marshmallow.fields.Nested(nested, **field_args) field.metadata.update(metadata) @@ -350,11 +356,11 @@ def _get_field_default( if isinstance(field, dataclasses.Field): # misc: https://github.com/python/mypy/issues/10750 # comparison-overlap: https://github.com/python/typeshed/pull/5900 - if field.default_factory != dataclasses.MISSING: - return dataclasses.MISSING - if field.default is dataclasses.MISSING: - return marshmallow.missing - return field.default + if field.default_factory is not dataclasses.MISSING: + return field.default_factory + if field.default is not dataclasses.MISSING: + return field.default + return marshmallow.missing elif isinstance(field, attr.Attribute): if field.default == attr.NOTHING: return marshmallow.missing diff --git a/tests/test_make.py b/tests/test_make.py index 9a2991c..9f1a1d9 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -224,6 +224,17 @@ class A: assert data == A(None) # type: ignore[call-arg] +def test_optional_default(module: DataclassModule) -> None: + """Setting an optional type allows passing None.""" + + @module.dataclass + class A: + x: t.Optional[int] = 1 + + data = desert.schema_class(A)().load({}) + assert data == A(1) # type: ignore[call-arg] + + def test_custom_field(module: DataclassModule) -> None: @module.dataclass class A: