Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 438ecd9

Browse files
committedDec 13, 2024
Translate to NULLIF
Closes dotnet#31682
1 parent c099cef commit 438ecd9

File tree

6 files changed

+247
-0
lines changed

6 files changed

+247
-0
lines changed
 

‎EFCore.sln.DotSettings

+2
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ The .NET Foundation licenses this file to you under the MIT license.
356356
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery/@EntryIndexedValue">True</s:Boolean>
357357
<s:Boolean x:Key="/Default/UserDictionary/Words/=subquery_0027s/@EntryIndexedValue">True</s:Boolean>
358358
<s:Boolean x:Key="/Default/UserDictionary/Words/=transactionality/@EntryIndexedValue">True</s:Boolean>
359+
<s:Boolean x:Key="/Default/UserDictionary/Words/=uncoalescing/@EntryIndexedValue">True</s:Boolean>
359360
<s:Boolean x:Key="/Default/UserDictionary/Words/=unconfigured/@EntryIndexedValue">True</s:Boolean>
361+
<s:Boolean x:Key="/Default/UserDictionary/Words/=unequality/@EntryIndexedValue">True</s:Boolean>
360362
<s:Boolean x:Key="/Default/UserDictionary/Words/=unignore/@EntryIndexedValue">True</s:Boolean>
361363
<s:Boolean x:Key="/Default/UserDictionary/Words/=fixup/@EntryIndexedValue">True</s:Boolean>
362364
<s:Boolean x:Key="/Default/UserDictionary/Words/=attacher/@EntryIndexedValue">True</s:Boolean>

‎src/EFCore.Relational/Query/SqlExpressionFactory.cs

+47
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,31 @@ public virtual SqlExpression Case(
825825
elseResult = lastCase.ElseResult;
826826
}
827827

828+
// Optimize:
829+
// a == b ? null : a -> NULLIF(a, b)
830+
// a != b ? a : null -> NULLIF(a, b)
831+
if (operand is null
832+
&& typeMappedWhenClauses is
833+
[
834+
{
835+
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
836+
Result: var result
837+
}
838+
])
839+
{
840+
switch (binary.OperatorType)
841+
{
842+
case ExpressionType.Equal
843+
when result is SqlConstantExpression { Value: null }
844+
&& elseResult is not null
845+
&& TryTranslateToNullIf(elseResult, out var nullIfTranslation):
846+
case ExpressionType.NotEqual
847+
when elseResult is null or SqlConstantExpression { Value: null }
848+
&& TryTranslateToNullIf(result, out nullIfTranslation):
849+
return nullIfTranslation;
850+
}
851+
}
852+
828853
return existingExpression is CaseExpression expr
829854
&& operand == expr.Operand
830855
&& typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
@@ -837,6 +862,28 @@ bool IsSkipped(CaseWhenClause clause)
837862

