From f5a15fe7cfd679ebd178a15cae553e08f2123c6b Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 19 Apr 2018 14:11:34 -0700 Subject: [PATCH] Query: Avoid applying entity equality rewrite to subquery coming from let Resolves #11728 --- .../Query/SimpleQueryTestBase.cs | 27 +++++ src/EFCore/Query/EntityQueryModelVisitor.cs | 31 +++++ ...ntityEqualityRewritingExpressionVisitor.cs | 15 +++ ...equiresMaterializationExpressionVisitor.cs | 31 +++++ src/EFCore/Query/QueryCompilationContext.cs | 2 + .../Query/SimpleQuerySqlServerTest.cs | 114 ++++++++++++++++++ 6 files changed, 220 insertions(+) diff --git a/src/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 6a65ad1d418..7d0250b94dd 100644 --- a/src/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -4258,5 +4258,32 @@ where details.Any() select new { Count = details.Count() }); } + [ConditionalFact] + public virtual void Let_entity_equality_to_null() + { + AssertQuery( + cs => from c in cs.Where(c => c.CustomerID.StartsWith("A")) + let o = c.Orders.OrderBy(e => e.OrderDate).FirstOrDefault() + where o != null + select new + { + c.CustomerID, + o.OrderDate + }); + } + + [ConditionalFact] + public virtual void Let_entity_equality_to_other_entity() + { + AssertQuery( + cs => from c in cs.Where(c => c.CustomerID.StartsWith("A")) + let o = c.Orders.OrderBy(e => e.OrderDate).FirstOrDefault() + where o != new Order() + select new + { + c.CustomerID, + A = (o != null ? o.OrderDate : null) + }); + } } } diff --git a/src/EFCore/Query/EntityQueryModelVisitor.cs b/src/EFCore/Query/EntityQueryModelVisitor.cs index 1895763fc1c..f001a746734 100644 --- a/src/EFCore/Query/EntityQueryModelVisitor.cs +++ b/src/EFCore/Query/EntityQueryModelVisitor.cs @@ -265,6 +265,34 @@ protected virtual void OnBeforeNavigationRewrite([NotNull] QueryModel queryModel { } + private class DuplicateQueryModelIdentifyingExpressionVisitor : RelinqExpressionVisitor + { + private readonly QueryCompilationContext _queryCompilationContext; + private ISet _queryModels = new HashSet(); + + public DuplicateQueryModelIdentifyingExpressionVisitor(QueryCompilationContext queryCompilationContext) + { + _queryCompilationContext = queryCompilationContext; + } + + protected override Expression VisitSubQuery(SubQueryExpression subQueryExpression) + { + var subQueryModel = subQueryExpression.QueryModel; + if (_queryModels.Contains(subQueryModel)) + { + _queryCompilationContext.DuplicateQueryModels.Add(subQueryModel); + } + else + { + _queryModels.Add(subQueryModel); + } + + subQueryModel.TransformExpressions(Visit); + + return base.VisitSubQuery(subQueryExpression); + } + } + /// /// Applies optimizations to the query. /// @@ -276,6 +304,9 @@ protected virtual void OptimizeQueryModel( { Check.NotNull(queryModel, nameof(queryModel)); + queryModel.TransformExpressions( + new DuplicateQueryModelIdentifyingExpressionVisitor(_queryCompilationContext).Visit); + ExtractQueryAnnotations(queryModel); new EagerLoadingExpressionVisitor(_queryCompilationContext, _querySourceTracingExpressionVisitorFactory) diff --git a/src/EFCore/Query/ExpressionVisitors/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/ExpressionVisitors/Internal/EntityEqualityRewritingExpressionVisitor.cs index 15e0b05c417..ddb8b760a9d 100644 --- a/src/EFCore/Query/ExpressionVisitors/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/ExpressionVisitors/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -166,6 +166,11 @@ private Expression RewriteNullEquality(ExpressionType nodeType, Expression nonNu Expression.Constant(null))); } + if (IsInvalidSubQueryExpression(nonNullExpression)) + { + return null; + } + var entityType = _model.FindEntityType(nonNullExpression.Type) ?? GetEntityType(properties, qsre); @@ -220,6 +225,10 @@ var nullConstantExpression return Expression.MakeBinary(nodeType, keyAccessExpression, nullConstantExpression); } + private bool IsInvalidSubQueryExpression(Expression expression) + => expression is SubQueryExpression subQuery + && _queryCompilationContext.DuplicateQueryModels.Contains(subQuery.QueryModel); + private Expression RewriteEntityEquality(ExpressionType nodeType, Expression left, Expression right) { var leftProperties = MemberAccessBindingExpressionVisitor @@ -249,6 +258,12 @@ private Expression RewriteEntityEquality(ExpressionType nodeType, Expression lef return Expression.Constant(false); } + if (IsInvalidSubQueryExpression(left) + || IsInvalidSubQueryExpression(right)) + { + return null; + } + var entityType = _model.FindEntityType(left.Type) ?? _model.FindEntityType(right.Type) ?? GetEntityType(leftProperties, leftQsre) diff --git a/src/EFCore/Query/ExpressionVisitors/Internal/RequiresMaterializationExpressionVisitor.cs b/src/EFCore/Query/ExpressionVisitors/Internal/RequiresMaterializationExpressionVisitor.cs index 4207c315393..834a6f85192 100644 --- a/src/EFCore/Query/ExpressionVisitors/Internal/RequiresMaterializationExpressionVisitor.cs +++ b/src/EFCore/Query/ExpressionVisitors/Internal/RequiresMaterializationExpressionVisitor.cs @@ -192,6 +192,37 @@ protected override Expression VisitMethodCall(MethodCallExpression node) return newExpression; } + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + protected override Expression VisitBinary(BinaryExpression node) + => node.Update( + VisitBinaryOperand(node.Left, node.NodeType), + node.Conversion, + VisitBinaryOperand(node.Right, node.NodeType)); + + private Expression VisitBinaryOperand(Expression operand, ExpressionType comparison) + { + if (comparison == ExpressionType.Equal + || comparison == ExpressionType.NotEqual) + { + if (operand is SubQueryExpression subQueryExpression + && _queryModelVisitor.QueryCompilationContext.DuplicateQueryModels.Contains(subQueryExpression.QueryModel)) + { + _queryModelStack.Push(subQueryExpression.QueryModel); + + subQueryExpression.QueryModel.TransformExpressions(Visit); + + _queryModelStack.Pop(); + + return operand; + } + } + + return Visit(operand); + } + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index b61b87a6771..957c06d888e 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -65,6 +65,8 @@ public QueryCompilationContext( TrackQueryResults = trackQueryResults; } + internal ISet DuplicateQueryModels = new HashSet(); + /// /// Registers a mapping between correlated collection query models and metadata needed to process them. /// diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index bbf2b53713a..6101217e795 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -4514,6 +4514,120 @@ FROM [Order Details] AS [od] WHERE ([od].[Quantity] < CAST(10 AS smallint)) AND ([o].[OrderID] = [od].[OrderID]))"); } + public override void Let_entity_equality_to_null() + { + base.Let_entity_equality_to_null(); + + AssertSql( + @"SELECT [c].[CustomerID], ( + SELECT TOP(1) [e0].[OrderDate] + FROM [Orders] AS [e0] + WHERE [c].[CustomerID] = [e0].[CustomerID] + ORDER BY [e0].[OrderDate] +) AS [OrderDate] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] LIKE N'A' + N'%' AND (LEFT([c].[CustomerID], LEN(N'A')) = N'A')", + // + @"@_outer_CustomerID='ALFKI' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID='ANATR' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID='ANTON' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID='AROUT' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]"); + } + + public override void Let_entity_equality_to_other_entity() + { + base.Let_entity_equality_to_other_entity(); + + AssertSql( + @"SELECT [c].[CustomerID], ( + SELECT TOP(1) [e2].[OrderDate] + FROM [Orders] AS [e2] + WHERE [c].[CustomerID] = [e2].[CustomerID] + ORDER BY [e2].[OrderDate] +) +FROM [Customers] AS [c] +WHERE [c].[CustomerID] LIKE N'A' + N'%' AND (LEFT([c].[CustomerID], LEN(N'A')) = N'A')", + // + @"@_outer_CustomerID='ALFKI' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID1='ALFKI' (Size = 5) + +SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate] +FROM [Orders] AS [e1] +WHERE @_outer_CustomerID1 = [e1].[CustomerID] +ORDER BY [e1].[OrderDate]", + // + @"@_outer_CustomerID='ANATR' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID1='ANATR' (Size = 5) + +SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate] +FROM [Orders] AS [e1] +WHERE @_outer_CustomerID1 = [e1].[CustomerID] +ORDER BY [e1].[OrderDate]", + // + @"@_outer_CustomerID='ANTON' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID1='ANTON' (Size = 5) + +SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate] +FROM [Orders] AS [e1] +WHERE @_outer_CustomerID1 = [e1].[CustomerID] +ORDER BY [e1].[OrderDate]", + // + @"@_outer_CustomerID='AROUT' (Size = 5) + +SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate] +FROM [Orders] AS [e] +WHERE @_outer_CustomerID = [e].[CustomerID] +ORDER BY [e].[OrderDate]", + // + @"@_outer_CustomerID1='AROUT' (Size = 5) + +SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate] +FROM [Orders] AS [e1] +WHERE @_outer_CustomerID1 = [e1].[CustomerID] +ORDER BY [e1].[OrderDate]"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);