diff --git a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs index ca2fef6fe07..4d8f3c35394 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs @@ -75,7 +75,7 @@ public DiscriminatedInterfaceSerializer() /// interfaceType /// interfaceType public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention) - : this(discriminatorConvention, CreateInterfaceSerializer()) + : this(discriminatorConvention, CreateInterfaceSerializer(), objectSerializer: null) { } @@ -87,6 +87,19 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo /// interfaceType /// interfaceType public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer interfaceSerializer) + : this(discriminatorConvention, interfaceSerializer, objectSerializer: null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The discriminator convention. + /// The interface serializer (necessary to support LINQ queries). + /// The serializer that is used to serialize any objects. + /// interfaceType + /// interfaceType + public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer interfaceSerializer, IBsonSerializer objectSerializer) { var interfaceTypeInfo = typeof(TInterface).GetTypeInfo(); if (!interfaceTypeInfo.IsInterface) @@ -97,10 +110,14 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo _interfaceType = typeof(TInterface); _discriminatorConvention = discriminatorConvention ?? interfaceSerializer.GetDiscriminatorConvention(); - _objectSerializer = BsonSerializer.LookupSerializer(); + + _objectSerializer = objectSerializer ?? BsonSerializer.LookupSerializer(); if (_objectSerializer is ObjectSerializer standardObjectSerializer) { - _objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention); + Func allowedTypes = (Type type) => typeof(TInterface).IsAssignableFrom(type); + _objectSerializer = standardObjectSerializer + .WithDiscriminatorConvention(_discriminatorConvention) + .WithAllowedTypes(allowedTypes, allowedTypes); } else { diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 243e89a9687..532e10c1609 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -154,7 +154,7 @@ private static bool IsConvertToBaseType(Type sourceType, Type targetType) private static bool IsConvertToDerivedType(Type sourceType, Type targetType) { - return targetType.IsSubclassOf(sourceType); + return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface } private static bool IsConvertToNullableType(Type targetType) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs index 2fce6e52d1d..7a734d7a075 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs @@ -87,7 +87,7 @@ private static bool IsConvertToBaseType(Type fieldType, Type targetType) private static bool IsConvertToDerivedType(Type fieldType, Type targetType) { - return targetType.IsSubclassOfOrImplements(fieldType); + return fieldType.IsAssignableFrom(targetType); // targetType either derives from fieldType or implements fieldType interface } private static bool IsConvertToNullable(Type fieldType, Type targetType) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs index 73f8bec5b3a..219352693a2 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs @@ -192,7 +192,47 @@ public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work result.EnumAsNullableInt.Should().Be(2); } + [Fact] + public void Should_translate_from_base_interface_to_derived_class_on_method_call() + { + var collection = GetInterfaceCollection(); + var queryable = collection.AsQueryable() + .Select(p => new DerivedClass + { + Id = p.Id, + A = ((DerivedClass)p).A.ToUpper() + }); + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }"); + + var result = queryable.Single(); + result.Id.Should().Be(1); + result.A.Should().Be("ABC"); + } + + [Fact] + public void Should_translate_from_base_interface_to_derived_class_on_projection() + { + var collection = GetInterfaceCollection(); + var queryable = collection.AsQueryable() + .Select(p => new DerivedClass() + { + Id = p.Id, + A = ((DerivedClass)p).A + }); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ '$project' : { _id : '$_id', A : '$A' } }"); + + var result = queryable.Single(); + result.Id.Should().Be(1); + result.A.Should().Be("abc"); + } private IMongoCollection GetCollection() { @@ -209,7 +249,31 @@ private IMongoCollection GetCollection() return collection; } - private class BaseClass + private IMongoCollection GetInterfaceCollection() + { + var collection = GetCollection("test"); + CreateCollection(collection, new DerivedClass() + { + Id = 1, + A = "abc", + Enum = Enum.Two, + NullableEnum = Enum.Two, + EnumAsInt = 2, + EnumAsNullableInt = 2 + }); + return collection; + } + + private interface IBaseInterface + { + public int Id { get; set; } + public Enum Enum { get; set; } + public Enum? NullableEnum { get; set; } + public int EnumAsInt { get; set; } + public int? EnumAsNullableInt { get; set; } + } + + private class BaseClass : IBaseInterface { public int Id { get; set; } public Enum Enum { get; set; } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs index 1c2dfa911ef..d0b2c51cc60 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs @@ -86,6 +86,17 @@ public void Filter_using_convert_nullable_enum_to_underlying_type_should_work() result.Id.Should().Be(2); } + [Fact] + public void Filter_using_field_from_implementing_type_should_work() + { + var collection = GetInterfaceCollection(); + + var filter = Builders.Filter.Eq(x => ((Data)x).AdditionalValue, "value"); + + var result = collection.Find(filter).Single(); + result.Id.Should().Be(2); + } + private IMongoCollection GetCollection() { var collection = GetCollection("test"); @@ -96,13 +107,33 @@ private IMongoCollection GetCollection() return collection; } - private class Data + private IMongoCollection GetInterfaceCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection, + new Data { Id = 1, Enum = Enum.One, NullableEnum = Enum.One, EnumAsInt = 1, EnumAsNullableInt = 1 }, + new Data { Id = 2, Enum = Enum.Two, NullableEnum = Enum.Two, EnumAsInt = 2, EnumAsNullableInt = 2, AdditionalValue = "value"}); + return collection; + } + + private interface IData + { + public int Id { get; set; } + public Enum Enum { get; set; } + public Enum? NullableEnum { get; set; } + public int EnumAsInt { get; set; } + public int? EnumAsNullableInt { get; set; } + } + + private class Data : IData { public int Id { get; set; } public Enum Enum { get; set; } public Enum? NullableEnum { get; set; } public int EnumAsInt { get; set; } public int? EnumAsNullableInt { get; set; } + public string AdditionalValue { get; set; } } private enum Enum