Skip to content

Commit

Permalink
Fix to #11687 - Query with projection cast on correlated collections …
Browse files Browse the repository at this point in the history
…throws on 2.1.0-preview2

Problem was that when we apply correlated collection optimization we were not accounting for scenario where the result collection element type was different than the input. This can be achieved by using ToList<T> or ToArray<T> etc.

Fix is to flow the collection result type as well as the input type to the CorrelateSubquery method, and construct the resulting collection based on the output type, if they are different.
  • Loading branch information
maumar committed Apr 20, 2018
1 parent 542de55 commit 60e9a43
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1653,5 +1653,43 @@ public virtual async Task Cast_to_derived_type_after_OfType_works()
await AssertQuery<Gear>(
gs => gs.OfType<Officer>().Cast<Officer>());
}

[ConditionalFact]
public virtual async Task Cast_subquery_to_base_type_using_typed_ToList()
{
await AssertQuery<City>(
cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.Select(g => new Officer
{
CityOrBirthName = g.CityOrBirthName,
FullName = g.FullName,
HasSoulPatch = g.HasSoulPatch,
LeaderNickname = g.LeaderNickname,
LeaderSquadId = g.LeaderSquadId,
Nickname = g.Nickname,
Rank = g.Rank,
SquadId = g.SquadId
}).ToList<Gear>()),
assertOrder: true,
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}

[ConditionalFact]
public virtual async Task Cast_ordered_subquery_to_base_type_using_typed_ToArray()
{
await AssertQuery<City>(
cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.OrderByDescending(g => g.Nickname).Select(g => new Officer
{
CityOrBirthName = g.CityOrBirthName,
FullName = g.FullName,
HasSoulPatch = g.HasSoulPatch,
LeaderNickname = g.LeaderNickname,
LeaderSquadId = g.LeaderSquadId,
Nickname = g.Nickname,
Rank = g.Rank,
SquadId = g.SquadId
}).ToArray<Gear>()),
assertOrder: true,
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}
}
}
38 changes: 38 additions & 0 deletions src/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4957,6 +4957,44 @@ public virtual void Select_subquery_distinct_singleordefault_boolean_empty_with_
assertOrder: true);
}

[ConditionalFact]
public virtual void Cast_subquery_to_base_type_using_typed_ToList()
{
AssertQuery<City>(
cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.Select(g => new Officer
{
CityOrBirthName = g.CityOrBirthName,
FullName = g.FullName,
HasSoulPatch = g.HasSoulPatch,
LeaderNickname = g.LeaderNickname,
LeaderSquadId = g.LeaderSquadId,
Nickname = g.Nickname,
Rank = g.Rank,
SquadId = g.SquadId
}).ToList<Gear>()),
assertOrder: true,
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}

[ConditionalFact]
public virtual void Cast_ordered_subquery_to_base_type_using_typed_ToArray()
{
AssertQuery<City>(
cs => cs.Where(c => c.Name == "Ephyra").Select(c => c.StationedGears.OrderByDescending(g => g.Nickname).Select(g => new Officer
{
CityOrBirthName = g.CityOrBirthName,
FullName = g.FullName,
HasSoulPatch = g.HasSoulPatch,
LeaderNickname = g.LeaderNickname,
LeaderSquadId = g.LeaderSquadId,
Nickname = g.Nickname,
Rank = g.Rank,
SquadId = g.SquadId
}).ToArray<Gear>()),
assertOrder: true,
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}

// Remember to add any new tests to Async version of this test class

protected GearsOfWarContext CreateContext() => Fixture.CreateContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,23 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& innerMethodCallExpression.Method.MethodIsClosedFormOf(CollectionNavigationSubqueryInjector.MaterializeCollectionNavigationMethodInfo)
&& innerMethodCallExpression.Arguments[1] is SubQueryExpression subQueryExpression1)
{
return TryRewrite(subQueryExpression1, /*forceToListResult*/ true, out var result)
return TryRewrite(subQueryExpression1, /*forceToListResult*/ true, methodCallExpression.Type.GetSequenceType(), out var result)
? result
: methodCallExpression;
}