838863
bool IsMatched(CaseWhenClause clause)
839864
=> operand is null && clause.Test is SqlConstantExpression { Value: true };
865+
866+
bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out SqlExpression? nullIfTranslation)
867+
{
868+
var (left, right) = (binary.Left, binary.Right);
869+
870+
if (left.Equals(conditionalResult) && right is not SqlConstantExpression { Value: null })
871+
{
872+
nullIfTranslation = Function(
873+
"NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping);
874+
return true;
875+
}
876+
877+
if (right.Equals(conditionalResult) && right is not SqlConstantExpression { Value: null })
878+
{
879+
nullIfTranslation = Function(
880+
"NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping);
881+
return true;
882+
}
883+
884+
nullIfTranslation = null;
885+
return false;
886+
}
840887
}
841888

842889
/// <inheritdoc />

‎test/EFCore.Cosmos.FunctionalTests/Query/Translations/MiscellaneousTranslationsCosmosTest.cs

+60
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,66 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp
175175

176176
#endregion Compare
177177

178+
#region Uncoalescing conditional / NullIf
179+
180+
public override Task Uncoalescing_conditional_with_equality_left(bool async)
181+
=> Fixture.NoSyncTest(
182+
async, async a =>
183+
{
184+
await base.Uncoalescing_conditional_with_equality_left(a);
185+
186+
AssertSql(
187+
"""
188+
SELECT VALUE c
189+
FROM root c
190+
WHERE (((c["Int"] = 9) ? null : c["Int"]) > 1)
191+
""");
192+
});
193+
194+
public override Task Uncoalescing_conditional_with_equality_right(bool async)
195+
=> Fixture.NoSyncTest(
196+
async, async a =>
197+
{
198+
await base.Uncoalescing_conditional_with_equality_right(a);
199+
200+
AssertSql(
201+
"""
202+
SELECT VALUE c
203+
FROM root c
204+
WHERE (((9 = c["Int"]) ? null : c["Int"]) > 1)
205+
""");
206+
});
207+
208+
public override Task Uncoalescing_conditional_with_unequality_left(bool async)
209+
=> Fixture.NoSyncTest(
210+
async, async a =>
211+
{
212+
await base.Uncoalescing_conditional_with_unequality_left(a);
213+
214+
AssertSql(
215+
"""
216+
SELECT VALUE c
217+
FROM root c
218+
WHERE (((c["Int"] != 9) ? c["Int"] : null) > 1)
219+
""");
220+
});
221+
222+
public override Task Uncoalescing_conditional_with_inequality_right(bool async)
223+
=> Fixture.NoSyncTest(
224+
async, async a =>
225+
{
226+
await base.Uncoalescing_conditional_with_inequality_right(a);
227+
228+
AssertSql(
229+
"""
230+
SELECT VALUE c
231+
FROM root c
232+
WHERE (((9 != c["Int"]) ? c["Int"] : null) > 1)
233+
""");
234+
});
235+
236+
#endregion Uncoalescing conditional / NullIf
237+
178238
[ConditionalFact]
179239
public virtual void Check_all_tests_overridden()
180240
=> TestHelpers.AssertAllMethodsOverridden(GetType());

‎test/EFCore.Specification.Tests/Query/Translations/MiscellaneousTranslationsTestBase.cs

+34
Original file line numberDiff line numberDiff line change
@@ -429,4 +429,38 @@ await AssertQuery(
429429
}
430430

431431
#endregion
432+
433+
#region Uncoalescing conditional
434+
435+
// In relational providers, x == a ? null : x is translated to SQL NULLIF
436+
437+
[Theory]
438+
[MemberData(nameof(IsAsyncData))]
439+
public virtual Task Uncoalescing_conditional_with_equality_left(bool async)
440+
=> AssertQuery(
441+
async,
442+
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int == 9 ? null : x.Int) > 1));
443+
444+
[Theory]
445+
[MemberData(nameof(IsAsyncData))]
446+
public virtual Task Uncoalescing_conditional_with_equality_right(bool async)
447+
=> AssertQuery(
448+
async,
449+
cs => cs.Set<BasicTypesEntity>().Where(x => (9 == x.Int ? null : x.Int) > 1));
450+
451+
[Theory]
452+
[MemberData(nameof(IsAsyncData))]
453+
public virtual Task Uncoalescing_conditional_with_unequality_left(bool async)
454+
=> AssertQuery(
455+
async,
456+
cs => cs.Set<BasicTypesEntity>().Where(x => (x.Int != 9 ? x.Int : null) > 1));
457+
458+
[Theory]
459+
[MemberData(nameof(IsAsyncData))]
460+
public virtual Task Uncoalescing_conditional_with_inequality_right(bool async)
461+
=> AssertQuery(
462+
async,
463+
cs => cs.Set<BasicTypesEntity>().Where(x => (9 != x.Int ? x.Int : null) > 1));
464+
465+
#endregion Uncoalescing conditional
432466
}

‎test/EFCore.SqlServer.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqlServerTest.cs

+52
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,58 @@ FROM [BasicTypesEntities] AS [b]
803803

804804
#endregion Compare
805805

