From fc071005b241fc7d0574f0ef359bfe1325cd5c54 Mon Sep 17 00:00:00 2001 From: LucStr <25279790+LucStr@users.noreply.github.com> Date: Thu, 18 Jan 2024 19:52:31 +0100 Subject: [PATCH 1/3] CSHARP-4935 allow Linq Translation conversion from interface to derived type. --- .../DiscriminatedInterfaceSerializer.cs | 18 ++++- ...essionToAggregationExpressionTranslator.cs | 2 +- ...onvertExpressionToFilterFieldTranslator.cs | 2 +- ...nToAggregationExpressionTranslatorTests.cs | 66 ++++++++++++++++++- ...onvertExpressionToFilterTranslatorTests.cs | 33 +++++++++- 5 files changed, 115 insertions(+), 6 deletions(-) diff --git a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs index ca2fef6fe07..63869a8582a 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs @@ -75,7 +75,7 @@ public DiscriminatedInterfaceSerializer() /// interfaceType /// interfaceType public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention) - : this(discriminatorConvention, CreateInterfaceSerializer()) + : this(discriminatorConvention, CreateInterfaceSerializer(), objectSerializer: null) { } @@ -87,6 +87,19 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo /// interfaceType /// interfaceType public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer interfaceSerializer) + : this(discriminatorConvention, interfaceSerializer, objectSerializer: null) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The discriminator convention. + /// The interface serializer (necessary to support LINQ queries). + /// The serializer that is used to serialize any objects. + /// interfaceType + /// interfaceType + public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorConvention, IBsonSerializer interfaceSerializer, IBsonSerializer objectSerializer) { var interfaceTypeInfo = typeof(TInterface).GetTypeInfo(); if (!interfaceTypeInfo.IsInterface) @@ -97,7 +110,8 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo _interfaceType = typeof(TInterface); _discriminatorConvention = discriminatorConvention ?? interfaceSerializer.GetDiscriminatorConvention(); - _objectSerializer = BsonSerializer.LookupSerializer(); + _objectSerializer = objectSerializer ?? new ObjectSerializer(allowedTypes: type => typeof(TInterface).IsAssignableFrom(type)); + if (_objectSerializer is ObjectSerializer standardObjectSerializer) { _objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs index 243e89a9687..532e10c1609 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslator.cs @@ -154,7 +154,7 @@ private static bool IsConvertToBaseType(Type sourceType, Type targetType) private static bool IsConvertToDerivedType(Type sourceType, Type targetType) { - return targetType.IsSubclassOf(sourceType); + return sourceType.IsAssignableFrom(targetType); // targetType either derives from sourceType or implements sourceType interface } private static bool IsConvertToNullableType(Type targetType) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs index 2fce6e52d1d..7a734d7a075 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ToFilterFieldTranslators/ConvertExpressionToFilterFieldTranslator.cs @@ -87,7 +87,7 @@ private static bool IsConvertToBaseType(Type fieldType, Type targetType) private static bool IsConvertToDerivedType(Type fieldType, Type targetType) { - return targetType.IsSubclassOfOrImplements(fieldType); + return fieldType.IsAssignableFrom(targetType); // targetType either derives from fieldType or implements fieldType interface } private static bool IsConvertToNullable(Type fieldType, Type targetType) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs index 73f8bec5b3a..219352693a2 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConvertExpressionToAggregationExpressionTranslatorTests.cs @@ -192,7 +192,47 @@ public void Project_using_convert_nullable_enum_to_nullable_underlying_type_work result.EnumAsNullableInt.Should().Be(2); } + [Fact] + public void Should_translate_from_base_interface_to_derived_class_on_method_call() + { + var collection = GetInterfaceCollection(); + var queryable = collection.AsQueryable() + .Select(p => new DerivedClass + { + Id = p.Id, + A = ((DerivedClass)p).A.ToUpper() + }); + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ '$project' : { _id : '$_id', A : { '$toUpper' : '$A' } } }"); + + var result = queryable.Single(); + result.Id.Should().Be(1); + result.A.Should().Be("ABC"); + } + + [Fact] + public void Should_translate_from_base_interface_to_derived_class_on_projection() + { + var collection = GetInterfaceCollection(); + var queryable = collection.AsQueryable() + .Select(p => new DerivedClass() + { + Id = p.Id, + A = ((DerivedClass)p).A + }); + + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ '$project' : { _id : '$_id', A : '$A' } }"); + + var result = queryable.Single(); + result.Id.Should().Be(1); + result.A.Should().Be("abc"); + } private IMongoCollection GetCollection() { @@ -209,7 +249,31 @@ private IMongoCollection GetCollection() return collection; } - private class BaseClass + private IMongoCollection GetInterfaceCollection() + { + var collection = GetCollection("test"); + CreateCollection(collection, new DerivedClass() + { + Id = 1, + A = "abc", + Enum = Enum.Two, + NullableEnum = Enum.Two, + EnumAsInt = 2, + EnumAsNullableInt = 2 + }); + return collection; + } + + private interface IBaseInterface + { + public int Id { get; set; } + public Enum Enum { get; set; } + public Enum? NullableEnum { get; set; } + public int EnumAsInt { get; set; } + public int? EnumAsNullableInt { get; set; } + } + + private class BaseClass : IBaseInterface { public int Id { get; set; } public Enum Enum { get; set; } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs index 1c2dfa911ef..d0b2c51cc60 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ConvertExpressionToFilterTranslatorTests.cs @@ -86,6 +86,17 @@ public void Filter_using_convert_nullable_enum_to_underlying_type_should_work() result.Id.Should().Be(2); } + [Fact] + public void Filter_using_field_from_implementing_type_should_work() + { + var collection = GetInterfaceCollection(); + + var filter = Builders.Filter.Eq(x => ((Data)x).AdditionalValue, "value"); + + var result = collection.Find(filter).Single(); + result.Id.Should().Be(2); + } + private IMongoCollection GetCollection() { var collection = GetCollection("test"); @@ -96,13 +107,33 @@ private IMongoCollection GetCollection() return collection; } - private class Data + private IMongoCollection GetInterfaceCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection, + new Data { Id = 1, Enum = Enum.One, NullableEnum = Enum.One, EnumAsInt = 1, EnumAsNullableInt = 1 }, + new Data { Id = 2, Enum = Enum.Two, NullableEnum = Enum.Two, EnumAsInt = 2, EnumAsNullableInt = 2, AdditionalValue = "value"}); + return collection; + } + + private interface IData + { + public int Id { get; set; } + public Enum Enum { get; set; } + public Enum? NullableEnum { get; set; } + public int EnumAsInt { get; set; } + public int? EnumAsNullableInt { get; set; } + } + + private class Data : IData { public int Id { get; set; } public Enum Enum { get; set; } public Enum? NullableEnum { get; set; } public int EnumAsInt { get; set; } public int? EnumAsNullableInt { get; set; } + public string AdditionalValue { get; set; } } private enum Enum From 7f1f1dc590334176efbd54865fde726ba7984174 Mon Sep 17 00:00:00 2001 From: rstam Date: Wed, 16 Apr 2025 15:49:47 -0400 Subject: [PATCH 2/3] CSHARP-4935: Requested changes. --- .../Serializers/DiscriminatedInterfaceSerializer.cs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs index 63869a8582a..4d8f3c35394 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs @@ -110,11 +110,14 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo _interfaceType = typeof(TInterface); _discriminatorConvention = discriminatorConvention ?? interfaceSerializer.GetDiscriminatorConvention(); - _objectSerializer = objectSerializer ?? new ObjectSerializer(allowedTypes: type => typeof(TInterface).IsAssignableFrom(type)); + _objectSerializer = objectSerializer ?? BsonSerializer.LookupSerializer(); if (_objectSerializer is ObjectSerializer standardObjectSerializer) { - _objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention); + Func allowedTypes = (Type type) => typeof(TInterface).IsAssignableFrom(type); + _objectSerializer = standardObjectSerializer + .WithDiscriminatorConvention(_discriminatorConvention) + .WithAllowedTypes(allowedTypes, allowedTypes); } else { From 1f28a1796a9ae99422b0a47abf3fe32f739fba83 Mon Sep 17 00:00:00 2001 From: rstam Date: Thu, 17 Apr 2025 09:47:35 -0400 Subject: [PATCH 3/3] CSHARP-4935: Review changes. --- .../DiscriminatedInterfaceSerializer.cs | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs index 4d8f3c35394..eb0035489e0 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs @@ -110,24 +110,25 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo _interfaceType = typeof(TInterface); _discriminatorConvention = discriminatorConvention ?? interfaceSerializer.GetDiscriminatorConvention(); + _interfaceSerializer = interfaceSerializer; - _objectSerializer = objectSerializer ?? BsonSerializer.LookupSerializer(); - if (_objectSerializer is ObjectSerializer standardObjectSerializer) - { - Func allowedTypes = (Type type) => typeof(TInterface).IsAssignableFrom(type); - _objectSerializer = standardObjectSerializer - .WithDiscriminatorConvention(_discriminatorConvention) - .WithAllowedTypes(allowedTypes, allowedTypes); - } - else + if (objectSerializer == null) { - if (discriminatorConvention != null) + objectSerializer = BsonSerializer.LookupSerializer(); + if (objectSerializer is ObjectSerializer standardObjectSerializer) + { + Func allowedTypes = (Type type) => typeof(TInterface).IsAssignableFrom(type); + objectSerializer = standardObjectSerializer + .WithDiscriminatorConvention(_discriminatorConvention) + .WithAllowedTypes(allowedTypes, allowedTypes); + } + else { throw new BsonSerializationException("Can't set discriminator convention on custom object serializer."); } } - _interfaceSerializer = interfaceSerializer; + _objectSerializer = objectSerializer; } // public properties