if (methodCallExpression.Method.MethodIsClosedFormOf(_toListMethodInfo)
&& methodCallExpression.Arguments[0] is SubQueryExpression subQueryExpression2)
{
return TryRewrite(subQueryExpression2, /*forceToListResult*/ true, out var result)
return TryRewrite(subQueryExpression2, /*forceToListResult*/ true, methodCallExpression.Type.GetSequenceType(), out var result)
? result
: methodCallExpression;
}

if (methodCallExpression.Method.MethodIsClosedFormOf(CollectionNavigationSubqueryInjector.MaterializeCollectionNavigationMethodInfo)
&& methodCallExpression.Arguments[1] is SubQueryExpression subQueryExpression3)
{
return TryRewrite(subQueryExpression3, /*forceToListResult*/ false, out var result)
return TryRewrite(subQueryExpression3, /*forceToListResult*/ false, /* listResultElementType */ null, out var result)
? result
: methodCallExpression;
}
Expand All @@ -120,11 +120,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override Expression VisitSubQuery(SubQueryExpression subQueryExpression)
=> TryRewrite(subQueryExpression, /*forceToListResult*/ false, out var result)
=> TryRewrite(subQueryExpression, /*forceToListResult*/ false, /* listResultElementType */ null, out var result)
? result
: base.VisitSubQuery(subQueryExpression);

private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListResult, out Expression result)
private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListResult, Type listResultElementType, out Expression result)
{
if (_queryCompilationContext.TryGetCorrelatedSubqueryMetadata(subQueryExpression.QueryModel.MainFromClause, out var correlatedSubqueryMetadata))
{
Expand All @@ -135,7 +135,8 @@ private bool TryRewrite(SubQueryExpression subQueryExpression, bool forceToListR
correlatedSubqueryMetadata.CollectionNavigation,
correlatedSubqueryMetadata.TrackingQuery,
parentQsre,
forceToListResult);
forceToListResult,
listResultElementType);

return true;
}
Expand All @@ -151,7 +152,8 @@ private Expression Rewrite(
INavigation navigation,
bool trackingQuery,
QuerySourceReferenceExpression originQuerySource,
bool forceListResult)
bool forceListResult,
Type listResultElementType)
{
var querySourceReferenceFindingExpressionTreeVisitor
= new QuerySourceReferenceFindingExpressionVisitor();
Expand Down Expand Up @@ -294,8 +296,10 @@ var joinQuerySourceReferenceExpression
var newOriginKeyElements = ((NewArrayExpression)(((NewExpression)newOriginKey).Arguments[0])).Expressions;
var remappedOriginKeyElements = RemapOriginKeyExpressions(newOriginKeyElements, joinQuerySourceReferenceExpression, subQueryProjection);

var collectionQueryModelSelectorType = collectionQueryModel.SelectClause.Selector.Type;

var tupleCtor = typeof(Tuple<,,>).MakeGenericType(
collectionQueryModel.SelectClause.Selector.Type,
collectionQueryModelSelectorType,
typeof(MaterializedAnonymousObject),
typeof(MaterializedAnonymousObject)).GetConstructors().FirstOrDefault();

Expand All @@ -307,14 +311,16 @@ var joinQuerySourceReferenceExpression

Expression resultCollectionFactoryExpressionBody;
if (forceListResult
|| navigation.ForeignKey.DeclaringEntityType.ClrType != collectionQueryModel.SelectClause.Selector.Type)
|| navigation.ForeignKey.DeclaringEntityType.ClrType != collectionQueryModelSelectorType)
{
var resultCollectionType = typeof(List<>).MakeGenericType(collectionQueryModel.SelectClause.Selector.Type);
listResultElementType = listResultElementType ?? collectionQueryModelSelectorType;
var resultCollectionType = typeof(List<>).MakeGenericType(listResultElementType);
var resultCollectionCtor = resultCollectionType.GetTypeInfo().GetDeclaredConstructor(Array.Empty<Type>());

correlateSubqueryMethod = correlateSubqueryMethod.MakeGenericMethod(
collectionQueryModel.SelectClause.Selector.Type,
typeof(List<>).MakeGenericType(collectionQueryModel.SelectClause.Selector.Type));
collectionQueryModelSelectorType,
listResultElementType,
typeof(List<>).MakeGenericType(listResultElementType));

