Skip to content

Commit

Permalink
Add noisy_avg_gaussian aggregation
Browse files Browse the repository at this point in the history
This commit adds `noisy_avg_gaussian` aggregation. It can be used to replace `avg(col)` with `noisy_avg_gaussian(col, noiseScale[, lower, upper][, randomSeed])`.

This is one of aggregations in our effort to add Presto UDF for noisy aggregations, used as building block for differential privacy in Presto.

`col` can be of numerical types: INT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, DECIMAL.

Because noise is of type `double`, all values are converted to `double` before being added to the avg, and the return type is `double`.

When a bound [lower, upper] is provided, each value is clipped to this range before being added to the sum (which is later used to compute the avg).

Optional randomSeed is used to get a fixed value of noise, often for reproducibility purposes. If randomSeed is omitted, SecureRandom is used. If randomSeed is provided, Random is used.

Why we want these functions:
The purpose is to help build systems/tools/framework that provide differential privacy guarantees. Differential privacy has been used by multiple teams within Meta to develop privacy-preserving systems. Current implementation involves complicated SQL operation even for simplest aggregations, increasing development time, complexity, maintenance and sharing cost, and sometimes completely blocking development of new features.

While these functions on their own do not guarantee 100% differential privacy, they are the building blocks for other systems. That is also why we do not call these functions “differentially private aggregations” but only “noisy aggregations” to avoid a wrong impression of achieving differential privacy solely by using these functions.
  • Loading branch information
duykienvp authored and pranjalssh committed Sep 19, 2023
1 parent 1bf8afe commit 92f51c0
Show file tree
Hide file tree
Showing 21 changed files with 3,502 additions and 148 deletions.
59 changes: 59 additions & 0 deletions presto-docs/src/main/sphinx/functions/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,65 @@ Approximate Aggregate Functions
SELECT noisy_sum_gaussian(orderkey, 20.0, 10.0, 50.0, 321) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row)
SELECT noisy_sum_gaussian(orderkey, 20.0, 10.0, 50.0, 321) FROM tpch.tiny.lineitem WHERE false GROUP BY orderkey; -- (0 row)

.. function:: noisy_avg_gaussian(x, noise_scale) -> double

Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise
with 0 mean and standard deviation of ``noise_scale``.
All values are converted to double before being added to the avg, and the return type is double.

When there are no input rows, this function returns ``NULL``.

Noise is from a secure random. ::

SELECT noisy_avg_gaussian(orderkey, 20.0) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row)
SELECT noisy_avg_gaussian(orderkey, 20.0) FROM tpch.tiny.lineitem WHERE false GROUP BY orderkey; -- (0 row)

.. function:: noisy_avg_gaussian(x, noise_scale, random_seed) -> double

Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise
with 0 mean and standard deviation of ``noise_scale``.
All values are converted to double before being added to the avg, and the return type is double.

When there are no input rows, this function returns ``NULL``.

Random seed is used to seed the random generator.
This method does not use a secure random. ::

SELECT noisy_avg_gaussian(orderkey, 20.0, 321) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row)
SELECT noisy_avg_gaussian(orderkey, 20.0, 321) FROM tpch.tiny.lineitem WHERE false GROUP BY orderkey; -- (0 row)

.. function:: noisy_avg_gaussian(x, noise_scale, lower, upper) -> double

Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise
with 0 mean and standard deviation of ``noise_scale``.
All values are converted to double before being added to the avg, and the return type is double.

Each value is clipped to the range of ``[lower, upper]`` before adding to the avg.

When there are no input rows, this function returns ``NULL``.

Noise is from a secure random. ::

SELECT noisy_avg_gaussian(orderkey, 20.0, 10.0, 50.0) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row)
SELECT noisy_avg_gaussian(orderkey, 20.0, 10.0, 51.0) FROM tpch.tiny.lineitem WHERE false GROUP BY orderkey; -- (0 row)

.. function:: noisy_avg_gaussian(x, noise_scale, lower, upper, random_seed) -> double

Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise
with 0 mean and standard deviation of ``noise_scale``.
All values are converted to double before being added to the avg, and the return type is double.

Each value is clipped to the range of ``[lower, upper]`` before adding to the avg.

When there are no input rows, this function returns ``NULL``.

