Skip to content

Commit 4b67387

Browse files
author
Oleksandr Poliakov
committed
CSHARP-4779: Support Dictionary(IEnumerable<KeyValuePair<TKey, TValue>> collection) constructor in LINQ3
1 parent a09e9c5 commit 4b67387

File tree

5 files changed

+354
-0
lines changed

5 files changed

+354
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

+57
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Generic;
1718
using System.Linq;
1819
using MongoDB.Bson;
1920
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
@@ -454,8 +455,42 @@ public override AstNode VisitMapExpression(AstMapExpression node)
454455
}
455456
}
456457

458+
if (node.In is AstComputedDocumentExpression inComputedDocumentExpression &&
459+
inComputedDocumentExpression.Fields.All(f => f.Value is AstGetFieldExpression getFieldExpression && getFieldExpression.Input == node.As && getFieldExpression.CanBeConvertedToFieldPath()))
460+
{
461+
462+
// { $map : { input : { $map : { input : <input>, as : "y", in : { A : "$$y.FieldA" } } }, as: "v", in : { B : '$$v.A' } } } => { $map : { input : <input>, as: "v", in : { B : "$$v.FieldA" } } }
463+
if (node.Input is AstMapExpression inputMapExpression &&
464+
inputMapExpression.In is AstComputedDocumentExpression innerInComputedDocumentExpression)
465+
{
466+
var simplified = AstExpression.Map(
467+
inputMapExpression.Input,
468+
inputMapExpression.As,
469+
AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, innerInComputedDocumentExpression.Fields))));
470+
471+
return Visit(simplified);
472+
}
473+
474+
// { $map : { input : [{ A: "$$ROOT.FieldA" }], as : "v", in: { B : "$$v.A" } } } => [{ B : "$FieldA }]
475+
if (node.Input is AstComputedArrayExpression inputArrayExpression &&
476+
inputArrayExpression.Items.All(i => i is AstComputedDocumentExpression))
477+
{
478+
var simplified = AstExpression.ComputedArray(inputArrayExpression.Items.Select(i =>
479+
AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, ((AstComputedDocumentExpression)i).Fields)))));
480+
return Visit(simplified);
481+
}
482+
}
483+
457484
return base.VisitMapExpression(node);
458485

486+
static AstComputedField RemapField(AstComputedField field, string @as, IEnumerable<AstComputedField> innerFields)
487+
{
488+
var fieldPath = ((AstGetFieldExpression)field.Value).ConvertToFieldPath().Replace($"$${@as}.", string.Empty);
489+
var innerField = innerFields.Single(f => f.Path == fieldPath);
490+
491+
return AstExpression.ComputedField(field.Path, innerField.Value);
492+
}
493+
459494
static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField)
460495
{
461496
if (getField.Input is AstGetFieldExpression nestedInputGetField)
@@ -574,7 +609,29 @@ arg is AstBinaryExpression argBinaryExpression &&
574609
return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2);
575610
}
576611