resultCollectionFactoryExpressionBody = Expression.New(resultCollectionCtor);

Expand All @@ -323,8 +329,9 @@ var joinQuerySourceReferenceExpression
else
{
correlateSubqueryMethod = correlateSubqueryMethod.MakeGenericMethod(
collectionQueryModel.SelectClause.Selector.Type,
navigation.GetCollectionAccessor().CollectionType);
collectionQueryModelSelectorType,
collectionQueryModelSelectorType,
navigation.GetCollectionAccessor().CollectionType);

resultCollectionFactoryExpressionBody
= Expression.Convert(
Expand Down Expand Up @@ -352,20 +359,20 @@ var joinQuerySourceReferenceExpression
remappedOriginKeyElements))
});

var collectionModelSelectorType = collectionQueryModel.SelectClause.Selector.Type;
collectionQueryModelSelectorType = collectionQueryModel.SelectClause.Selector.Type;

// Enumerable or OrderedEnumerable
collectionQueryModel.ResultTypeOverride = collectionQueryModel.BodyClauses.OfType<OrderByClause>().Any()
? typeof(IOrderedEnumerable<>).MakeGenericType(collectionModelSelectorType)
: typeof(IEnumerable<>).MakeGenericType(collectionModelSelectorType);
? typeof(IOrderedEnumerable<>).MakeGenericType(collectionQueryModelSelectorType)
: typeof(IEnumerable<>).MakeGenericType(collectionQueryModelSelectorType);

var lambda = (Expression)Expression.Lambda(new SubQueryExpression(collectionQueryModel));
if (_queryCompilationContext.IsAsyncQuery)
{
lambda = Expression.Convert(
lambda,
typeof(Func<>).MakeGenericType(
typeof(IAsyncEnumerable<>).MakeGenericType(collectionModelSelectorType)));
typeof(IAsyncEnumerable<>).MakeGenericType(collectionQueryModelSelectorType)));
}

