Skip to content

Commit

Permalink
Query: Avoid applying entity equality rewrite to subquery coming from…
Browse files Browse the repository at this point in the history
… let

Resolves #11728
  • Loading branch information
smitpatel committed Apr 20, 2018
1 parent e08140c commit f5a15fe
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4258,5 +4258,32 @@ where details.Any()
select new { Count = details.Count() });
}

[ConditionalFact]
public virtual void Let_entity_equality_to_null()
{
AssertQuery<Customer>(
cs => from c in cs.Where(c => c.CustomerID.StartsWith("A"))
let o = c.Orders.OrderBy(e => e.OrderDate).FirstOrDefault()
where o != null
select new
{
c.CustomerID,
o.OrderDate
});
}

[ConditionalFact]
public virtual void Let_entity_equality_to_other_entity()
{
AssertQuery<Customer>(
cs => from c in cs.Where(c => c.CustomerID.StartsWith("A"))
let o = c.Orders.OrderBy(e => e.OrderDate).FirstOrDefault()
where o != new Order()
select new
{
c.CustomerID,
A = (o != null ? o.OrderDate : null)
});
}
}
}
31 changes: 31 additions & 0 deletions src/EFCore/Query/EntityQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,34 @@ protected virtual void OnBeforeNavigationRewrite([NotNull] QueryModel queryModel
{
}

private class DuplicateQueryModelIdentifyingExpressionVisitor : RelinqExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private ISet<QueryModel> _queryModels = new HashSet<QueryModel>();

public DuplicateQueryModelIdentifyingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
_queryCompilationContext = queryCompilationContext;
}

protected override Expression VisitSubQuery(SubQueryExpression subQueryExpression)
{
var subQueryModel = subQueryExpression.QueryModel;
if (_queryModels.Contains(subQueryModel))
{
_queryCompilationContext.DuplicateQueryModels.Add(subQueryModel);
}
else
{
_queryModels.Add(subQueryModel);
}

subQueryModel.TransformExpressions(Visit);

return base.VisitSubQuery(subQueryExpression);
}
}

/// <summary>
/// Applies optimizations to the query.
/// </summary>
Expand All @@ -276,6 +304,9 @@ protected virtual void OptimizeQueryModel(
{
Check.NotNull(queryModel, nameof(queryModel));

queryModel.TransformExpressions(
new DuplicateQueryModelIdentifyingExpressionVisitor(_queryCompilationContext).Visit);

ExtractQueryAnnotations(queryModel);

new EagerLoadingExpressionVisitor(_queryCompilationContext, _querySourceTracingExpressionVisitorFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ private Expression RewriteNullEquality(ExpressionType nodeType, Expression nonNu
Expression.Constant(null)));
}

if (IsInvalidSubQueryExpression(nonNullExpression))
{
return null;
}

var entityType = _model.FindEntityType(nonNullExpression.Type)
?? GetEntityType(properties, qsre);

Expand Down Expand Up @@ -220,6 +225,10 @@ var nullConstantExpression
return Expression.MakeBinary(nodeType, keyAccessExpression, nullConstantExpression);
}

private bool IsInvalidSubQueryExpression(Expression expression)
=> expression is SubQueryExpression subQuery
&& _queryCompilationContext.DuplicateQueryModels.Contains(subQuery.QueryModel);

