Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isra/optional default #269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions src/desert/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def field_for_schema(

if default is not marshmallow.missing:
desert_metadata.setdefault("dump_default", default)
desert_metadata.setdefault("allow_none", True)
desert_metadata.setdefault("load_default", default)
desert_metadata.setdefault("allow_none", True)

field = None

Expand All @@ -235,9 +235,20 @@ def field_for_schema(
field.metadata.update(metadata)
return field

field_args = {
k: v
for k, v in desert_metadata.items()
if k
in [
"dump_default",
"load_default",
"allow_none",
]
}

# Base types
if not field and typ in _native_to_marshmallow:
field = _native_to_marshmallow[typ](dump_default=default)
field = _native_to_marshmallow[typ](**field_args)

# Generic types
origin = typing_inspect.get_origin(typ)
Expand All @@ -253,16 +264,18 @@ def field_for_schema(
collections.abc.Sequence,
collections.abc.MutableSequence,
):
field = marshmallow.fields.List(field_for_schema(arguments[0]))
field = marshmallow.fields.List(
field_for_schema(arguments[0]), **field_args
)

if origin in (tuple, t.Tuple) and Ellipsis not in arguments:
field = marshmallow.fields.Tuple( # type: ignore[no-untyped-call]
tuple(field_for_schema(arg) for arg in arguments)
tuple(field_for_schema(arg) for arg in arguments), **field_args
)
elif origin in (tuple, t.Tuple) and Ellipsis in arguments:

field = VariadicTuple(
field_for_schema(only(arg for arg in arguments if arg != Ellipsis))
field_for_schema(only(arg for arg in arguments if arg != Ellipsis)),
**field_args,
)
elif origin in (
dict,
Expand All @@ -275,47 +288,40 @@ def field_for_schema(
field = marshmallow.fields.Dict(
keys=field_for_schema(arguments[0]),
values=field_for_schema(arguments[1]),
**field_args,
)
elif typing_inspect.is_optional_type(typ):
[subtyp] = (t for t in arguments if t is not NoneType)
# Treat optional types as types with a None default
metadata[_DESERT_SENTINEL]["dump_default"] = metadata.get(
"dump_default", None
)
metadata[_DESERT_SENTINEL]["load_default"] = metadata.get(
"load_default", None
)
metadata[_DESERT_SENTINEL]["required"] = False

field = field_for_schema(subtyp, metadata=metadata, default=None)
field.dump_default = None
field.load_default = None
field.allow_none = True
metadata[_DESERT_SENTINEL]["allow_none"] = True
if default is marshmallow.missing:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that this behavior is quite counter intuitive, but if we want to keep backward compatibility this is needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want an optional value, then do the same as all python code and do x: Optional[int] = None

default = None
field = field_for_schema(subtyp, metadata=metadata, default=default)

elif typing_inspect.is_union_type(typ):
subfields = [field_for_schema(subtyp) for subtyp in arguments]
import marshmallow_union

field = marshmallow_union.Union(subfields)
field = marshmallow_union.Union(subfields, **field_args)

# t.NewType returns a function with a __supertype__ attribute
newtype_supertype = getattr(typ, "__supertype__", None)
if newtype_supertype and typing_inspect.is_new_type(typ):
metadata.setdefault("description", typ.__name__)
field = field_for_schema(newtype_supertype, default=default)
field = field_for_schema(newtype_supertype, metadata=metadata, default=default)

# enumerations
if type(typ) is enum.EnumMeta:
import marshmallow_enum

field = marshmallow_enum.EnumField(typ, metadata=metadata)
field = marshmallow_enum.EnumField(typ, metadata=metadata, **field_args)

# Nested dataclasses
forward_reference = getattr(typ, "__forward_arg__", None)

if field is None:
nested = forward_reference or class_schema(typ)
field = marshmallow.fields.Nested(nested)
field = marshmallow.fields.Nested(nested, **field_args)

field.metadata.update(metadata)

Expand Down Expand Up @@ -350,11 +356,11 @@ def _get_field_default(
if isinstance(field, dataclasses.Field):
# misc: https://github.com/python/mypy/issues/10750
# comparison-overlap: https://github.com/python/typeshed/pull/5900
if field.default_factory != dataclasses.MISSING:
return dataclasses.MISSING
if field.default is dataclasses.MISSING:
return marshmallow.missing
return field.default
if field.default_factory is not dataclasses.MISSING:
return field.default_factory
if field.default is not dataclasses.MISSING:
return field.default
return marshmallow.missing
elif isinstance(field, attr.Attribute):
if field.default == attr.NOTHING:
return marshmallow.missing
Expand Down
11 changes: 11 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ class A:
assert data == A(None) # type: ignore[call-arg]


def test_optional_default(module: DataclassModule) -> None:
"""Setting an optional type allows passing None."""

@module.dataclass
class A:
x: t.Optional[int] = 1

data = desert.schema_class(A)().load({})
assert data == A(1) # type: ignore[call-arg]


def test_custom_field(module: DataclassModule) -> None:
@module.dataclass
class A:
Expand Down