Skip to content

Commit

Permalink
feat(eap): Add support for extrapolation (#6536)
Browse files Browse the repository at this point in the history
Added support for EXTRAPOLATION_MODE_SAMPLE_WEIGHTED
  • Loading branch information
davidtsuk authored Nov 12, 2024
1 parent b80738f commit c3b10a7
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ query_processors:
- sum
- count
- avg
- avgWeighted
- max
- min
- uniq
curried_aggregation_names:
- quantile
- quantileTDigestWeighted
- processor: HashBucketFunctionTransformer
args:
hash_bucket_names:
Expand Down
97 changes: 53 additions & 44 deletions snuba/query/processors/logical/optional_attribute_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,66 +37,75 @@ def __init__(
self._curried_aggregation_names = curried_aggregation_names

def process_query(self, query: Query, query_settings: QuerySettings) -> None:
def is_attribute_expression(exp_parameters: tuple[Expression, ...]) -> bool:
if len(exp_parameters) != 1:
return False
func_param = exp_parameters[0]
if not isinstance(func_param, SubscriptableReference):
return False
if func_param.column.column_name in self._attribute_column_names:
return True
return False
def find_subscriptable_reference(
exp: Expression,
) -> SubscriptableReference | None:
# Recursively find the SubscriptableReference in nested expressions
if (
isinstance(exp, SubscriptableReference)
and exp.column.column_name in self._attribute_column_names
):
return exp
elif isinstance(exp, FunctionCall) and exp.parameters:
for param in exp.parameters:
result = find_subscriptable_reference(param)
if result:
return result
elif isinstance(exp, CurriedFunctionCall):
for param in exp.parameters:
result = find_subscriptable_reference(param)
if result:
return result
return None

def transform_aggregates_to_conditionals(exp: Expression) -> Expression:
if (
isinstance(exp, FunctionCall)
and exp.function_name in self._aggregation_names
and len(exp.parameters) == 1
and is_attribute_expression(exp.parameters)
):
assert isinstance(exp.parameters[0], SubscriptableReference)
return FunctionCall(
alias=exp.alias,
function_name=f"{exp.function_name}If",
parameters=(
exp.parameters[0],
FunctionCall(
alias=None,
function_name="mapContains",
parameters=(
exp.parameters[0].column,
exp.parameters[0].key,
),
),
),
)

elif isinstance(exp, CurriedFunctionCall):
if (
exp.internal_function.function_name
in self._curried_aggregation_names
and is_attribute_expression(exp.parameters)
):
assert isinstance(exp.parameters[0], SubscriptableReference)
return CurriedFunctionCall(
subscriptable_ref = find_subscriptable_reference(exp)
if subscriptable_ref:
return FunctionCall(
alias=exp.alias,
internal_function=FunctionCall(
alias=None,
function_name=f"{exp.internal_function.function_name}If",
parameters=exp.internal_function.parameters,
),
function_name=f"{exp.function_name}If",
parameters=(
exp.parameters[0],
*exp.parameters,
FunctionCall(
alias=None,
function_name="mapContains",
parameters=(
exp.parameters[0].column,
exp.parameters[0].key,
subscriptable_ref.column,
subscriptable_ref.key,
),
),
),
)
elif isinstance(exp, CurriedFunctionCall):
if (
exp.internal_function.function_name
in self._curried_aggregation_names
):
subscriptable_ref = find_subscriptable_reference(exp)
if subscriptable_ref:
return CurriedFunctionCall(
alias=exp.alias,
internal_function=FunctionCall(
alias=None,
function_name=f"{exp.internal_function.function_name}If",
parameters=exp.internal_function.parameters,
),
parameters=(
*exp.parameters,
FunctionCall(
alias=None,
function_name="mapContains",
parameters=(
subscriptable_ref.column,
subscriptable_ref.key,
),
),
),
)

return exp

Expand Down
100 changes: 72 additions & 28 deletions snuba/web/rpc/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
AttributeAggregation,
AttributeKey,
ExtrapolationMode,
Function,
VirtualColumnContext,
)
Expand All @@ -17,17 +18,13 @@
from snuba.query.conditions import combine_and_conditions, combine_or_conditions
from snuba.query.dsl import CurriedFunctions as cf
from snuba.query.dsl import Functions as f
from snuba.query.dsl import (
_CurriedFunctionCall,
_FunctionCall,
and_cond,
column,
in_cond,
literal,
literals_array,
or_cond,
from snuba.query.dsl import and_cond, column, in_cond, literal, literals_array, or_cond
from snuba.query.expressions import (
CurriedFunctionCall,
Expression,
FunctionCall,
SubscriptableReference,
)
from snuba.query.expressions import Expression, FunctionCall, SubscriptableReference
from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException


Expand Down Expand Up @@ -77,29 +74,76 @@ def transform(exp: Expression) -> Expression:


def aggregation_to_expression(aggregation: AttributeAggregation) -> Expression:
function_map: dict[Function.ValueType, _CurriedFunctionCall | _FunctionCall] = {
Function.FUNCTION_SUM: f.sum,
Function.FUNCTION_AVERAGE: f.avg,
Function.FUNCTION_COUNT: f.count,
Function.FUNCTION_P50: cf.quantile(0.5),
Function.FUNCTION_P90: cf.quantile(0.9),
Function.FUNCTION_P95: cf.quantile(0.95),
Function.FUNCTION_P99: cf.quantile(0.99),
Function.FUNCTION_AVG: f.avg,
Function.FUNCTION_MAX: f.max,
Function.FUNCTION_MIN: f.min,
Function.FUNCTION_UNIQ: f.uniq,
field = attribute_key_to_expression(aggregation.key)
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
function_map: dict[Function.ValueType, CurriedFunctionCall | FunctionCall] = {
Function.FUNCTION_SUM: f.sum(field, **alias_dict),
Function.FUNCTION_AVERAGE: f.avg(field, **alias_dict),
Function.FUNCTION_COUNT: f.count(field, **alias_dict),
Function.FUNCTION_P50: cf.quantile(0.5)(field, **alias_dict),
Function.FUNCTION_P90: cf.quantile(0.9)(field, **alias_dict),
Function.FUNCTION_P95: cf.quantile(0.95)(field, **alias_dict),
Function.FUNCTION_P99: cf.quantile(0.99)(field, **alias_dict),
Function.FUNCTION_AVG: f.avg(field, **alias_dict),
Function.FUNCTION_MAX: f.max(field, **alias_dict),
Function.FUNCTION_MIN: f.min(field, **alias_dict),
Function.FUNCTION_UNIQ: f.uniq(field, **alias_dict),
}

sampling_weight_column = column("sampling_weight")
function_map_sample_weighted: dict[
Function.ValueType, CurriedFunctionCall | FunctionCall
] = {
Function.FUNCTION_SUM: f.sum(
f.multiply(field, sampling_weight_column), **alias_dict
),
Function.FUNCTION_AVERAGE: f.avgWeighted(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_COUNT: (
f.sumIf(
sampling_weight_column,
f.mapContains(field.column, field.key),
**alias_dict,
) # this is ugly, but we do this because the optional attribute aggregation processor can't handle this case as we are not summing up the actual attribute
if isinstance(field, SubscriptableReference)
else f.sum(sampling_weight_column, **alias_dict)
),
Function.FUNCTION_P50: cf.quantileTDigestWeighted(0.5)(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_P90: cf.quantileTDigestWeighted(0.9)(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_P95: cf.quantileTDigestWeighted(0.95)(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_P99: cf.quantileTDigestWeighted(0.99)(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_AVG: f.weightedAvg(
field, sampling_weight_column, **alias_dict
),
Function.FUNCTION_MAX: f.max(field, **alias_dict),
Function.FUNCTION_MIN: f.min(field, **alias_dict),
Function.FUNCTION_UNIQ: f.uniq(field, **alias_dict),
}

agg_func = function_map.get(aggregation.aggregate)
if agg_func is None:
if (
aggregation.extrapolation_mode
== ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED
):
agg_func_expr = function_map_sample_weighted.get(aggregation.aggregate)
else:
agg_func_expr = function_map.get(aggregation.aggregate)

if agg_func_expr is None:
raise BadSnubaRPCRequestException(
f"Aggregation not specified for {aggregation.key.name}"
)
field = attribute_key_to_expression(aggregation.key)
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
return agg_func(field, **alias_dict)

return agg_func_expr


# These are the columns which aren't stored in attr_str_ nor attr_num_ in clickhouse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@
],
),
),
(
Query(
QueryEntity(EntityKey.EAP_SPANS, ColumnSet([])),
selected_columns=[
SelectedExpression(
"sum(x)",
f.sum(
f.multiply(attr_num["x"], column("sampling_weight")),
alias="sum(x)",
),
),
],
),
Query(
QueryEntity(EntityKey.EAP_SPANS, ColumnSet([])),
selected_columns=[
SelectedExpression(
"sum(x)",
f.sumIf(
f.multiply(
attr_num["x"],
column("sampling_weight"),
),
f.mapContains(column("attr_num", alias="_snuba_attr_num"), "x"),
alias="sum(x)",
),
),
],
),
),
]


Expand All @@ -70,7 +100,7 @@ def test_query_processing(pre_format: Query, expected_query: Query) -> None:
copy = deepcopy(pre_format)
OptionalAttributeAggregationTransformer(
attribute_column_names=["attr_num"],
aggregation_names=["avg"],
aggregation_names=["avg", "sum"],
curried_aggregation_names=["quantile"],
).process_query(copy, HTTPQuerySettings())
assert copy.get_selected_columns() == expected_query.get_selected_columns()
Expand Down
Loading

0 comments on commit c3b10a7

Please sign in to comment.