Skip to content

Commit

Permalink
Refactored the code and implemented if any one of the source stats ar…
Browse files Browse the repository at this point in the history
…e unknown, return unknown.
  • Loading branch information
ScrapCodes committed Sep 23, 2024
1 parent f4fd3f1 commit f5d3275
Show file tree
Hide file tree
Showing 4 changed files with 552 additions and 442 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.cost;

import com.facebook.presto.common.type.FixedWidthType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.function.ScalarFunctionStatsUtils;
import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
import com.facebook.presto.spi.function.ScalarStatsHeader;
import com.facebook.presto.spi.function.StatsPropagationBehavior;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS;
import static com.facebook.presto.util.MoreMath.maxExcludingNaNs;
import static com.facebook.presto.util.MoreMath.min;
import static com.facebook.presto.util.MoreMath.minExcludingNaNs;
import static com.facebook.presto.util.MoreMath.sumExcludingNaNs;
import static java.lang.Double.NEGATIVE_INFINITY;
import static java.lang.Double.NaN;
import static java.lang.Double.POSITIVE_INFINITY;
import static java.lang.Double.isFinite;
import static java.lang.Double.isNaN;

public class ScalarStatsAnnotationProcessor
{
private final double outputRowCount;
private double nullFraction = NaN;
private final ScalarStatsHeader scalarStatsHeader;
private final CallExpression callExpression;
private final List<VariableStatsEstimate> sourceStats;

public ScalarStatsAnnotationProcessor(
double outputRowCount,
List<VariableStatsEstimate> sourceStats,
ScalarStatsHeader scalarStatsHeader,
CallExpression callExpression)
{
this.outputRowCount = outputRowCount;
this.sourceStats = sourceStats;
this.scalarStatsHeader = scalarStatsHeader;
this.callExpression = callExpression;
}

public VariableStatsEstimate process()
{
nullFraction = firstNonNaN(scalarStatsHeader.getNullFraction(), nullFraction);
double distinctValuesCount = NaN;
if (isFinite(scalarStatsHeader.getDistinctValuesCount())) {
distinctValuesCount = scalarStatsHeader.getDistinctValuesCount();
if (distinctValuesCount == ScalarFunctionStatsUtils.ROW_COUNT_TIMES_INV_NULL_FRACTION) {
distinctValuesCount = outputRowCount * (1 - nullFraction);
}
else if (distinctValuesCount == ScalarFunctionStatsUtils.ROW_COUNT) {
distinctValuesCount = outputRowCount;
}
}
double averageRowSize = scalarStatsHeader.getAvgRowSize();
double maxValue = scalarStatsHeader.getMax();
double minValue = scalarStatsHeader.getMin();
for (Map.Entry<Integer, ScalarPropagateSourceStats> paramIndexVsStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) {
ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexVsStatsMap.getValue();
boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats();
nullFraction = min(firstNonNaN(nullFraction, processSingleArgumentStatistic(
sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(Collectors.toList()),
paramIndexVsStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0);
distinctValuesCount = firstNonNaN(distinctValuesCount, processDistinctValuesCount(
sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(Collectors.toList()), // an array list !
paramIndexVsStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount())));
averageRowSize = firstNonNaN(averageRowSize, processAvgRowSize(
sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(Collectors.toList()),
paramIndexVsStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize())));
maxValue = firstNonNaN(maxValue, processSingleArgumentStatistic(
sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(Collectors.toList()),
paramIndexVsStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue())));
minValue = firstNonNaN(minValue, processSingleArgumentStatistic(
sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(Collectors.toList()),
paramIndexVsStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue())));
}
if (!isFinite(maxValue) || !isFinite(minValue)) {
minValue = NEGATIVE_INFINITY;
maxValue = POSITIVE_INFINITY;
}
return VariableStatsEstimate.builder()
.setLowValue(minValue)
.setHighValue(maxValue)
.setNullsFraction(nullFraction)
.setAverageRowSize(minExcludingNaNs(averageRowSize, getReturnTypeWidth(callExpression, UNKNOWN)))
.setDistinctValuesCount(distinctValuesCount).build();
}

private double processDistinctValuesCount(List<Double> sourceStats, int sourceStatsArgumentIndex, StatsPropagationBehavior operation)
{
double distinctValuesCount =
processSingleArgumentStatistic(sourceStats, sourceStatsArgumentIndex, operation);
if (isFinite(outputRowCount)) {
distinctValuesCount = min(distinctValuesCount, outputRowCount * (1 - firstNonNaN(nullFraction, 0.0)));
}
return distinctValuesCount;
}