612+
// { $arrayToObject : [[{ k : 'A', v : '$A' }, { k : 'B', v : '$B' }]] } => { A : '$A', B : '$B' }
613+
if (node.Operator is AstUnaryOperator.ArrayToObject &&
614+
arg is AstComputedArrayExpression computedArrayExpression &&
615+
computedArrayExpression.Items.All(
616+
i => i is AstComputedDocumentExpression computedDocumentExpression &&
617+
computedDocumentExpression.Fields.FirstOrDefault(f => f.Path == "k")?.Value is AstConstantExpression &&
618+
computedDocumentExpression.Fields.Any(f => f.Path == "v"))
619+
)
620+
{
621+
var fields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField);
622+
return AstExpression.ComputedDocument(fields);
623+
}
624+
577625
return node.Update(arg);
626+
627+
static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression)
628+
{
629+
var documentExpression = (AstComputedDocumentExpression)expression;
630+
var keyExpression = documentExpression.Fields.First(f => f.Path == "k").Value;
631+
var valueExpression = documentExpression.Fields.First(f => f.Path == "v").Value;
632+
633+
return AstExpression.ComputedField(((AstConstantExpression)keyExpression).Value.AsString, valueExpression);
634+
}
578635
}
579636
}
580637
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using System.Reflection;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
19+
20+
namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection
21+
{
22+
internal static class DictionaryConstructor
23+
{
24+
// public static methods
25+
public static bool IsIEnumerableKeyValuePairConstructor(ConstructorInfo ctor)
26+
{
27+
var parameters = ctor.GetParameters();
28+
return parameters.Length == 1 &&
29+
parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) &&
30+
enumerableType.IsConstructedGenericType &&
31+
enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>);
32+
}
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq;
19+
using System.Linq.Expressions;
20+
using MongoDB.Bson;
21+
using MongoDB.Bson.Serialization;
22+
using MongoDB.Bson.Serialization.Options;
23+
using MongoDB.Bson.Serialization.Serializers;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
25+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
26+
27+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
28+
{
29+
internal static class NewDictionaryExpressionToAggregationExpressionTranslator
30+
{
31+
public static TranslatedExpression Translate(TranslationContext context, NewExpression expression)
32+
{
33+
var arguments = expression.Arguments;
34+
var collectionExpression = arguments.Single();
35+
var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression);
36+
37+
if (collectionTranslation.Serializer is IBsonArraySerializer bsonArraySerializer &&
38+
bsonArraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo))
39+
{
40+
IBsonSerializer keySerializer = null;
41+
IBsonSerializer valueSerializer = null;
42+
AstExpression collectionTranslationAst;
43+
44+
if (itemSerializationInfo.Serializer is IRepresentationConfigurable { Representation: BsonType.Array })
45+
{
46+
collectionTranslationAst = collectionTranslation.Ast;
47+
}
48+
else if (itemSerializationInfo.Serializer is IBsonDocumentSerializer itemDocumentSerializer)
49+
{
50+
if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo) ||
51+
!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo))
52+
{
53+
throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not provide member serialization info for required fields.");
54+
}
55+
56+
if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v")
57+
{
58+
collectionTranslationAst = collectionTranslation.Ast;
59+
}
60+
else
61+
{
62+
keySerializer = keyMemberSerializationInfo.Serializer;
63+
valueSerializer = valueMemberSerializationInfo.Serializer;
64+
65+
var pairVar = AstExpression.Var("pair");
66+
var computedDocumentAst = AstExpression.ComputedDocument([
67+
AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)),
68+
AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName))
69+
]);
70+
collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst);
71+
}
72+
}
73+
else
74+
{
75+
throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not implement {nameof(IBsonDocumentSerializer)}");
76+
}
77+
78+
if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String })
79+
{
80+
throw new ExpressionNotSupportedException(expression, because: "key did not serialize as a string");
81+
}
82+
83+
var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst);
84+
var resultSerializer = CreateDictionarySerializer(keySerializer, valueSerializer);
85+
return new TranslatedExpression(expression, ast, resultSerializer);
86+
}
87+
88+
throw new ExpressionNotSupportedException(expression);
89+
}
90+
91+
public static bool CanTranslate(NewExpression expression)
92+
=> expression.Type.IsConstructedGenericType &&
93+
expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) &&
94+
DictionaryConstructor.IsIEnumerableKeyValuePairConstructor(expression.Constructor);
95+
96+
private static IBsonSerializer CreateDictionarySerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer)
97+
{
98+
var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType);
99+
var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType);
100+
101+
return (IBsonSerializer)Activator.CreateInstance(serializerType, DictionaryRepresentation.Document, keySerializer, valueSerializer);
102+
}
103+
}
104+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr
5050
{
5151
return NewKeyValuePairExpressionToAggregationExpressionTranslator.Translate(context, expression);
5252
}
53+
if (NewDictionaryExpressionToAggregationExpressionTranslator.CanTranslate(expression))
54+
{
55+
return NewDictionaryExpressionToAggregationExpressionTranslator.Translate(context, expression);
56+
}
5357
return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty<MemberBinding>());
5458
}
5559
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER
17+
18+
using System;
19+
using System.Collections.Generic;
20+
using System.Linq;
21+
using FluentAssertions;
22+
using MongoDB.Bson;
23+
using MongoDB.Bson.Serialization.Attributes;
24+
using MongoDB.Driver.Linq;
25+
using MongoDB.Driver.TestHelpers;
26+
using Xunit;
27+
28+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
29+
{
30+
public class NewDictionaryExpressionToAggregationExpressionTranslatorTests : LinqIntegrationTest<NewDictionaryExpressionToAggregationExpressionTranslatorTests.ClassFixture>
31+
{
32+
public NewDictionaryExpressionToAggregationExpressionTranslatorTests(ClassFixture fixture)
33+
: base(fixture)
34+
{
35+
}
36+
37+
[Fact]
38+
public void NewDictionary_with_KeyValuePairs_should_translate()
39+
{
40+
var collection = Fixture.Collection;
41+
42+
var queryable = collection.AsQueryable()
43+
.Select(d => new Dictionary<string, string>(
44+
new[] { new KeyValuePair<string, string>("A", d.A), new KeyValuePair<string, string>("B", d.B) }));
45+
46+
var stages = Translate(collection, queryable);
47+
48+
AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }");
49+
50+
var result = queryable.Single();
51+
result.Should().Equal(new Dictionary<string, string>{ ["A"] = "a", ["B"] = "b" });
52+
}
53+
54+
[Fact]
55+
public void NewDictionary_with_KeyValuePairs_Create_should_translate()
56+
{
57+
var collection = Fixture.Collection;
58+
59+
var queryable = collection.AsQueryable()
60+
.Select(d => new Dictionary<string, string>(
61+
new[] { KeyValuePair.Create("A", d.A), KeyValuePair.Create("B", d.B) }));
62+
63+
var stages = Translate(collection, queryable);
64+
65+
AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }");
66+
67+
var result = queryable.Single();
68+
result.Should().Equal(new Dictionary<string, string>{ ["A"] = "a", ["B"] = "b" });
69+
}
70+
71+
[Fact]
72+
public void NewDictionary_with_KeyValuePairs_should_translate_Guid_as_string_key()
73+
{
74+
var collection = Fixture.Collection;
75+
76+
var queryable = collection.AsQueryable()
77+
.Select(d => new Dictionary<Guid, string>(
78+
new[] { new KeyValuePair<Guid, string>(d.GuidAsString, d.A) }));
79+
80+
var stages = Translate(collection, queryable);
81+
82+
AssertStages(stages, "{ $project : { _v : { $arrayToObject : [[{ k : '$GuidAsString', v : '$A' }]] }, _id : 0 } }");
83+
84+
var result = queryable.Single();
85+
result.Should().Equal(new Dictionary<Guid, string>{ [Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE")] = "a" });
86+
}
87+
88+
89+
[Fact]
90+
public void NewDictionary_with_KeyValuePairs_should_translate_dynamic_array()
91+
{
92+
var collection = Fixture.Collection;
93+
94+
var queryable = collection.AsQueryable()
95+
.Select(d => new Dictionary<string, string>(
96+
d.Items.Select(i => new KeyValuePair<string, string>(i.P, i.W))));
97+
98+
var stages = Translate(collection, queryable);
99+
100+
AssertStages(stages, "{ $project : { _v : { $arrayToObject : { $map: { input: '$Items', as: 'i', in: { k: '$$i.P', v: '$$i.W' } } } }, _id : 0 } }");
101+
102+
var result = queryable.Single();
103+
result.Should().Equal(new Dictionary<string, string>{ ["x"] = "y" });
104+
}
105+
106+
[Fact]
107+
public void NewDictionary_with_KeyValuePairs_throws_on_non_string_key()
108+
{
109+
var collection = Fixture.Collection;
110+
111+
var queryable = collection.AsQueryable()
112+
.Select(d => new Dictionary<int, string>(
113+
new[] { new KeyValuePair<int, string>(42, d.A) }));
114+
115+
var exception = Record.Exception(() => queryable.ToList());
116+
117+
exception.Should().NotBeNull();
118+
exception.Should().BeOfType<ExpressionNotSupportedException>();
119+
}
120+
121+
public class C
122+
{
123+
public string A { get; set; }
124+
125+
public string B { get; set; }
126+
127+
[BsonRepresentation(BsonType.String)]
128+
public Guid GuidAsString { get; set; }
129+
130+
public Item[] Items { get; set; }
131+
}
132+
133+
public class Item
134+
{
135+
public string P { get; set; }
136+
137+
public string W { get; set; }
138+
}
139+
140+
public sealed class ClassFixture : MongoCollectionFixture<C>
141+
{
142+
protected override IEnumerable<C> InitialData =>
143+
[
144+
new C
145+
{
146+
A = "a",
147+
B = "b",
148+
GuidAsString = Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE"),
149+
Items = [ new Item { P = "x", W = "y" } ]
150+
},
151+
];
152+
}
153+
}
154+
}
155+
#endif

0 commit comments

Comments
 (0)