From d9a08a87b4b91eeedfe4d455dc27e545a93d79aa Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Fri, 20 Apr 2018 11:20:23 -0700 Subject: [PATCH] Query: Improve alias-ing of projection in SelectExpression Issue: When we try to find name of the projection, we go through Convert/NullableExpression nodes. But when updating the alias we just wrap it inside a new AliasExpression. If the inner expression had AliasExpression giving it a name, then we would have multiple "AS" in SQL Fix: When assigning new alias to projection, we run visitor and tries to replace old name to new name to avoid multiple aliases. Resolves #11757 --- .../Query/Expressions/SelectExpression.cs | 38 +++++++++++++--- .../Query/AsyncGroupByQueryTestBase.cs | 43 ++++++++++++++++--- .../Query/GroupByQueryTestBase.cs | 43 ++++++++++++++++--- .../Query/GroupByQuerySqlServerTest.cs | 36 +++++++++++++--- 4 files changed, 137 insertions(+), 23 deletions(-) diff --git a/src/EFCore.Relational/Query/Expressions/SelectExpression.cs b/src/EFCore.Relational/Query/Expressions/SelectExpression.cs index ee6c9181a47..39e8dd28357 100644 --- a/src/EFCore.Relational/Query/Expressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Expressions/SelectExpression.cs @@ -760,16 +760,20 @@ var currentProjectionIndex var updatedExpression = expression; - if (!(expression is ColumnReferenceExpression - || expression is ColumnExpression - || expression is AliasExpression) + // All subquery projections are required to be ColumnExpression/ColumnReferenceExpression/AliasExpression + if (!(updatedExpression is ColumnExpression + || updatedExpression is ColumnReferenceExpression + || updatedExpression is AliasExpression) || !string.Equals(currentAlias, uniqueAlias, StringComparison.OrdinalIgnoreCase)) { - updatedExpression = new AliasExpression(uniqueAlias, (expression as AliasExpression)?.Expression ?? expression); + var newExpression = new NameReplacingExpressionVisitor(currentAlias, uniqueAlias) + .Visit(updatedExpression); + updatedExpression = newExpression != updatedExpression + ? newExpression + : new AliasExpression(uniqueAlias, updatedExpression); } var currentOrderingIndex = _orderBy.FindIndex(e => e.Expression.Equals(expression)); - if (currentOrderingIndex != -1) { var oldOrdering = _orderBy[currentOrderingIndex]; @@ -786,6 +790,30 @@ var currentProjectionIndex return updatedExpression; } + private class NameReplacingExpressionVisitor : ExpressionVisitor + { + private readonly string _oldName; + private readonly string _newName; + + public NameReplacingExpressionVisitor(string oldName, string newName) + { + _oldName = oldName; + _newName = newName; + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case AliasExpression aliasExpression + when string.Equals(aliasExpression.Alias, _oldName, StringComparison.OrdinalIgnoreCase): + return new AliasExpression(_newName, aliasExpression.Expression); + } + + return base.VisitExtension(extensionExpression); + } + } + private static string GetColumnName(Expression expression) { expression = expression.RemoveConvert(); diff --git a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs index 2c1a7e57dcf..d6bdd9076b0 100644 --- a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs @@ -906,10 +906,11 @@ await AssertQuery( (from c in cs join o in os on c.CustomerID equals o.CustomerID into grouping - from o in grouping + from o in grouping.DefaultIfEmpty() + where o != null select o) .GroupBy(o => o.CustomerID) - .Select(g => new { g.Key, Count = g.Average(o => o.OrderID) }), + .Select(g => new { g.Key, Average = g.Average(o => o.OrderID) }), e => e.Key); } @@ -921,10 +922,10 @@ await AssertQuery( (from c in cs join o in os on c.CustomerID equals o.CustomerID into grouping - from o in grouping + from o in grouping.DefaultIfEmpty() select c) .GroupBy(c => c.CustomerID) - .Select(g => new { g.Key, Count = g.Max(c => c.City) }), + .Select(g => new { g.Key, Max = g.Max(c => c.City) }), e => e.Key); } @@ -936,13 +937,43 @@ await AssertQuery( (from o in os join c in cs on o.CustomerID equals c.CustomerID into grouping - from c in grouping + from c in grouping.DefaultIfEmpty() select o) .GroupBy(o => o.CustomerID) - .Select(g => new { g.Key, Count = g.Average(o => o.OrderID) }), + .Select(g => new { g.Key, Average = g.Average(o => o.OrderID) }), e => e.Key); } + [ConditionalFact] + public virtual async Task GroupJoin_GroupBy_Aggregate_4() + { + await AssertQuery( + (os, cs) => + (from c in cs + join o in os + on c.CustomerID equals o.CustomerID into grouping + from o in grouping.DefaultIfEmpty() + select c) + .GroupBy(c => c.CustomerID) + .Select(g => new { Value = g.Key, Max = g.Max(c => c.City) }), + e => e.Value); + } + + [ConditionalFact] + public virtual async Task GroupJoin_GroupBy_Aggregate_5() + { + await AssertQuery( + (os, cs) => + (from o in os + join c in cs + on o.CustomerID equals c.CustomerID into grouping + from c in grouping.DefaultIfEmpty() + select o) + .GroupBy(o => o.OrderID) + .Select(g => new { Value = g.Key, Average = g.Average(o => o.OrderID) }), + e => e.Value); + } + [ConditionalFact] public virtual async Task GroupBy_optional_navigation_member_Aggregate() { diff --git a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs index ce67c42c6ee..e5dcf3059c1 100644 --- a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs @@ -909,10 +909,11 @@ public virtual void GroupJoin_GroupBy_Aggregate() (from c in cs join o in os on c.CustomerID equals o.CustomerID into grouping - from o in grouping + from o in grouping.DefaultIfEmpty() + where o != null select o) .GroupBy(o => o.CustomerID) - .Select(g => new { g.Key, Count = g.Average(o => o.OrderID) }), + .Select(g => new { g.Key, Average = g.Average(o => o.OrderID) }), e => e.Key); } @@ -924,10 +925,10 @@ public virtual void GroupJoin_GroupBy_Aggregate_2() (from c in cs join o in os on c.CustomerID equals o.CustomerID into grouping - from o in grouping + from o in grouping.DefaultIfEmpty() select c) .GroupBy(c => c.CustomerID) - .Select(g => new { g.Key, Count = g.Max(c => c.City) }), + .Select(g => new { g.Key, Max = g.Max(c => c.City) }), e => e.Key); } @@ -939,13 +940,43 @@ public virtual void GroupJoin_GroupBy_Aggregate_3() (from o in os join c in cs on o.CustomerID equals c.CustomerID into grouping - from c in grouping + from c in grouping.DefaultIfEmpty() select o) .GroupBy(o => o.CustomerID) - .Select(g => new { g.Key, Count = g.Average(o => o.OrderID) }), + .Select(g => new { g.Key, Average = g.Average(o => o.OrderID) }), e => e.Key); } + [ConditionalFact] + public virtual void GroupJoin_GroupBy_Aggregate_4() + { + AssertQuery( + (os, cs) => + (from c in cs + join o in os + on c.CustomerID equals o.CustomerID into grouping + from o in grouping.DefaultIfEmpty() + select c) + .GroupBy(c => c.CustomerID) + .Select(g => new { Value = g.Key, Max = g.Max(c => c.City) }), + e => e.Value); + } + + [ConditionalFact] + public virtual void GroupJoin_GroupBy_Aggregate_5() + { + AssertQuery( + (os, cs) => + (from o in os + join c in cs + on o.CustomerID equals c.CustomerID into grouping + from c in grouping.DefaultIfEmpty() + select o) + .GroupBy(o => o.OrderID) + .Select(g => new { Value = g.Key, Average = g.Average(o => o.OrderID) }), + e => e.Value); + } + [ConditionalFact] public virtual void GroupBy_optional_navigation_member_Aggregate() { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs index c3c72964989..fdba8e9b7db 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs @@ -810,9 +810,10 @@ public override void GroupJoin_GroupBy_Aggregate() base.GroupJoin_GroupBy_Aggregate(); AssertSql( - @"SELECT [o].[CustomerID] AS [Key], AVG(CAST([o].[OrderID] AS float)) AS [Count] + @"SELECT [o].[CustomerID] AS [Key], AVG(CAST([o].[OrderID] AS float)) AS [Average] FROM [Customers] AS [c] -INNER JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] +LEFT JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] +WHERE [o].[OrderID] IS NOT NULL GROUP BY [o].[CustomerID]"); } @@ -821,9 +822,9 @@ public override void GroupJoin_GroupBy_Aggregate_2() base.GroupJoin_GroupBy_Aggregate_2(); AssertSql( - @"SELECT [c].[CustomerID] AS [Key], MAX([c].[City]) AS [Count] + @"SELECT [c].[CustomerID] AS [Key], MAX([c].[City]) AS [Max] FROM [Customers] AS [c] -INNER JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] +LEFT JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] GROUP BY [c].[CustomerID]"); } @@ -832,12 +833,35 @@ public override void GroupJoin_GroupBy_Aggregate_3() base.GroupJoin_GroupBy_Aggregate_3(); AssertSql( - @"SELECT [o].[CustomerID] AS [Key], AVG(CAST([o].[OrderID] AS float)) AS [Count] + @"SELECT [o].[CustomerID] AS [Key], AVG(CAST([o].[OrderID] AS float)) AS [Average] FROM [Orders] AS [o] -INNER JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] +LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] GROUP BY [o].[CustomerID]"); } + public override void GroupJoin_GroupBy_Aggregate_4() + { + base.GroupJoin_GroupBy_Aggregate_4(); + + AssertSql( + @"SELECT [c].[CustomerID] AS [Value], MAX([c].[City]) AS [Max] +FROM [Customers] AS [c] +LEFT JOIN [Orders] AS [o] ON [c].[CustomerID] = [o].[CustomerID] +GROUP BY [c].[CustomerID]"); + } + + public override void GroupJoin_GroupBy_Aggregate_5() + { + base.GroupJoin_GroupBy_Aggregate_5(); + + AssertSql( + @"SELECT [o].[OrderID] AS [Value], AVG(CAST([o].[OrderID] AS float)) AS [Average] +FROM [Orders] AS [o] +LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID] +GROUP BY [o].[OrderID]"); + + } + public override void GroupBy_optional_navigation_member_Aggregate() { base.GroupBy_optional_navigation_member_Aggregate();