private double processAvgRowSize(List<Double> sourceStats, int sourceStatsArgumentIndex, StatsPropagationBehavior operation)
{
double avgRowSize =
processSingleArgumentStatistic(sourceStats, sourceStatsArgumentIndex, operation);
return minExcludingNaNs(avgRowSize, getReturnTypeWidth(callExpression, operation));
}

private double processSingleArgumentStatistic(
List<Double> sourceStats,
int sourceStatsArgumentIndex,
StatsPropagationBehavior operation)
{
// sourceStatsArgumentIndex is index of the argument on which
// ScalarPropagateSourceStats annotation was applied.
double statValue = NaN;
if (!operation.isSingleArgumentStats()) {
boolean sourceStatUnknown = false;
for (int i = 0; i < sourceStats.size() && !sourceStatUnknown; i++) {
switch (operation) {
case MAX_TYPE_WIDTH_VARCHAR:
statValue = maxExcludingNaNs(statValue, typeWidthVarchar(callExpression, i));
break;
case USE_MIN_ARGUMENT:
sourceStatUnknown = !isFinite(sourceStats.get(i));
statValue = minExcludingNaNs(statValue, sourceStats.get(i));
break;
case USE_MAX_ARGUMENT:
sourceStatUnknown = !isFinite(sourceStats.get(i));
statValue = maxExcludingNaNs(statValue, sourceStats.get(i));
break;
case SUM_ARGUMENTS:
sourceStatUnknown = !isFinite(sourceStats.get(i));
statValue = sumExcludingNaNs(statValue, sourceStats.get(i));
break;
}
}
if (sourceStatUnknown) {
statValue = NaN;
}
}
else {
switch (operation) {
case USE_SOURCE_STATS:
statValue = sourceStats.get(sourceStatsArgumentIndex);
break;
case ROW_COUNT:
statValue = outputRowCount;
break;
case NON_NULL_ROW_COUNT:
statValue = outputRowCount * (1 - nullFraction);
break;
case USE_TYPE_WIDTH_VARCHAR:
statValue = typeWidthVarchar(callExpression, sourceStatsArgumentIndex);
break;
}
}
return statValue;
}

private double typeWidthVarchar(CallExpression call, int argumentIndex)
{
TypeSignature typeSignature = call.getArguments().get(argumentIndex).getType().getTypeSignature();
if (typeSignature.getTypeSignatureBase().hasStandardType() && typeSignature.getTypeSignatureBase().getStandardTypeBase().equals(
StandardTypes.VARCHAR)) {
if (typeSignature.getParameters().size() == 1) { // Varchar type should have 1 parameter i.e. size.;
Long longLiteral = typeSignature.getParameters().get(0).getLongLiteral();
if (longLiteral > 0 && longLiteral != VarcharType.UNBOUNDED_LENGTH) {
return longLiteral;
}
}
}
return NaN;
}

private double getReturnTypeWidth(CallExpression call, StatsPropagationBehavior operation)
{
if (call.getType() instanceof FixedWidthType) {
return ((FixedWidthType) call.getType()).getFixedSize();
}
if (call.getType() instanceof VarcharType) {
VarcharType returnType = (VarcharType) call.getType();
if (!returnType.isUnbounded()) {
return returnType.getLengthSafe();
}
if (operation == SUM_ARGUMENTS) {
// since return type is a varchar and length is unknown,
// getting it by doing a SUM each argument's varchar type bounds.
double sum = 0;
for (RowExpression r : call.getArguments()) {
if (r instanceof CallExpression) { // argument is another function call
sum += getReturnTypeWidth((CallExpression) r, UNKNOWN);
}
if (r.getType() instanceof VarcharType) {
VarcharType argType = (VarcharType) r.getType();
if (!argType.isUnbounded()) {
sum += argType.getLengthSafe();
}
}
}
if (sum > 0) {
return sum;
}
}
}
return NaN;
}

private double firstNonNaN(double... values)
{
double target = NaN;
for (double v : values) {
if (isFinite(v)) {
if (isNaN(target)) {
target = v;
}
}
}
return target;
}

private StatsPropagationBehavior applyPropagateAllStats(
boolean propagateAllStats, StatsPropagationBehavior operation)
{
if (operation == UNKNOWN && propagateAllStats) {
return USE_SOURCE_STATS;
}
return operation;
}
}
Loading

0 comments on commit f5d3275

Please sign in to comment.