From a6a4f9f0da9b51a4256ea937c124751033f7eb24 Mon Sep 17 00:00:00 2001 From: maumar Date: Sun, 5 Jan 2025 03:06:21 -0800 Subject: [PATCH] Fix to #35393 - GroupJoin in EF Core 9 Returns Null for Joined Entities Problem was that in EF9 we moved some optimizations from sql nullability processor to SqlExpressionFactory (so that we optimize things early). One of the optimizations: !(true = a) -> false = a !(false = a) -> true = a In principle, this should be valid in 3-value logic a (true = a) !(true = a) (false = a) 0 0 1 1 1 1 0 0 N N N N but not in c# semantics, both null == true and null == false are evaluated to false, so !(true == a) is not the same as false = a (first one is true, second one is false) Fix is to drop this optimization from SqlExpressionFactory and instead compensate in nullability processor (where we have full nullability information), so that we don't regress performance for cases in which this optimization happened to be valid. Fixes #35393 --- .../Query/SqlExpressionFactory.cs | 8 +- .../Query/SqlNullabilityProcessor.cs | 161 +++++++++++++----- .../Query/NullSemanticsQueryTestBase.cs | 27 +++ .../Query/NorthwindJoinQueryTestBase.cs | 18 ++ .../Query/NorthwindJoinQuerySqlServerTest.cs | 16 ++ .../Query/NullSemanticsQuerySqlServerTest.cs | 36 ++++ 6 files changed, 221 insertions(+), 45 deletions(-) diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 47f94f1d78b..b03b8f7625a 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -9,6 +9,9 @@ namespace Microsoft.EntityFrameworkCore.Query; /// public class SqlExpressionFactory : ISqlExpressionFactory { + private static readonly bool UseOldBehavior35393 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393; + private readonly IRelationalTypeMappingSource _typeMappingSource; private readonly RelationalTypeMapping _boolTypeMapping; @@ -660,20 +663,23 @@ private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpressi SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary => AndAlso(Not(binary.Left), Not(binary.Right)), - // use equality where possible + // see #35393 - this optimization is not safe to do for c# null semantics // !(a == true) -> a == false // !(a == false) -> a == true SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary + when UseOldBehavior35393 => Equal(binary.Left, Not(binary.Right)), // !(true == a) -> false == a // !(false == a) -> true == a SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary + when UseOldBehavior35393 => Equal(Not(binary.Left), binary.Right), // !(a == b) -> a != b SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual( sqlBinaryOperand.Left, sqlBinaryOperand.Right), + // !(a != b) -> a == b SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal( sqlBinaryOperand.Left, sqlBinaryOperand.Right), diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 407826342f2..aca63c1e9d7 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -19,6 +19,9 @@ namespace Microsoft.EntityFrameworkCore.Query; /// public class SqlNullabilityProcessor { + private static readonly bool UseOldBehavior35393 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393; + private readonly List _nonNullableColumns; private readonly List _nullValueColumns; private readonly ISqlExpressionFactory _sqlExpressionFactory; @@ -1326,6 +1329,7 @@ protected virtual SqlExpression VisitSqlBinary( right, leftNullable, rightNullable, + optimize, out nullable); if (optimized is SqlUnaryExpression { Operand: ColumnExpression optimizedUnaryColumnOperand } optimizedUnary) @@ -1343,7 +1347,7 @@ protected virtual SqlExpression VisitSqlBinary( // we assume that NullSemantics rewrite is only needed (on the current level) // if the optimization didn't make any changes. // Reason is that optimization can/will change the nullability of the resulting expression - // and that inforation is not tracked/stored anywhere + // and that information is not tracked/stored anywhere // so we can no longer rely on nullabilities that we computed earlier (leftNullable, rightNullable) // when performing null semantics rewrite. // It should be fine because current optimizations *radically* change the expression @@ -1678,6 +1682,7 @@ private SqlExpression ProcessJoinPredicate(SqlExpression predicate) right, leftNullable, rightNullable, + optimize: true, out _); return result; @@ -1704,6 +1709,7 @@ private SqlExpression OptimizeComparison( SqlExpression right, bool leftNullable, bool rightNullable, + bool optimize, out bool nullable) { var leftNullValue = leftNullable && left is SqlConstantExpression or SqlParameterExpression; @@ -1784,37 +1790,63 @@ private SqlExpression OptimizeComparison( && !rightNullable && sqlBinaryExpression.OperatorType is ExpressionType.Equal or ExpressionType.NotEqual) { - var leftUnary = left as SqlUnaryExpression; - var rightUnary = right as SqlUnaryExpression; + nullable = false; - var leftNegated = IsLogicalNot(leftUnary); - var rightNegated = IsLogicalNot(rightUnary); + return OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); + } - if (leftNegated) - { - left = leftUnary!.Operand; - } + nullable = false; - if (rightNegated) - { - right = rightUnary!.Operand; - } + return sqlBinaryExpression.Update(left, right); + } - // a == b <=> !a == !b -> a == b - // !a == b <=> a == !b -> a != b - // a != b <=> !a != !b -> a != b - // !a != b <=> a != !b -> a == b + private SqlExpression OptimizeBooleanComparison( + SqlBinaryExpression sqlBinaryExpression, + SqlExpression left, + SqlExpression right, + bool optimize) + { + var leftUnary = left as SqlUnaryExpression; + var rightUnary = right as SqlUnaryExpression; - nullable = false; + var leftNegated = IsLogicalNot(leftUnary); + var rightNegated = IsLogicalNot(rightUnary); - return sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated - ? _sqlExpressionFactory.NotEqual(left, right) - : _sqlExpressionFactory.Equal(left, right); + if (leftNegated) + { + left = leftUnary!.Operand; + } + if (rightNegated) + { + right = rightUnary!.Operand; } - nullable = false; + var notEqual = sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated; + if (!UseOldBehavior35393) + { + // prefer equality in predicates when comparing to constants + if (optimize && notEqual && left.Type == typeof(bool) && (left is SqlConstantExpression || right is SqlConstantExpression)) + { + if (right is ColumnExpression && (left is not ColumnExpression || leftNegated)) + { + left = _sqlExpressionFactory.Not(left); + } + else + { + right = _sqlExpressionFactory.Not(right); + } - return sqlBinaryExpression.Update(left, right); + return _sqlExpressionFactory.Equal(left, right); + } + } + + // a == b <=> !a == !b -> a == b + // !a == b <=> a == !b -> a != b + // a != b <=> !a != !b -> a != b + // !a != b <=> a != !b -> a == b + return notEqual + ? _sqlExpressionFactory.NotEqual(left, right) + : _sqlExpressionFactory.Equal(left, right); } private SqlExpression RewriteNullSemantics( @@ -1832,31 +1864,41 @@ private SqlExpression RewriteNullSemantics( var leftNegated = IsLogicalNot(leftUnary); var rightNegated = IsLogicalNot(rightUnary); - if (leftNegated) + if (UseOldBehavior35393) { - left = leftUnary!.Operand; - } + if (leftNegated) + { + left = leftUnary!.Operand; + } - if (rightNegated) - { - right = rightUnary!.Operand; + if (rightNegated) + { + right = rightUnary!.Operand; + } } var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable); - var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull); + var leftIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(leftIsNull)); var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable); - var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull); + var rightIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(rightIsNull)); SqlExpression body; - if (leftNegated == rightNegated) + if (!UseOldBehavior35393) { - body = _sqlExpressionFactory.Equal(left, right); + body = OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); } else { - // a == !b and !a == b in SQL evaluate the same as a != b - body = _sqlExpressionFactory.NotEqual(left, right); + if (leftNegated == rightNegated) + { + body = _sqlExpressionFactory.Equal(left, right); + } + else + { + // a == !b and !a == b in SQL evaluate the same as a != b + body = _sqlExpressionFactory.NotEqual(left, right); + } } // optimized expansion which doesn't distinguish between null and false @@ -1870,6 +1912,12 @@ private SqlExpression RewriteNullSemantics( // doing a full null semantics rewrite - removing all nulls from truth table nullable = false; + if (!UseOldBehavior35393 && sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + // the factory takes care of simplifying equal <-> not-equal + body = _sqlExpressionFactory.Not(body); + } + // (a == b && (a != null && b != null)) || (a == null && b == null) body = _sqlExpressionFactory.OrElse( _sqlExpressionFactory.AndAlso(body, _sqlExpressionFactory.AndAlso(leftIsNotNull, rightIsNotNull)), @@ -1878,7 +1926,7 @@ private SqlExpression RewriteNullSemantics( if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) { // the factory takes care of simplifying using DeMorgan - body = _sqlExpressionFactory.Not(body); + body = OptimizeNotExpression(_sqlExpressionFactory.Not(body)); } return body; @@ -1900,14 +1948,39 @@ protected virtual SqlExpression OptimizeNotExpression(SqlExpression expression) // !(a >= b) -> a < b // !(a < b) -> a >= b // !(a <= b) -> a > b - if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand - && TryNegate(sqlBinaryOperand.OperatorType, out var negated)) - { - return _sqlExpressionFactory.MakeBinary( - negated, - sqlBinaryOperand.Left, - sqlBinaryOperand.Right, - sqlBinaryOperand.TypeMapping)!; + if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand) + { + if (TryNegate(sqlBinaryOperand.OperatorType, out var negated)) + { + return _sqlExpressionFactory.MakeBinary( + negated, + sqlBinaryOperand.Left, + sqlBinaryOperand.Right, + sqlBinaryOperand.TypeMapping)!; + } + + if (!UseOldBehavior35393) + { + // use equality where possible - at this point (true == null) and (false == null) have been converted to + // IS NULL / IS NOT NULL (i.e. false), so this optimization is safe to do. See #35393 + // !(a == true) -> a == false + // !(a == false) -> a == true + if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } }) + { + return _sqlExpressionFactory.Equal( + sqlBinaryOperand.Left, + OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Right))); + } + + // !(true == a) -> false == a + // !(false == a) -> true == a + if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } }) + { + return _sqlExpressionFactory.Equal( + OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Left)), + sqlBinaryOperand.Right); + } + } } // the factory can optimize most `NOT` expressions diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 4ac99509883..0374285ba6d 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -2462,6 +2462,33 @@ await AssertQueryScalar( ss => ss.Set().Where(e => true).Select(e => e.Id)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_nullable_column_negated(bool async) + => await AssertQueryScalar( + async, + ss => ss.Set().Where(x => !(true == x.NullableBoolA)).Select(x => x.Id)); + + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_non_nullable_column_negated(bool async) + => await AssertQueryScalar( + async, + ss => ss.Set().Where(x => !(true == x.BoolA)).Select(x => x.Id)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async) + { + var prm = default(bool?); + + await AssertQueryScalar( + async, + ss => ss.Set().Where(x => x.NullableBoolA != null + && !object.Equals(true, x.NullableBoolA == null ? null : prm)).Select(x => x.Id)); + } + // We can't client-evaluate Like (for the expected results). // However, since the test data has no LIKE wildcards, it effectively functions like equality - except that 'null like null' returns // false instead of true. So we have this "lite" implementation which doesn't support wildcards. diff --git a/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs index 6f3350b748f..571d9fcbac9 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs @@ -654,6 +654,24 @@ join o in ss.Set().OrderBy(o => o.OrderID).Take(100) on c.CustomerID equa from o in lo.Where(x => x.CustomerID.StartsWith("A")) select new { c.CustomerID, o.OrderID }); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupJoin_on_true_equal_true(bool async) + => AssertQuery( + async, + ss => ss.Set().GroupJoin( + ss.Set(), + x => true, + x => true, + (c, g) => new { c, g }) + .Select(x => new { x.c.CustomerID, Orders = x.g }), + elementSorter: e => e.CustomerID, + elementAsserter: (e, a) => + { + Assert.Equal(e.CustomerID, a.CustomerID); + AssertCollection(e.Orders, a.Orders, elementSorter: ee => ee.OrderID); + }); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs index 48c3758e51a..bf0747d92d6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs @@ -567,6 +567,22 @@ WHERE [o0].[CustomerID] LIKE N'A%' """); } + public override async Task GroupJoin_on_true_equal_true(bool async) + { + await base.GroupJoin_on_true_equal_true(async); + + AssertSql( + """ +SELECT [c].[CustomerID], [o0].[OrderID], [o0].[CustomerID], [o0].[EmployeeID], [o0].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + FROM [Orders] AS [o] +) AS [o0] +ORDER BY [c].[CustomerID] +"""); + } + public override async Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async) { await base.Inner_join_with_tautology_predicate_converts_to_cross_join(async); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 2581e0a3baa..bbfe95baec6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -4568,6 +4568,42 @@ FROM [Entities1] AS [e] """); } + public override async Task Compare_constant_true_to_nullable_column_negated(bool async) + { + await base.Compare_constant_true_to_nullable_column_negated(async); + + AssertSql( + """ + SELECT [e].[Id] + FROM [Entities1] AS [e] + WHERE CAST(0 AS bit) = [e].[NullableBoolA] OR [e].[NullableBoolA] IS NULL + """); + } + + public override async Task Compare_constant_true_to_non_nullable_column_negated(bool async) + { + await base.Compare_constant_true_to_non_nullable_column_negated(async); + + AssertSql( + """ +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[BoolA] = CAST(0 AS bit) +"""); + } + + public override async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async) + { + await base.Compare_constant_true_to_expression_which_evaluates_to_null(async); + + AssertSql( + """ +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableBoolA] IS NOT NULL +"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);