private Expression RewriteEntityEquality(ExpressionType nodeType, Expression left, Expression right)
{
var leftProperties = MemberAccessBindingExpressionVisitor
Expand Down Expand Up @@ -249,6 +258,12 @@ private Expression RewriteEntityEquality(ExpressionType nodeType, Expression lef
return Expression.Constant(false);
}

if (IsInvalidSubQueryExpression(left)
|| IsInvalidSubQueryExpression(right))
{
return null;
}

var entityType = _model.FindEntityType(left.Type)
?? _model.FindEntityType(right.Type)
?? GetEntityType(leftProperties, leftQsre)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,37 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
return newExpression;
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override Expression VisitBinary(BinaryExpression node)
=> node.Update(
VisitBinaryOperand(node.Left, node.NodeType),
node.Conversion,
VisitBinaryOperand(node.Right, node.NodeType));

private Expression VisitBinaryOperand(Expression operand, ExpressionType comparison)
{
if (comparison == ExpressionType.Equal
|| comparison == ExpressionType.NotEqual)
{
if (operand is SubQueryExpression subQueryExpression
&& _queryModelVisitor.QueryCompilationContext.DuplicateQueryModels.Contains(subQueryExpression.QueryModel))
{
_queryModelStack.Push(subQueryExpression.QueryModel);

subQueryExpression.QueryModel.TransformExpressions(Visit);

_queryModelStack.Pop();

return operand;
}
}

return Visit(operand);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down
2 changes: 2 additions & 0 deletions src/EFCore/Query/QueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public QueryCompilationContext(
TrackQueryResults = trackQueryResults;
}

internal ISet<QueryModel> DuplicateQueryModels = new HashSet<QueryModel>();

/// <summary>
/// Registers a mapping between correlated collection query models and metadata needed to process them.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4514,6 +4514,120 @@ FROM [Order Details] AS [od]
WHERE ([od].[Quantity] < CAST(10 AS smallint)) AND ([o].[OrderID] = [od].[OrderID]))");
}

public override void Let_entity_equality_to_null()
{
base.Let_entity_equality_to_null();

AssertSql(
@"SELECT [c].[CustomerID], (
SELECT TOP(1) [e0].[OrderDate]
FROM [Orders] AS [e0]
WHERE [c].[CustomerID] = [e0].[CustomerID]
ORDER BY [e0].[OrderDate]
) AS [OrderDate]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A' + N'%' AND (LEFT([c].[CustomerID], LEN(N'A')) = N'A')",
//
@"@_outer_CustomerID='ALFKI' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID='ANATR' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID='ANTON' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID='AROUT' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]");
}

public override void Let_entity_equality_to_other_entity()
{
base.Let_entity_equality_to_other_entity();

AssertSql(
@"SELECT [c].[CustomerID], (
SELECT TOP(1) [e2].[OrderDate]
FROM [Orders] AS [e2]
WHERE [c].[CustomerID] = [e2].[CustomerID]
ORDER BY [e2].[OrderDate]
)
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A' + N'%' AND (LEFT([c].[CustomerID], LEN(N'A')) = N'A')",
//
@"@_outer_CustomerID='ALFKI' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID1='ALFKI' (Size = 5)
SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate]
FROM [Orders] AS [e1]
WHERE @_outer_CustomerID1 = [e1].[CustomerID]
ORDER BY [e1].[OrderDate]",
//
@"@_outer_CustomerID='ANATR' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID1='ANATR' (Size = 5)
SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate]
FROM [Orders] AS [e1]
WHERE @_outer_CustomerID1 = [e1].[CustomerID]
ORDER BY [e1].[OrderDate]",
//
@"@_outer_CustomerID='ANTON' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID1='ANTON' (Size = 5)
SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate]
FROM [Orders] AS [e1]
WHERE @_outer_CustomerID1 = [e1].[CustomerID]
ORDER BY [e1].[OrderDate]",
//
@"@_outer_CustomerID='AROUT' (Size = 5)
SELECT TOP(1) [e].[OrderID], [e].[CustomerID], [e].[EmployeeID], [e].[OrderDate]
FROM [Orders] AS [e]
WHERE @_outer_CustomerID = [e].[CustomerID]
ORDER BY [e].[OrderDate]",
//
@"@_outer_CustomerID1='AROUT' (Size = 5)
SELECT TOP(1) [e1].[OrderID], [e1].[CustomerID], [e1].[EmployeeID], [e1].[OrderDate]
FROM [Orders] AS [e1]
WHERE @_outer_CustomerID1 = [e1].[CustomerID]
ORDER BY [e1].[OrderDate]");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit f5a15fe

Please sign in to comment.