Random seed is used to seed the random generator.
This method does not use a secure random. ::

SELECT noisy_avg_gaussian(orderkey, 20.0, 10.0, 50.0, 321) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row)
SELECT noisy_avg_gaussian(orderkey, 20.0, 10.0, 50.0, 321) FROM tpch.tiny.lineitem WHERE false GROUP BY orderkey; -- (0 row)


Statistical Aggregate Functions
-------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@
import static com.facebook.presto.operator.aggregation.minmaxby.MaxByNAggregationFunction.MAX_BY_N_AGGREGATION;
import static com.facebook.presto.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY;
import static com.facebook.presto.operator.aggregation.minmaxby.MinByNAggregationFunction.MIN_BY_N_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyAverageGaussianAggregation.NOISY_AVERAGE_GAUSSIAN_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyAverageGaussianClippingAggregation.NOISY_AVERAGE_GAUSSIAN_CLIPPING_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyAverageGaussianClippingRandomSeedAggregation.NOISY_AVERAGE_GAUSSIAN_CLIPPING_RANDOM_SEED_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyAverageGaussianRandomSeedAggregation.NOISY_AVERAGE_GAUSSIAN_RANDOM_SEED_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountGaussianColumnAggregation.NOISY_COUNT_GAUSSIAN_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountGaussianColumnRandomSeedAggregation.NOISY_COUNT_GAUSSIAN_RANDOM_SEED_AGGREGATION;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisySumGaussianAggregation.NOISY_SUM_GAUSSIAN_AGGREGATION;
Expand Down Expand Up @@ -669,6 +673,10 @@ private List<? extends SqlFunction> getBuildInFunctions(FeaturesConfig featuresC
.function(NOISY_SUM_GAUSSIAN_RANDOM_SEED_AGGREGATION)
.function(NOISY_SUM_GAUSSIAN_CLIPPING_AGGREGATION)
.function(NOISY_SUM_GAUSSIAN_CLIPPING_RANDOM_SEED_AGGREGATION)
.function(NOISY_AVERAGE_GAUSSIAN_AGGREGATION)
.function(NOISY_AVERAGE_GAUSSIAN_CLIPPING_AGGREGATION)
.function(NOISY_AVERAGE_GAUSSIAN_CLIPPING_RANDOM_SEED_AGGREGATION)
.function(NOISY_AVERAGE_GAUSSIAN_RANDOM_SEED_AGGREGATION)
.function(REAL_AVERAGE_AGGREGATION)
.aggregates(IntervalDayToSecondAverageAggregation.class)
.aggregates(IntervalYearToMonthAverageAggregation.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.noisyaggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.SmallintType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.TinyintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AccumulatorCompiler;
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.collect.ImmutableList;

import java.lang.invoke.MethodHandle;
import java.util.List;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.Decimals.MAX_PRECISION;
import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimal;
import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.combineStates;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.updateState;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.writeNoisyAvgOutput;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Float.intBitsToFloat;

public class NoisyAverageGaussianAggregation
extends SqlAggregationFunction
{
// Constant references for short/long decimal types for use in operations that only manipulate unscaled values
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0);
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0);

public static final NoisyAverageGaussianAggregation NOISY_AVERAGE_GAUSSIAN_AGGREGATION = new NoisyAverageGaussianAggregation();
private static final String NAME = "noisy_avg_gaussian";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputShortDecimal", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputLongDecimal", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputDouble", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle REAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputReal", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle BIGINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputBigInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle INTEGER_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputInteger", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle SMALLINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputSmallInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class);
private static final MethodHandle TINYINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputTinyInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class);

private static final MethodHandle OUTPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "output", NoisyCountAndSumState.class, BlockBuilder.class);

private static final MethodHandle COMBINE_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "combine", NoisyCountAndSumState.class, NoisyCountAndSumState.class);

public NoisyAverageGaussianAggregation()
{
super(NAME,
ImmutableList.of(typeVariable("T")),
ImmutableList.of(),
parseTypeSignature(StandardTypes.DOUBLE),
ImmutableList.of(parseTypeSignature("T"),
parseTypeSignature(StandardTypes.DOUBLE)),
FunctionKind.AGGREGATE);
}