// since we cloned QM, we need to check if it's query sources require materialization (e.g. TypeIs operation for InMemory)
Expand Down
12 changes: 8 additions & 4 deletions src/EFCore/Query/Internal/IQueryBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,31 @@ Task IncludeCollectionAsync<TEntity, TRelated, TElement>(
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
TCollection CorrelateSubquery<TInner, TCollection>(
TCollection CorrelateSubquery<TInner, TOut, TCollection>(
int correlatedCollectionId,
[NotNull] INavigation navigation,
[NotNull] Func<INavigation, TCollection> resultCollectionFactory,
in MaterializedAnonymousObject outerKey,
bool tracking,
[NotNull] Func<IEnumerable<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>>> correlatedCollectionFactory,
[NotNull] Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate) where TCollection : ICollection<TInner>;
[NotNull] Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate)
where TCollection : ICollection<TOut>
where TInner : TOut;

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
Task<TCollection> CorrelateSubqueryAsync<TInner, TCollection>(
Task<TCollection> CorrelateSubqueryAsync<TInner, TOut, TCollection>(
int correlatedCollectionId,
[NotNull] INavigation navigation,
[NotNull] Func<INavigation, TCollection> resultCollectionFactory,
MaterializedAnonymousObject outerKey,
bool tracking,
[NotNull] Func<IAsyncEnumerable<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>>> correlatedCollectionFactory,
[NotNull] Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate,
CancellationToken cancellationToken) where TCollection : ICollection<TInner>;
CancellationToken cancellationToken)
where TCollection : ICollection<TOut>
where TInner : TOut;
}
}
12 changes: 8 additions & 4 deletions src/EFCore/Query/Internal/QueryBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,16 @@ private IWeakReferenceIdentityMap GetOrCreateIdentityMap(IKey key)
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual TCollection CorrelateSubquery<TInner, TCollection>(
public virtual TCollection CorrelateSubquery<TInner, TOut, TCollection>(
int correlatedCollectionId,
INavigation navigation,
Func<INavigation, TCollection> resultCollectionFactory,
in MaterializedAnonymousObject outerKey,
bool tracking,
Func<IEnumerable<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>>> correlatedCollectionFactory,
Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate) where TCollection : ICollection<TInner>
Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate)
where TCollection : ICollection<TOut>
where TInner: TOut
{
IDisposable untypedEnumerator = null;
IEnumerator<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>> enumerator = null;
Expand Down Expand Up @@ -594,15 +596,17 @@ public virtual TCollection CorrelateSubquery<TInner, TCollection>(
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual async Task<TCollection> CorrelateSubqueryAsync<TInner, TCollection>(
public virtual async Task<TCollection> CorrelateSubqueryAsync<TInner, TOut, TCollection>(
int correlatedCollectionId,
INavigation navigation,
Func<INavigation, TCollection> resultCollectionFactory,
MaterializedAnonymousObject outerKey,
bool tracking,
Func<IAsyncEnumerable<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>>> correlatedCollectionFactory,
Func<MaterializedAnonymousObject, MaterializedAnonymousObject, bool> correlationPredicate,
CancellationToken cancellationToken) where TCollection : ICollection<TInner>
CancellationToken cancellationToken)
where TCollection : ICollection<TOut>
where TInner: TOut
{
IDisposable untypedEnumerator = null;
IAsyncEnumerator<Tuple<TInner, MaterializedAnonymousObject, MaterializedAnonymousObject>> enumerator = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,48 @@ WHERE [o.Reports].[Discriminator] IN (N'Officer', N'Gear') AND ([o.Reports].[Has
ORDER BY [t].[c], [t].[Nickname], [t].[SquadId]");
}

public override async Task Cast_subquery_to_base_type_using_typed_ToList()
{
await base.Cast_subquery_to_base_type_using_typed_ToList();

AssertSql(
@"SELECT [c].[Name]
FROM [Cities] AS [c]
WHERE [c].[Name] = N'Ephyra'
ORDER BY [c].[Name]",
//
@"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName]
FROM [Gears] AS [c.StationedGears]
INNER JOIN (
SELECT [c0].[Name]
FROM [Cities] AS [c0]
WHERE [c0].[Name] = N'Ephyra'
) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name]
WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear')
ORDER BY [t].[Name]");
}

public override async Task Cast_ordered_subquery_to_base_type_using_typed_ToArray()
{
await base.Cast_ordered_subquery_to_base_type_using_typed_ToArray();

AssertSql(
@"SELECT [c].[Name]
FROM [Cities] AS [c]
WHERE [c].[Name] = N'Ephyra'
ORDER BY [c].[Name]",
//
@"SELECT [t].[Name], [c.StationedGears].[CityOrBirthName], [c.StationedGears].[FullName], [c.StationedGears].[HasSoulPatch], [c.StationedGears].[LeaderNickname], [c.StationedGears].[LeaderSquadId], [c.StationedGears].[Nickname], [c.StationedGears].[Rank], [c.StationedGears].[SquadId], [c.StationedGears].[AssignedCityName]
FROM [Gears] AS [c.StationedGears]
INNER JOIN (
SELECT [c0].[Name]
FROM [Cities] AS [c0]
WHERE [c0].[Name] = N'Ephyra'
) AS [t] ON [c.StationedGears].[AssignedCityName] = [t].[Name]
WHERE [c.StationedGears].[Discriminator] IN (N'Officer', N'Gear')
ORDER BY [t].[Name], [c.StationedGears].[Nickname] DESC");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down
Loading

0 comments on commit 60e9a43

Please sign in to comment.