diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 47f94f1d78b..b03b8f7625a 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -9,6 +9,9 @@ namespace Microsoft.EntityFrameworkCore.Query; /// public class SqlExpressionFactory : ISqlExpressionFactory { + private static readonly bool UseOldBehavior35393 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393; + private readonly IRelationalTypeMappingSource _typeMappingSource; private readonly RelationalTypeMapping _boolTypeMapping; @@ -660,20 +663,23 @@ private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpressi SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary => AndAlso(Not(binary.Left), Not(binary.Right)), - // use equality where possible + // see #35393 - this optimization is not safe to do for c# null semantics // !(a == true) -> a == false // !(a == false) -> a == true SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary + when UseOldBehavior35393 => Equal(binary.Left, Not(binary.Right)), // !(true == a) -> false == a // !(false == a) -> true == a SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary + when UseOldBehavior35393 => Equal(Not(binary.Left), binary.Right), // !(a == b) -> a != b SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual( sqlBinaryOperand.Left, sqlBinaryOperand.Right), + // !(a != b) -> a == b SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal( sqlBinaryOperand.Left, sqlBinaryOperand.Right), diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 407826342f2..aca63c1e9d7 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -19,6 +19,9 @@ namespace Microsoft.EntityFrameworkCore.Query; /// public class SqlNullabilityProcessor { + private static readonly bool UseOldBehavior35393 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393; + private readonly List _nonNullableColumns; private readonly List _nullValueColumns; private readonly ISqlExpressionFactory _sqlExpressionFactory; @@ -1326,6 +1329,7 @@ protected virtual SqlExpression VisitSqlBinary( right, leftNullable, rightNullable, + optimize, out nullable); if (optimized is SqlUnaryExpression { Operand: ColumnExpression optimizedUnaryColumnOperand } optimizedUnary) @@ -1343,7 +1347,7 @@ protected virtual SqlExpression VisitSqlBinary( // we assume that NullSemantics rewrite is only needed (on the current level) // if the optimization didn't make any changes. // Reason is that optimization can/will change the nullability of the resulting expression - // and that inforation is not tracked/stored anywhere + // and that information is not tracked/stored anywhere // so we can no longer rely on nullabilities that we computed earlier (leftNullable, rightNullable) // when performing null semantics rewrite. // It should be fine because current optimizations *radically* change the expression @@ -1678,6 +1682,7 @@ private SqlExpression ProcessJoinPredicate(SqlExpression predicate) right, leftNullable, rightNullable, + optimize: true, out _); return result; @@ -1704,6 +1709,7 @@ private SqlExpression OptimizeComparison( SqlExpression right, bool leftNullable, bool rightNullable, + bool optimize, out bool nullable) { var leftNullValue = leftNullable && left is SqlConstantExpression or SqlParameterExpression; @@ -1784,37 +1790,63 @@ private SqlExpression OptimizeComparison( && !rightNullable && sqlBinaryExpression.OperatorType is ExpressionType.Equal or ExpressionType.NotEqual) { - var leftUnary = left as SqlUnaryExpression; - var rightUnary = right as SqlUnaryExpression; + nullable = false; - var leftNegated = IsLogicalNot(leftUnary); - var rightNegated = IsLogicalNot(rightUnary); + return OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); + } - if (leftNegated) - { - left = leftUnary!.Operand; - } + nullable = false; - if (rightNegated) - { - right = rightUnary!.Operand; - } + return sqlBinaryExpression.Update(left, right); + } - // a == b <=> !a == !b -> a == b - // !a == b <=> a == !b -> a != b - // a != b <=> !a != !b -> a != b - // !a != b <=> a != !b -> a == b + private SqlExpression OptimizeBooleanComparison( + SqlBinaryExpression sqlBinaryExpression, + SqlExpression left, + SqlExpression right, + bool optimize) + { + var leftUnary = left as SqlUnaryExpression; + var rightUnary = right as SqlUnaryExpression; - nullable = false; + var leftNegated = IsLogicalNot(leftUnary); + var rightNegated = IsLogicalNot(rightUnary); - return sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated - ? _sqlExpressionFactory.NotEqual(left, right) - : _sqlExpressionFactory.Equal(left, right); + if (leftNegated) + { + left = leftUnary!.Operand; + } + if (rightNegated) + { + right = rightUnary!.Operand; } - nullable = false; + var notEqual = sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated; + if (!UseOldBehavior35393) + { + // prefer equality in predicates when comparing to constants + if (optimize && notEqual && left.Type == typeof(bool) && (left is SqlConstantExpression || right is SqlConstantExpression)) + { + if (right is ColumnExpression && (left is not ColumnExpression || leftNegated)) + { + left = _sqlExpressionFactory.Not(left); + } + else + { + right = _sqlExpressionFactory.Not(right); + } - return sqlBinaryExpression.Update(left, right); + return _sqlExpressionFactory.Equal(left, right); + } + } + + // a == b <=> !a == !b -> a == b + // !a == b <=> a == !b -> a != b + // a != b <=> !a != !b -> a != b + // !a != b <=> a != !b -> a == b + return notEqual + ? _sqlExpressionFactory.NotEqual(left, right) + : _sqlExpressionFactory.Equal(left, right); } private SqlExpression RewriteNullSemantics( @@ -1832,31 +1864,41 @@ private SqlExpression RewriteNullSemantics( var leftNegated = IsLogicalNot(leftUnary); var rightNegated = IsLogicalNot(rightUnary); - if (leftNegated) + if (UseOldBehavior35393) { - left = leftUnary!.Operand; - } + if (leftNegated) + { + left = leftUnary!.Operand; + } - if (rightNegated) - { - right = rightUnary!.Operand; + if (rightNegated) + { + right = rightUnary!.Operand; + } } var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable); - var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull); + var leftIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(leftIsNull)); var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable); - var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull); + var rightIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(rightIsNull)); SqlExpression body; - if (leftNegated == rightNegated) + if (!UseOldBehavior35393) { - body = _sqlExpressionFactory.Equal(left, right); + body = OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize); } else { - // a == !b and !a == b in SQL evaluate the same as a != b - body = _sqlExpressionFactory.NotEqual(left, right); + if (leftNegated == rightNegated) + { + body = _sqlExpressionFactory.Equal(left, right); + } + else + { + // a == !b and !a == b in SQL evaluate the same as a != b + body = _sqlExpressionFactory.NotEqual(left, right); + } } // optimized expansion which doesn't distinguish between null and false @@ -1870,6 +1912,12 @@ private SqlExpression RewriteNullSemantics( // doing a full null semantics rewrite - removing all nulls from truth table nullable = false; + if (!UseOldBehavior35393 && sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) + { + // the factory takes care of simplifying equal <-> not-equal + body = _sqlExpressionFactory.Not(body); + } + // (a == b && (a != null && b != null)) || (a == null && b == null) body = _sqlExpressionFactory.OrElse( _sqlExpressionFactory.AndAlso(body, _sqlExpressionFactory.AndAlso(leftIsNotNull, rightIsNotNull)), @@ -1878,7 +1926,7 @@ private SqlExpression RewriteNullSemantics( if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual) { // the factory takes care of simplifying using DeMorgan - body = _sqlExpressionFactory.Not(body); + body = OptimizeNotExpression(_sqlExpressionFactory.Not(body)); } return body; @@ -1900,14 +1948,39 @@ protected virtual SqlExpression OptimizeNotExpression(SqlExpression expression) // !(a >= b) -> a < b // !(a < b) -> a >= b // !(a <= b) -> a > b - if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand - && TryNegate(sqlBinaryOperand.OperatorType, out var negated)) - { - return _sqlExpressionFactory.MakeBinary( - negated, - sqlBinaryOperand.Left, - sqlBinaryOperand.Right, - sqlBinaryOperand.TypeMapping)!; + if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand) + { + if (TryNegate(sqlBinaryOperand.OperatorType, out var negated)) + { + return _sqlExpressionFactory.MakeBinary( + negated, + sqlBinaryOperand.Left, + sqlBinaryOperand.Right, + sqlBinaryOperand.TypeMapping)!; + } + + if (!UseOldBehavior35393) + { + // use equality where possible - at this point (true == null) and (false == null) have been converted to + // IS NULL / IS NOT NULL (i.e. false), so this optimization is safe to do. See #35393 + // !(a == true) -> a == false + // !(a == false) -> a == true + if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } }) + { + return _sqlExpressionFactory.Equal( + sqlBinaryOperand.Left, + OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Right))); + } + + // !(true == a) -> false == a + // !(false == a) -> true == a + if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } }) + { + return _sqlExpressionFactory.Equal( + OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Left)), + sqlBinaryOperand.Right); + } + } } // the factory can optimize most `NOT` expressions diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 4ac99509883..0374285ba6d 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -2462,6 +2462,33 @@ await AssertQueryScalar( ss => ss.Set().Where(e => true).Select(e => e.Id)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_nullable_column_negated(bool async) + => await AssertQueryScalar( + async, + ss => ss.Set().Where(x => !(true == x.NullableBoolA)).Select(x => x.Id)); + + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_non_nullable_column_negated(bool async) + => await AssertQueryScalar( + async, + ss => ss.Set().Where(x => !(true == x.BoolA)).Select(x => x.Id)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async) + { + var prm = default(bool?); + + await AssertQueryScalar( + async, + ss => ss.Set().Where(x => x.NullableBoolA != null + && !object.Equals(true, x.NullableBoolA == null ? null : prm)).Select(x => x.Id)); + } + // We can't client-evaluate Like (for the expected results). // However, since the test data has no LIKE wildcards, it effectively functions like equality - except that 'null like null' returns // false instead of true. So we have this "lite" implementation which doesn't support wildcards. diff --git a/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs index 6f3350b748f..571d9fcbac9 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs @@ -654,6 +654,24 @@ join o in ss.Set().OrderBy(o => o.OrderID).Take(100) on c.CustomerID equa from o in lo.Where(x => x.CustomerID.StartsWith("A")) select new { c.CustomerID, o.OrderID }); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupJoin_on_true_equal_true(bool async) + => AssertQuery( + async, + ss => ss.Set().GroupJoin( + ss.Set(), + x => true, + x => true, + (c, g) => new { c, g }) + .Select(x => new { x.c.CustomerID, Orders = x.g }), + elementSorter: e => e.CustomerID, + elementAsserter: (e, a) => + { + Assert.Equal(e.CustomerID, a.CustomerID); + AssertCollection(e.Orders, a.Orders, elementSorter: ee => ee.OrderID); + }); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs index 48c3758e51a..bf0747d92d6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindJoinQuerySqlServerTest.cs @@ -567,6 +567,22 @@ WHERE [o0].[CustomerID] LIKE N'A%' """); } + public override async Task GroupJoin_on_true_equal_true(bool async) + { + await base.GroupJoin_on_true_equal_true(async); + + AssertSql( + """ +SELECT [c].[CustomerID], [o0].[OrderID], [o0].[CustomerID], [o0].[EmployeeID], [o0].[OrderDate] +FROM [Customers] AS [c] +OUTER APPLY ( + SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + FROM [Orders] AS [o] +) AS [o0] +ORDER BY [c].[CustomerID] +"""); + } + public override async Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async) { await base.Inner_join_with_tautology_predicate_converts_to_cross_join(async); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 2581e0a3baa..bbfe95baec6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -4568,6 +4568,42 @@ FROM [Entities1] AS [e] """); } + public override async Task Compare_constant_true_to_nullable_column_negated(bool async) + { + await base.Compare_constant_true_to_nullable_column_negated(async); + + AssertSql( + """ + SELECT [e].[Id] + FROM [Entities1] AS [e] + WHERE CAST(0 AS bit) = [e].[NullableBoolA] OR [e].[NullableBoolA] IS NULL + """); + } + + public override async Task Compare_constant_true_to_non_nullable_column_negated(bool async) + { + await base.Compare_constant_true_to_non_nullable_column_negated(async); + + AssertSql( + """ +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[BoolA] = CAST(0 AS bit) +"""); + } + + public override async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async) + { + await base.Compare_constant_true_to_expression_which_evaluates_to_null(async); + + AssertSql( + """ +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableBoolA] IS NOT NULL +"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);