From 068b81c1547876c6177174f9a5be777136938d7b Mon Sep 17 00:00:00 2001 From: dem4ply Date: Thu, 30 Jan 2020 10:15:29 -0600 Subject: [PATCH] restore the behavior of parameter many in the serializers --- rest_marshmallow/__init__.py | 13 ++++++++++--- tests/test_marshmallow.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/rest_marshmallow/__init__.py b/rest_marshmallow/__init__.py index 71b9706..0d72391 100644 --- a/rest_marshmallow/__init__.py +++ b/rest_marshmallow/__init__.py @@ -2,7 +2,8 @@ from marshmallow import Schema as MarshmallowSchema from marshmallow import fields # noqa from marshmallow.exceptions import ValidationError as MarshmallowValidationError -from rest_framework.serializers import BaseSerializer, ValidationError +from rest_framework.serializers import BaseSerializer, ValidationError, ListSerializer +from types import MethodType IS_MARSHMALLOW_LT_3 = int(marshmallow.__version__.split('.')[0]) < 3 @@ -15,13 +16,19 @@ ) +def _dump(self, obj, *args, many=None, **kwargs): + return [self.child.dump(o) for o in obj] + + class Schema(BaseSerializer, MarshmallowSchema): def __new__(cls, *args, **kwargs): # We're overriding the DRF implementation here, because ListSerializer # clashes with Nested implementation. - kwargs.pop('many', False) - return super(Schema, cls).__new__(cls, *args, **kwargs) + result = super(Schema, cls).__new__(cls, *args, **kwargs) + if isinstance(result, ListSerializer): + result.dump = MethodType(_dump, result) + return result def __init__(self, *args, **kwargs): schema_kwargs = { diff --git a/tests/test_marshmallow.py b/tests/test_marshmallow.py index e6197ec..8891749 100644 --- a/tests/test_marshmallow.py +++ b/tests/test_marshmallow.py @@ -80,6 +80,21 @@ def test_deserialize(): assert serializer.validated_data == {'number': 123, 'text': 'abc'} +def test_deserialize_many(): + data = [ + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + ] + serializer = ExampleSerializer(data=data, many=True) + assert serializer.is_valid() + assert serializer.validated_data == [ + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + ] + + def test_deserialize_validation_failed(): data = {'number': 'abc', 'text': 'abc'} serializer = ExampleSerializer(data=data) @@ -105,6 +120,27 @@ def test_create(): assert serializer.data == {'number': 123, 'text': 'abc'} +def test_create_many(): + data = [ + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + ] + serializer = ExampleSerializer(data=data, many=True) + assert serializer.is_valid() + instances = serializer.save() + assert instances + for instance in instances: + assert isinstance(instance, Object) + assert instance.number == 123 + assert instance.text == 'abc' + assert serializer.data == [ + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + {'number': 123, 'text': 'abc'}, + ] + + def test_update(): instance = Object(number=123, text='abc') data = {'number': 456, 'text': 'def'}