806+
#region Uncoalescing conditional / NullIf
807+
808+
public override async Task Uncoalescing_conditional_with_equality_left(bool async)
809+
{
810+
await base.Uncoalescing_conditional_with_equality_left(async);
811+
812+
AssertSql(
813+
"""
814+
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
815+
FROM [BasicTypesEntities] AS [b]
816+
WHERE NULLIF([b].[Int], 9) > 1
817+
""");
818+
}
819+
820+
public override async Task Uncoalescing_conditional_with_equality_right(bool async)
821+
{
822+
await base.Uncoalescing_conditional_with_equality_right(async);
823+
824+
AssertSql(
825+
"""
826+
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
827+
FROM [BasicTypesEntities] AS [b]
828+
WHERE NULLIF([b].[Int], 9) > 1
829+
""");
830+
}
831+
832+
public override async Task Uncoalescing_conditional_with_unequality_left(bool async)
833+
{
834+
await base.Uncoalescing_conditional_with_unequality_left(async);
835+
836+
AssertSql(
837+
"""
838+
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
839+
FROM [BasicTypesEntities] AS [b]
840+
WHERE NULLIF([b].[Int], 9) > 1
841+
""");
842+
}
843+
844+
public override async Task Uncoalescing_conditional_with_inequality_right(bool async)
845+
{
846+
await base.Uncoalescing_conditional_with_inequality_right(async);
847+
848+
AssertSql(
849+
"""
850+
SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan]
851+
FROM [BasicTypesEntities] AS [b]
852+
WHERE NULLIF([b].[Int], 9) > 1
853+
""");
854+
}
855+
856+
#endregion Uncoalescing conditional / NullIf
857+
806858
[ConditionalFact]
807859
public virtual void Check_all_tests_overridden()
808860
=> TestHelpers.AssertAllMethodsOverridden(GetType());

‎test/EFCore.Sqlite.FunctionalTests/Query/Translations/MiscellaneousTranslationsSqliteTest.cs

+52
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,58 @@ public override async Task TimeSpan_Compare_to_simple_zero(bool async, bool comp
274274

275275
#endregion Compare
276276

277+
#region Uncoalescing conditional / NullIf
278+
279+
public override async Task Uncoalescing_conditional_with_equality_left(bool async)
280+
{
281+
await base.Uncoalescing_conditional_with_equality_left(async);
282+
283+
AssertSql(
284+
"""
285+
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
286+
FROM "BasicTypesEntities" AS "b"
287+
WHERE NULLIF("b"."Int", 9) > 1
288+
""");
289+
}
290+
291+
public override async Task Uncoalescing_conditional_with_equality_right(bool async)
292+
{
293+
await base.Uncoalescing_conditional_with_equality_right(async);
294+
295+
AssertSql(
296+
"""
297+
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
298+
FROM "BasicTypesEntities" AS "b"
299+
WHERE NULLIF("b"."Int", 9) > 1
300+
""");
301+
}
302+
303+
public override async Task Uncoalescing_conditional_with_unequality_left(bool async)
304+
{
305+
await base.Uncoalescing_conditional_with_unequality_left(async);
306+
307+
AssertSql(
308+
"""
309+
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
310+
FROM "BasicTypesEntities" AS "b"
311+
WHERE NULLIF("b"."Int", 9) > 1
312+
""");
313+
}
314+
315+
public override async Task Uncoalescing_conditional_with_inequality_right(bool async)
316+
{
317+
await base.Uncoalescing_conditional_with_inequality_right(async);
318+
319+
AssertSql(
320+
"""
321+
SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan"
322+
FROM "BasicTypesEntities" AS "b"
323+
WHERE NULLIF("b"."Int", 9) > 1
324+
""");
325+
}
326+
327+
#endregion Uncoalescing conditional / NullIf
328+
277329
[ConditionalFact]
278330
public virtual void Check_all_tests_overridden()
279331
=> TestHelpers.AssertAllMethodsOverridden(GetType());

0 commit comments

Comments
 (0)
Please sign in to comment.