-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored the code and implemented if any one of the source stats ar…
…e unknown, return unknown.
- Loading branch information
1 parent
f4fd3f1
commit f5d3275
Showing
4 changed files
with
552 additions
and
442 deletions.
There are no files selected for viewing
255 changes: 255 additions & 0 deletions
255
presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.facebook.presto.cost; | ||
|
||
import com.facebook.presto.common.type.FixedWidthType; | ||
import com.facebook.presto.common.type.StandardTypes; | ||
import com.facebook.presto.common.type.TypeSignature; | ||
import com.facebook.presto.common.type.VarcharType; | ||
import com.facebook.presto.spi.function.ScalarFunctionStatsUtils; | ||
import com.facebook.presto.spi.function.ScalarPropagateSourceStats; | ||
import com.facebook.presto.spi.function.ScalarStatsHeader; | ||
import com.facebook.presto.spi.function.StatsPropagationBehavior; | ||
import com.facebook.presto.spi.relation.CallExpression; | ||
import com.facebook.presto.spi.relation.RowExpression; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; | ||
import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; | ||
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; | ||
import static com.facebook.presto.util.MoreMath.maxExcludingNaNs; | ||
import static com.facebook.presto.util.MoreMath.min; | ||
import static com.facebook.presto.util.MoreMath.minExcludingNaNs; | ||
import static com.facebook.presto.util.MoreMath.sumExcludingNaNs; | ||
import static java.lang.Double.NEGATIVE_INFINITY; | ||
import static java.lang.Double.NaN; | ||
import static java.lang.Double.POSITIVE_INFINITY; | ||
import static java.lang.Double.isFinite; | ||
import static java.lang.Double.isNaN; | ||
|
||
public class ScalarStatsAnnotationProcessor | ||
{ | ||
private final double outputRowCount; | ||
private double nullFraction = NaN; | ||
private final ScalarStatsHeader scalarStatsHeader; | ||
private final CallExpression callExpression; | ||
private final List<VariableStatsEstimate> sourceStats; | ||
|
||
public ScalarStatsAnnotationProcessor( | ||
double outputRowCount, | ||
List<VariableStatsEstimate> sourceStats, | ||
ScalarStatsHeader scalarStatsHeader, | ||
CallExpression callExpression) | ||
{ | ||
this.outputRowCount = outputRowCount; | ||
this.sourceStats = sourceStats; | ||
this.scalarStatsHeader = scalarStatsHeader; | ||
this.callExpression = callExpression; | ||
} | ||
|
||
public VariableStatsEstimate process() | ||
{ | ||
nullFraction = firstNonNaN(scalarStatsHeader.getNullFraction(), nullFraction); | ||
double distinctValuesCount = NaN; | ||
if (isFinite(scalarStatsHeader.getDistinctValuesCount())) { | ||
distinctValuesCount = scalarStatsHeader.getDistinctValuesCount(); | ||
if (distinctValuesCount == ScalarFunctionStatsUtils.ROW_COUNT_TIMES_INV_NULL_FRACTION) { | ||
distinctValuesCount = outputRowCount * (1 - nullFraction); | ||
} | ||
else if (distinctValuesCount == ScalarFunctionStatsUtils.ROW_COUNT) { | ||
distinctValuesCount = outputRowCount; | ||
} | ||
} | ||
double averageRowSize = scalarStatsHeader.getAvgRowSize(); | ||
double maxValue = scalarStatsHeader.getMax(); | ||
double minValue = scalarStatsHeader.getMin(); | ||
for (Map.Entry<Integer, ScalarPropagateSourceStats> paramIndexVsStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) { | ||
ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexVsStatsMap.getValue(); | ||
boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats(); | ||
nullFraction = min(firstNonNaN(nullFraction, processSingleArgumentStatistic( | ||
sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(Collectors.toList()), | ||
paramIndexVsStatsMap.getKey(), | ||
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0); | ||
distinctValuesCount = firstNonNaN(distinctValuesCount, processDistinctValuesCount( | ||
sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(Collectors.toList()), // an array list ! | ||
paramIndexVsStatsMap.getKey(), | ||
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount()))); | ||
averageRowSize = firstNonNaN(averageRowSize, processAvgRowSize( | ||
sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(Collectors.toList()), | ||
paramIndexVsStatsMap.getKey(), | ||
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize()))); | ||
maxValue = firstNonNaN(maxValue, processSingleArgumentStatistic( | ||
sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(Collectors.toList()), | ||
paramIndexVsStatsMap.getKey(), | ||
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue()))); | ||
minValue = firstNonNaN(minValue, processSingleArgumentStatistic( | ||
sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(Collectors.toList()), | ||
paramIndexVsStatsMap.getKey(), | ||
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue()))); | ||
} | ||
if (!isFinite(maxValue) || !isFinite(minValue)) { | ||
minValue = NEGATIVE_INFINITY; | ||
maxValue = POSITIVE_INFINITY; | ||
} | ||
return VariableStatsEstimate.builder() | ||
.setLowValue(minValue) | ||
.setHighValue(maxValue) | ||
.setNullsFraction(nullFraction) | ||
.setAverageRowSize(minExcludingNaNs(averageRowSize, getReturnTypeWidth(callExpression, UNKNOWN))) | ||
.setDistinctValuesCount(distinctValuesCount).build(); | ||
} | ||
|
||
private double processDistinctValuesCount(List<Double> sourceStats, int sourceStatsArgumentIndex, StatsPropagationBehavior operation) | ||
{ | ||
double distinctValuesCount = | ||
processSingleArgumentStatistic(sourceStats, sourceStatsArgumentIndex, operation); | ||
if (isFinite(outputRowCount)) { | ||
distinctValuesCount = min(distinctValuesCount, outputRowCount * (1 - firstNonNaN(nullFraction, 0.0))); | ||
} | ||
return distinctValuesCount; | ||
} | ||
|
||
private double processAvgRowSize(List<Double> sourceStats, int sourceStatsArgumentIndex, StatsPropagationBehavior operation) | ||
{ | ||
double avgRowSize = | ||
processSingleArgumentStatistic(sourceStats, sourceStatsArgumentIndex, operation); | ||
return minExcludingNaNs(avgRowSize, getReturnTypeWidth(callExpression, operation)); | ||
} | ||
|
||
private double processSingleArgumentStatistic( | ||
List<Double> sourceStats, | ||
int sourceStatsArgumentIndex, | ||
StatsPropagationBehavior operation) | ||
{ | ||
// sourceStatsArgumentIndex is index of the argument on which | ||
// ScalarPropagateSourceStats annotation was applied. | ||
double statValue = NaN; | ||
if (!operation.isSingleArgumentStats()) { | ||
boolean sourceStatUnknown = false; | ||
for (int i = 0; i < sourceStats.size() && !sourceStatUnknown; i++) { | ||
switch (operation) { | ||
case MAX_TYPE_WIDTH_VARCHAR: | ||
statValue = maxExcludingNaNs(statValue, typeWidthVarchar(callExpression, i)); | ||
break; | ||
case USE_MIN_ARGUMENT: | ||
sourceStatUnknown = !isFinite(sourceStats.get(i)); | ||
statValue = minExcludingNaNs(statValue, sourceStats.get(i)); | ||
break; | ||
case USE_MAX_ARGUMENT: | ||
sourceStatUnknown = !isFinite(sourceStats.get(i)); | ||
statValue = maxExcludingNaNs(statValue, sourceStats.get(i)); | ||
break; | ||
case SUM_ARGUMENTS: | ||
sourceStatUnknown = !isFinite(sourceStats.get(i)); | ||
statValue = sumExcludingNaNs(statValue, sourceStats.get(i)); | ||
break; | ||
} | ||
} | ||
if (sourceStatUnknown) { | ||
statValue = NaN; | ||
} | ||
} | ||
else { | ||
switch (operation) { | ||
case USE_SOURCE_STATS: | ||
statValue = sourceStats.get(sourceStatsArgumentIndex); | ||
break; | ||
case ROW_COUNT: | ||
statValue = outputRowCount; | ||
break; | ||
case NON_NULL_ROW_COUNT: | ||
statValue = outputRowCount * (1 - nullFraction); | ||
break; | ||
case USE_TYPE_WIDTH_VARCHAR: | ||
statValue = typeWidthVarchar(callExpression, sourceStatsArgumentIndex); | ||
break; | ||
} | ||
} | ||
return statValue; | ||
} | ||
|
||
private double typeWidthVarchar(CallExpression call, int argumentIndex) | ||
{ | ||
TypeSignature typeSignature = call.getArguments().get(argumentIndex).getType().getTypeSignature(); | ||
if (typeSignature.getTypeSignatureBase().hasStandardType() && typeSignature.getTypeSignatureBase().getStandardTypeBase().equals( | ||
StandardTypes.VARCHAR)) { | ||
if (typeSignature.getParameters().size() == 1) { // Varchar type should have 1 parameter i.e. size.; | ||
Long longLiteral = typeSignature.getParameters().get(0).getLongLiteral(); | ||
if (longLiteral > 0 && longLiteral != VarcharType.UNBOUNDED_LENGTH) { | ||
return longLiteral; | ||
} | ||
} | ||
} | ||
return NaN; | ||
} | ||
|
||
private double getReturnTypeWidth(CallExpression call, StatsPropagationBehavior operation) | ||
{ | ||
if (call.getType() instanceof FixedWidthType) { | ||
return ((FixedWidthType) call.getType()).getFixedSize(); | ||
} | ||
if (call.getType() instanceof VarcharType) { | ||
VarcharType returnType = (VarcharType) call.getType(); | ||
if (!returnType.isUnbounded()) { | ||
return returnType.getLengthSafe(); | ||
} | ||
if (operation == SUM_ARGUMENTS) { | ||
// since return type is a varchar and length is unknown, | ||
// getting it by doing a SUM each argument's varchar type bounds. | ||
double sum = 0; | ||
for (RowExpression r : call.getArguments()) { | ||
if (r instanceof CallExpression) { // argument is another function call | ||
sum += getReturnTypeWidth((CallExpression) r, UNKNOWN); | ||
} | ||
if (r.getType() instanceof VarcharType) { | ||
VarcharType argType = (VarcharType) r.getType(); | ||
if (!argType.isUnbounded()) { | ||
sum += argType.getLengthSafe(); | ||
} | ||
} | ||
} | ||
if (sum > 0) { | ||
return sum; | ||
} | ||
} | ||
} | ||
return NaN; | ||
} | ||
|
||
private double firstNonNaN(double... values) | ||
{ | ||
double target = NaN; | ||
for (double v : values) { | ||
if (isFinite(v)) { | ||
if (isNaN(target)) { | ||
target = v; | ||
} | ||
} | ||
} | ||
return target; | ||
} | ||
|
||
private StatsPropagationBehavior applyPropagateAllStats( | ||
boolean propagateAllStats, StatsPropagationBehavior operation) | ||
{ | ||
if (operation == UNKNOWN && propagateAllStats) { | ||
return USE_SOURCE_STATS; | ||
} | ||
return operation; | ||
} | ||
} |
Oops, something went wrong.