Skip to content

Commit

Permalink
Updated following review.
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers committed Nov 26, 2024
1 parent 43cedd8 commit 12ee469
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,46 +443,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
if (updatedSource == null)
{
return null;
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, resultType, _typeMappingSource.FindMapping(resultType));

return AggregateResultShaper(updatedSource, projection, resultType);
}

private ShapedQueryExpression? TranslateAggregateCommon(
ShapedQueryExpression source,
LambdaExpression? selector,
Type resultType,
out SelectExpression selectExpression)
{
selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

if (resultType.IsNullableType())
{
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
}

return source;
}
=> TranslateAggregate(source, selector, resultType, "AVG");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -862,19 +823,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
if (updatedSource == null)
{
return null;
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(updatedSource, projection, resultType);
}
=> TranslateAggregate(source, selector, resultType, "MAX");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -883,19 +832,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var updatedSource = TranslateAggregateCommon(source, selector, resultType, out var selectExpression);
if (updatedSource == null)
{
return null;
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(updatedSource, projection, resultType);
}
=> TranslateAggregate(source, selector, resultType, "MIN");

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -1522,6 +1459,33 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s

#endregion Queryable collection support

private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName)
{
var selectExpression = (SelectExpression)source.QueryExpression;
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
{
return null;
}

if (selector != null)
{
source = TranslateSelect(source, selector);
}

if (!_subquery && resultType.IsNullableType())
{
// For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType));

return AggregateResultShaper(source, projection, resultType);
}

private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate)
{
var select = (SelectExpression)source.QueryExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
ShapedQueryExpression source,
LambdaExpression? selector,
Type resultType)
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType)
Expand Down Expand Up @@ -971,7 +971,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType);
}

/// <inheritdoc />
Expand All @@ -990,7 +990,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1241,7 +1241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType);
=> TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType);

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count)
Expand Down Expand Up @@ -1966,7 +1966,6 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
ShapedQueryExpression source,
LambdaExpression? selectorLambda,
Func<Type, MethodInfo> methodGenerator,
bool throwWhenEmpty,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand Down Expand Up @@ -2012,48 +2011,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
new Dictionary<ProjectionMember, Expression> { { new ProjectionMember(), translation } });

selectExpression.ClearOrdering();
Expression shaper;

if (throwWhenEmpty)
{
// Avg/Max/Min case.
// We always read nullable value
// If resultType is nullable then we always return null. Only non-null result shows throwing behavior.
// otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default
// otherwise, server would return null only if it is empty, and we throw
var nullableResultType = resultType.MakeNullable();
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType);
var resultVariable = Expression.Variable(nullableResultType, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Default(resultType)
: translation.Type.IsNullableType()
? Expression.Default(resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.SequenceContainsNoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(nullableResultType)),
returnValueForNull,
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: resultVariable));
}
else
{
// Sum case. Projection is always non-null. We read nullable value.
shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());

if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}
// Sum case. Projection is always non-null. We read nullable value.
Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable());

if (resultType != shaper.Type)
{
shaper = Expression.Convert(shaper, resultType);
}

return source.UpdateShaperExpression(shaper);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.ComponentModel.DataAnnotations.Schema;

namespace Microsoft.EntityFrameworkCore.Query;

#nullable disable
Expand Down Expand Up @@ -50,6 +52,113 @@ public enum MemberType

#endregion 34911

#region 35094

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Min_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MinAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal(3.14, await context.Set<Context35094.Product>().MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Equal("Value", await context.Set<Context35094.Product>().MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Max_over_reference_type_containing_no_data()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().AverageAsync(p => p.NullableVal));
}

[ConditionalFact]
public virtual async Task Average_over_value_type_containing_all_nulls()
{
await using var context = (await InitializeAsync<Context35094>()).CreateContext();
Assert.Null(await context.Set<Context35094.Product>().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal));
}

protected class Context35094(DbContextOptions options) : DbContext(options)
{
public DbSet<Product> Products { get; set; }

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<Product>().HasData(
new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 },
new Product { Id = 2, NullableVal = 3.14 },
new Product { Id = 3, NullableRef = "Value" });

public class Product
{
[DatabaseGenerated(DatabaseGeneratedOption.None)]
public int Id { get; set; }
public double? NullableVal { get; set; }
public string NullableRef { get; set; }
}
}

#endregion 35094

protected override string StoreName
=> "AdHocMiscellaneousQueryTests";

Expand Down

0 comments on commit 12ee469

Please sign in to comment.