From 60e9a430262e413670d534ea9c9effc5e43d696d Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Thu, 19 Apr 2018 16:22:25 -0700 Subject: [PATCH] Fix to #11687 - Query with projection cast on correlated collections throws on 2.1.0-preview2 Problem was that when we apply correlated collection optimization we were not accounting for scenario where the result collection element type was different than the input. This can be achieved by using ToList or ToArray etc. Fix is to flow the collection result type as well as the input type to the CorrelateSubquery method, and construct the resulting collection based on the output type, if they are different. --- .../Query/AsyncGearsOfWarQueryTestBase.cs | 38 ++++++++++++++++ .../Query/GearsOfWarQueryTestBase.cs | 38 ++++++++++++++++ .../CorrelatedCollectionOptimizingVisitor.cs | 43 +++++++++++-------- src/EFCore/Query/Internal/IQueryBuffer.cs | 12 ++++-- src/EFCore/Query/Internal/QueryBuffer.cs | 12 ++++-- .../AsyncGearsOfWarQuerySqlServerTest.cs | 42 ++++++++++++++++++ .../Query/GearsOfWarQuerySqlServerTest.cs | 42 ++++++++++++++++++ 7 files changed, 201 insertions(+), 26 deletions(-) diff --git a/src/EFCore.Specification.Tests/Query/AsyncGearsOfWarQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncGearsOfWarQueryTestBase.cs index f92c322ec23..d8f76e87ccc 100644 --- a/src/EFCore.Specification.Tests/Query/AsyncGearsOfWarQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/AsyncGearsOfWarQueryTestBase.cs @@ -1653,5 +1653,43 @@ public virtual async Task Cast_to_derived_type_after_OfType_works() await AssertQuery( gs => gs.OfType().Cast()); } + + [ConditionalFact] + public virtual async Task Cast_subquery_to_base_type_using_typed_ToList() + { + await AssertQuery( + cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.Select(g => new Officer + { + CityOrBirthName = g.CityOrBirthName, + FullName = g.FullName, + HasSoulPatch = g.HasSoulPatch, + LeaderNickname = g.LeaderNickname, + LeaderSquadId = g.LeaderSquadId, + Nickname = g.Nickname, + Rank = g.Rank, + SquadId = g.SquadId + }).ToList()), + assertOrder: true, + elementAsserter: CollectionAsserter(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname))); + } + + [ConditionalFact] + public virtual async Task Cast_ordered_subquery_to_base_type_using_typed_ToArray() + { + await AssertQuery( + cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.OrderByDescending(g => g.Nickname).Select(g => new Officer + { + CityOrBirthName = g.CityOrBirthName, + FullName = g.FullName, + HasSoulPatch = g.HasSoulPatch, + LeaderNickname = g.LeaderNickname, + LeaderSquadId = g.LeaderSquadId, + Nickname = g.Nickname, + Rank = g.Rank, + SquadId = g.SquadId + }).ToArray()), + assertOrder: true, + elementAsserter: CollectionAsserter(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname))); + } } } diff --git a/src/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs index 96ab9d73405..7d71ae2ac45 100644 --- a/src/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs @@ -4957,6 +4957,44 @@ public virtual void Select_subquery_distinct_singleordefault_boolean_empty_with_ assertOrder: true); } + [ConditionalFact] + public virtual void Cast_subquery_to_base_type_using_typed_ToList() + { + AssertQuery( + cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.Select(g => new Officer + { + CityOrBirthName = g.CityOrBirthName, + FullName = g.FullName, + HasSoulPatch = g.HasSoulPatch, + LeaderNickname = g.LeaderNickname, + LeaderSquadId = g.LeaderSquadId, + Nickname = g.Nickname, + Rank = g.Rank, + SquadId = g.SquadId + }).ToList()), + assertOrder: true, + elementAsserter: CollectionAsserter(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname))); + } + + [ConditionalFact] + public virtual void Cast_ordered_subquery_to_base_type_using_typed_ToArray() + { + AssertQuery( + cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.OrderByDescending(g => g.Nickname).Select(g => new Officer + { + CityOrBirthName = g.CityOrBirthName, + FullName = g.FullName, + HasSoulPatch = g.HasSoulPatch, + LeaderNickname = g.LeaderNickname, + LeaderSquadId = g.LeaderSquadId, + Nickname = g.Nickname, + Rank = g.Rank, + SquadId = g.SquadId + }).ToArray()), + assertOrder: true, + elementAsserter: CollectionAsserter(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname))); + } + // Remember to add any new tests to Async version of this test class protected GearsOfWarContext CreateContext() => Fixture.CreateContext(); diff --git a/src/EFCore/Query/ExpressionVisitors/Internal/CorrelatedCollectionOptimizingVisitor.cs b/src/EFCore/Query/ExpressionVisitors/Internal/CorrelatedCollectionOptimizingVisitor.cs index c7b795dcbdc..3ba28ddc289 100644 --- a/src/EFCore/Query/ExpressionVisitors/Internal/CorrelatedCollectionOptimizingVisitor.cs +++ b/src/EFCore/Query/ExpressionVisitors/Internal/CorrelatedCollectionOptimizingVisitor.cs @@ -91,7 +91,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && innerMethodCallExpression.Method.MethodIsClosedFormOf(CollectionNavigationSubqueryInjector.MaterializeCollectionNavigationMethodInfo) && innerMethodCallExpression.Arguments[1] is SubQueryExpression subQueryExpression1) { - return TryRewrite(subQueryExpression1, /*forceToListResult*/ true, out var result) + return TryRewrite(subQueryExpression1, /*forceToListResult*/ true, methodCallExpression.Type.GetSequenceType(), out var result) ? result : methodCallExpression; } @@ -99,7 +99,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (methodCallExpression.Method.MethodIsClosedFormOf(_toListMethodInfo) && methodCallExpression.Arguments[0] is SubQueryExpression subQueryExpression2) { - return TryRewrite(subQueryExpression2, /*forceToListResult*/ true, out var result) + return TryRewrite(subQueryExpression2, /*forceToListResult*/ true, methodCallExpression.Type.GetSequenceType(), out var result) ? result : methodCallExpression; } @@ -107,7 +107,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (methodCallExpression.Method.MethodIsClosedFormOf(CollectionNavigationSubqueryInjector.MaterializeCollectionNavigationMethodInfo) && methodCallExpression.Arguments[1] is SubQueryExpression subQueryExpression3) { - return TryRewrite(subQueryExpression3, /*forceToListResult*/ false, out var result) + return TryRewrite(subQueryExpression3, /*forceToListResult*/ false, /* listResultElementType */ null, out var result) ? result : methodCallExpression; } @@ -120,11 +120,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// directly from your code. This API may change or be removed in future releases. /// protected override Expression VisitSubQuery(SubQueryExpression subQueryExpression) - => TryRewrite(subQueryExpression, /*forceToListResult*/ false, out var result) + => TryRewrite(subQueryExpression, /*forceToListResult*/ false, /* listResultElementType */ null, out var result) ? result : base.VisitSubQuery(subQueryExpression); - private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListResult, out Expression result) + private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListResult, Type listResultElementType, out Expression result) { if (_queryCompilationContext.TryGetCorrelatedSubqueryMetadata(subQueryExpression.QueryModel.MainFromClause, out var correlatedSubqueryMetadata)) { @@ -135,7 +135,8 @@ private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListR correlatedSubqueryMetadata.CollectionNavigation, correlatedSubqueryMetadata.TrackingQuery, parentQsre, - forceToListResult); + forceToListResult, + listResultElementType); return true; } @@ -151,7 +152,8 @@ private Expression Rewrite( INavigation navigation, bool trackingQuery, QuerySourceReferenceExpression originQuerySource, - bool forceListResult) + bool forceListResult, + Type listResultElementType) { var querySourceReferenceFindingExpressionTreeVisitor = new QuerySourceReferenceFindingExpressionVisitor(); @@ -294,8 +296,10 @@ var joinQuerySourceReferenceExpression var newOriginKeyElements = ((NewArrayExpression)(((NewExpression)newOriginKey).Arguments[0])).Expressions; var remappedOriginKeyElements = RemapOriginKeyExpressions(newOriginKeyElements, joinQuerySourceReferenceExpression, subQueryProjection); + var collectionQueryModelSelectorType = collectionQueryModel.SelectClause.Selector.Type; + var tupleCtor = typeof(Tuple<,,>).MakeGenericType( - collectionQueryModel.SelectClause.Selector.Type, + collectionQueryModelSelectorType, typeof(MaterializedAnonymousObject), typeof(MaterializedAnonymousObject)).GetConstructors().FirstOrDefault(); @@ -307,14 +311,16 @@ var joinQuerySourceReferenceExpression Expression resultCollectionFactoryExpressionBody; if (forceListResult - || navigation.ForeignKey.DeclaringEntityType.ClrType != collectionQueryModel.SelectClause.Selector.Type) + || navigation.ForeignKey.DeclaringEntityType.ClrType != collectionQueryModelSelectorType) { - var resultCollectionType = typeof(List<>).MakeGenericType(collectionQueryModel.SelectClause.Selector.Type); + listResultElementType = listResultElementType ?? collectionQueryModelSelectorType; + var resultCollectionType = typeof(List<>).MakeGenericType(listResultElementType); var resultCollectionCtor = resultCollectionType.GetTypeInfo().GetDeclaredConstructor(Array.Empty()); correlateSubqueryMethod = correlateSubqueryMethod.MakeGenericMethod( - collectionQueryModel.SelectClause.Selector.Type, - typeof(List<>).MakeGenericType(collectionQueryModel.SelectClause.Selector.Type)); + collectionQueryModelSelectorType, + listResultElementType, + typeof(List<>).MakeGenericType(listResultElementType)); resultCollectionFactoryExpressionBody = Expression.New(resultCollectionCtor); @@ -323,8 +329,9 @@ var joinQuerySourceReferenceExpression else { correlateSubqueryMethod = correlateSubqueryMethod.MakeGenericMethod( - collectionQueryModel.SelectClause.Selector.Type, - navigation.GetCollectionAccessor().CollectionType); + collectionQueryModelSelectorType, + collectionQueryModelSelectorType, + navigation.GetCollectionAccessor().CollectionType); resultCollectionFactoryExpressionBody = Expression.Convert( @@ -352,12 +359,12 @@ var joinQuerySourceReferenceExpression remappedOriginKeyElements)) }); - var collectionModelSelectorType = collectionQueryModel.SelectClause.Selector.Type; + collectionQueryModelSelectorType = collectionQueryModel.SelectClause.Selector.Type; // Enumerable or OrderedEnumerable collectionQueryModel.ResultTypeOverride = collectionQueryModel.BodyClauses.OfType().Any() - ? typeof(IOrderedEnumerable<>).MakeGenericType(collectionModelSelectorType) - : typeof(IEnumerable<>).MakeGenericType(collectionModelSelectorType); + ? typeof(IOrderedEnumerable<>).MakeGenericType(collectionQueryModelSelectorType) + : typeof(IEnumerable<>).MakeGenericType(collectionQueryModelSelectorType); var lambda = (Expression)Expression.Lambda(new SubQueryExpression(collectionQueryModel)); if (_queryCompilationContext.IsAsyncQuery) @@ -365,7 +372,7 @@ var joinQuerySourceReferenceExpression lambda = Expression.Convert( lambda, typeof(Func<>).MakeGenericType( - typeof(IAsyncEnumerable<>).MakeGenericType(collectionModelSelectorType))); + typeof(IAsyncEnumerable<>).MakeGenericType(collectionQueryModelSelectorType))); } // since we cloned QM, we need to check if it's query sources require materialization (e.g. TypeIs operation for InMemory) diff --git a/src/EFCore/Query/Internal/IQueryBuffer.cs b/src/EFCore/Query/Internal/IQueryBuffer.cs index 5cdaf898c1b..c5eb1cb4c70 100644 --- a/src/EFCore/Query/Internal/IQueryBuffer.cs +++ b/src/EFCore/Query/Internal/IQueryBuffer.cs @@ -90,20 +90,22 @@ Task IncludeCollectionAsync( /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// - TCollection CorrelateSubquery( + TCollection CorrelateSubquery( int correlatedCollectionId, [NotNull] INavigation navigation, [NotNull] Func resultCollectionFactory, in MaterializedAnonymousObject outerKey, bool tracking, [NotNull] Func>> correlatedCollectionFactory, - [NotNull] Func correlationPredicate) where TCollection : ICollection; + [NotNull] Func correlationPredicate) + where TCollection : ICollection + where TInner : TOut; /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// - Task CorrelateSubqueryAsync( + Task CorrelateSubqueryAsync( int correlatedCollectionId, [NotNull] INavigation navigation, [NotNull] Func resultCollectionFactory, @@ -111,6 +113,8 @@ Task CorrelateSubqueryAsync( bool tracking, [NotNull] Func>> correlatedCollectionFactory, [NotNull] Func correlationPredicate, - CancellationToken cancellationToken) where TCollection : ICollection; + CancellationToken cancellationToken) + where TCollection : ICollection + where TInner : TOut; } } diff --git a/src/EFCore/Query/Internal/QueryBuffer.cs b/src/EFCore/Query/Internal/QueryBuffer.cs index 67cab7c0b6e..4e2cae78f9e 100644 --- a/src/EFCore/Query/Internal/QueryBuffer.cs +++ b/src/EFCore/Query/Internal/QueryBuffer.cs @@ -498,14 +498,16 @@ private IWeakReferenceIdentityMap GetOrCreateIdentityMap(IKey key) /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// - public virtual TCollection CorrelateSubquery( + public virtual TCollection CorrelateSubquery( int correlatedCollectionId, INavigation navigation, Func resultCollectionFactory, in MaterializedAnonymousObject outerKey, bool tracking, Func>> correlatedCollectionFactory, - Func correlationPredicate) where TCollection : ICollection + Func correlationPredicate) + where TCollection : ICollection + where TInner: TOut { IDisposable untypedEnumerator = null; IEnumerator> enumerator = null; @@ -594,7 +596,7 @@ public virtual TCollection CorrelateSubquery( /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// - public virtual async Task CorrelateSubqueryAsync( + public virtual async Task CorrelateSubqueryAsync( int correlatedCollectionId, INavigation navigation, Func resultCollectionFactory, @@ -602,7 +604,9 @@ public virtual async Task CorrelateSubqueryAsync>> correlatedCollectionFactory, Func correlationPredicate, - CancellationToken cancellationToken) where TCollection : ICollection + CancellationToken cancellationToken) + where TCollection : ICollection + where TInner: TOut { IDisposable untypedEnumerator = null; IAsyncEnumerator> enumerator = null; diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/AsyncGearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/AsyncGearsOfWarQuerySqlServerTest.cs index a02b6ecc85b..a6dffdb4d91 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/AsyncGearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/AsyncGearsOfWarQuerySqlServerTest.cs @@ -2357,6 +2357,48 @@ WHERE [o.Reports].[Discriminator] IN (N'Officer', N'Gear') AND ([o.Reports].[Has ORDER BY [t].[c], [t].[Nickname], [t].[SquadId]"); } + public override async Task Cast_subquery_to_base_type_using_typed_ToList() + { + await base.Cast_subquery_to_base_type_using_typed_ToList(); + + AssertSql( + @"SELECT [c].[Name] +FROM [Cities] AS [c] +WHERE [c].[Name] = N'Ephyra' +ORDER BY [c].[Name]", + // + @"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName] +FROM [Gears] AS [c.StationedGears] +INNER JOIN ( + SELECT [c0].[Name] + FROM [Cities] AS [c0] + WHERE [c0].[Name] = N'Ephyra' +) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name] +WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear') +ORDER BY [t].[Name]"); + } + + public override async Task Cast_ordered_subquery_to_base_type_using_typed_ToArray() + { + await base.Cast_ordered_subquery_to_base_type_using_typed_ToArray(); + + AssertSql( + @"SELECT [c].[Name] +FROM [Cities] AS [c] +WHERE [c].[Name] = N'Ephyra' +ORDER BY [c].[Name]", + // + @"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName] +FROM [Gears] AS [c.StationedGears] +INNER JOIN ( + SELECT [c0].[Name] + FROM [Cities] AS [c0] + WHERE [c0].[Name] = N'Ephyra' +) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name] +WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear') +ORDER BY [t].[Name], [c.StationedGears].[Nickname] DESC"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index a76b2af9221..afdb5ec8fe1 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -7274,6 +7274,48 @@ FROM [Weapons] AS [w0] ) AS [t0]"); } + public override void Cast_subquery_to_base_type_using_typed_ToList() + { + base.Cast_subquery_to_base_type_using_typed_ToList(); + + AssertSql( + @"SELECT [c].[Name] +FROM [Cities] AS [c] +WHERE [c].[Name] = N'Ephyra' +ORDER BY [c].[Name]", + // + @"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName] +FROM [Gears] AS [c.StationedGears] +INNER JOIN ( + SELECT [c0].[Name] + FROM [Cities] AS [c0] + WHERE [c0].[Name] = N'Ephyra' +) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name] +WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear') +ORDER BY [t].[Name]"); + } + + public override void Cast_ordered_subquery_to_base_type_using_typed_ToArray() + { + base.Cast_ordered_subquery_to_base_type_using_typed_ToArray(); + + AssertSql( + @"SELECT [c].[Name] +FROM [Cities] AS [c] +WHERE [c].[Name] = N'Ephyra' +ORDER BY [c].[Name]", + // + @"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName] +FROM [Gears] AS [c.StationedGears] +INNER JOIN ( + SELECT [c0].[Name] + FROM [Cities] AS [c0] + WHERE [c0].[Name] = N'Ephyra' +) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name] +WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear') +ORDER BY [t].[Name], [c.StationedGears].[Nickname] DESC"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);