@Override
public String getDescription()
{
return "Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise.";
}

@Override
public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
Type type = boundVariables.getTypeVariable("T");
return generateAggregation(type);
}

private static BuiltInAggregationFunctionImplementation generateAggregation(Type type)
{
DynamicClassLoader classLoader = new DynamicClassLoader(NoisyAverageGaussianAggregation.class.getClassLoader());

AccumulatorStateSerializer<?> stateSerializer = new NoisyCountAndSumStateSerializer();
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(NoisyCountAndSumState.class, classLoader);
List<Type> inputTypes = ImmutableList.of(type, DOUBLE);

MethodHandle inputFunction;
if (type instanceof DecimalType) {
inputFunction = ((DecimalType) type).isShort() ? SHORT_DECIMAL_INPUT_FUNCTION : LONG_DECIMAL_INPUT_FUNCTION;
}
else if (type instanceof TinyintType) {
inputFunction = TINYINT_INPUT_FUNCTION;
}
else if (type instanceof SmallintType) {
inputFunction = SMALLINT_INPUT_FUNCTION;
}
else if (type instanceof IntegerType) {
inputFunction = INTEGER_INPUT_FUNCTION;
}
else if (type instanceof BigintType) {
inputFunction = BIGINT_INPUT_FUNCTION;
}
else if (type instanceof RealType) {
inputFunction = REAL_INPUT_FUNCTION;
}
else {
inputFunction = DOUBLE_INPUT_FUNCTION;
}

AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, DOUBLE.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
createInputParameterMetadata(type),
inputFunction,
COMBINE_FUNCTION,
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor(
NoisyCountAndSumState.class,
stateSerializer,
stateFactory)),
DOUBLE);

Type intermediateType = stateSerializer.getSerializedType();

Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
Accumulator.class,
metadata,
classLoader);
Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
GroupedAccumulator.class,
metadata,
classLoader);
return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), DOUBLE,
true, false, metadata, accumulatorClass, groupedAccumulatorClass);
}

private static List<ParameterMetadata> createInputParameterMetadata(Type type)
{
return ImmutableList.of(
new ParameterMetadata(STATE),
new ParameterMetadata(BLOCK_INPUT_CHANNEL, type),
new ParameterMetadata(BLOCK_INPUT_CHANNEL, DOUBLE),
new ParameterMetadata(BLOCK_INDEX));
}

public static void inputShortDecimal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = unscaledDecimalToBigInteger(
unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(valueBlock, position)))
.doubleValue();
double noiseScale = DOUBLE.getDouble(noiseScaleBlock, position);

updateState(state, value, noiseScale, null, null, null);
}

public static void inputLongDecimal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = unscaledDecimalToBigInteger(
unscaledDecimal(LONG_DECIMAL_TYPE.getSlice(valueBlock, position)))
.doubleValue();
input(state, value, noiseScaleBlock, position);
}

public static void inputDouble(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = DOUBLE.getDouble(valueBlock, position);
input(state, value, noiseScaleBlock, position);
}

public static void inputReal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = intBitsToFloat((int) REAL.getLong(valueBlock, position));
input(state, value, noiseScaleBlock, position);
}

public static void inputBigInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = BIGINT.getLong(valueBlock, position);
input(state, value, noiseScaleBlock, position);
}

public static void inputInteger(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = INTEGER.getLong(valueBlock, position);
input(state, value, noiseScaleBlock, position);
}

public static void inputSmallInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = SMALLINT.getLong(valueBlock, position);
input(state, value, noiseScaleBlock, position);
}

public static void inputTinyInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position)
{
double value = TINYINT.getLong(valueBlock, position);
input(state, value, noiseScaleBlock, position);
}

private static void input(NoisyCountAndSumState state, double value, Block noiseScaleBlock, int position)
{
double noiseScale = DOUBLE.getDouble(noiseScaleBlock, position);

updateState(state, value, noiseScale, null, null, null);
}

public static void combine(NoisyCountAndSumState state, NoisyCountAndSumState otherState)
{
combineStates(state, otherState);
}

public static void output(NoisyCountAndSumState state, BlockBuilder out)
{
writeNoisyAvgOutput(state, out);
}
}
Loading

0 comments on commit 92f51c0

Please sign in to comment.