-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit adds `noisy_avg_gaussian` aggregation. It can be used to replace `avg(col)` with `noisy_avg_gaussian(col, noiseScale[, lower, upper][, randomSeed])`. This is one of aggregations in our effort to add Presto UDF for noisy aggregations, used as building block for differential privacy in Presto. `col` can be of numerical types: INT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, DECIMAL. Because noise is of type `double`, all values are converted to `double` before being added to the avg, and the return type is `double`. When a bound [lower, upper] is provided, each value is clipped to this range before being added to the sum (which is later used to compute the avg). Optional randomSeed is used to get a fixed value of noise, often for reproducibility purposes. If randomSeed is omitted, SecureRandom is used. If randomSeed is provided, Random is used. Why we want these functions: The purpose is to help build systems/tools/framework that provide differential privacy guarantees. Differential privacy has been used by multiple teams within Meta to develop privacy-preserving systems. Current implementation involves complicated SQL operation even for simplest aggregations, increasing development time, complexity, maintenance and sharing cost, and sometimes completely blocking development of new features. While these functions on their own do not guarantee 100% differential privacy, they are the building blocks for other systems. That is also why we do not call these functions “differentially private aggregations” but only “noisy aggregations” to avoid a wrong impression of achieving differential privacy solely by using these functions.
- Loading branch information
1 parent
1bf8afe
commit 92f51c0
Showing
21 changed files
with
3,502 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
251 changes: 251 additions & 0 deletions
251
...acebook/presto/operator/aggregation/noisyaggregation/NoisyAverageGaussianAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto.operator.aggregation.noisyaggregation; | ||
|
||
import com.facebook.presto.bytecode.DynamicClassLoader; | ||
import com.facebook.presto.common.block.Block; | ||
import com.facebook.presto.common.block.BlockBuilder; | ||
import com.facebook.presto.common.type.BigintType; | ||
import com.facebook.presto.common.type.DecimalType; | ||
import com.facebook.presto.common.type.IntegerType; | ||
import com.facebook.presto.common.type.RealType; | ||
import com.facebook.presto.common.type.SmallintType; | ||
import com.facebook.presto.common.type.StandardTypes; | ||
import com.facebook.presto.common.type.TinyintType; | ||
import com.facebook.presto.common.type.Type; | ||
import com.facebook.presto.metadata.BoundVariables; | ||
import com.facebook.presto.metadata.FunctionAndTypeManager; | ||
import com.facebook.presto.metadata.SqlAggregationFunction; | ||
import com.facebook.presto.operator.aggregation.AccumulatorCompiler; | ||
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; | ||
import com.facebook.presto.operator.aggregation.state.StateCompiler; | ||
import com.facebook.presto.spi.function.AccumulatorStateFactory; | ||
import com.facebook.presto.spi.function.AccumulatorStateSerializer; | ||
import com.facebook.presto.spi.function.FunctionKind; | ||
import com.facebook.presto.spi.function.aggregation.Accumulator; | ||
import com.facebook.presto.spi.function.aggregation.AggregationMetadata; | ||
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.AccumulatorStateDescriptor; | ||
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; | ||
import com.google.common.collect.ImmutableList; | ||
|
||
import java.lang.invoke.MethodHandle; | ||
import java.util.List; | ||
|
||
import static com.facebook.presto.common.type.BigintType.BIGINT; | ||
import static com.facebook.presto.common.type.Decimals.MAX_PRECISION; | ||
import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION; | ||
import static com.facebook.presto.common.type.DoubleType.DOUBLE; | ||
import static com.facebook.presto.common.type.IntegerType.INTEGER; | ||
import static com.facebook.presto.common.type.RealType.REAL; | ||
import static com.facebook.presto.common.type.SmallintType.SMALLINT; | ||
import static com.facebook.presto.common.type.TinyintType.TINYINT; | ||
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; | ||
import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimal; | ||
import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger; | ||
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; | ||
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.combineStates; | ||
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.updateState; | ||
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAndSumAggregationUtils.writeNoisyAvgOutput; | ||
import static com.facebook.presto.spi.function.Signature.typeVariable; | ||
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata; | ||
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; | ||
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; | ||
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; | ||
import static com.facebook.presto.util.Reflection.methodHandle; | ||
import static com.google.common.collect.ImmutableList.toImmutableList; | ||
import static java.lang.Float.intBitsToFloat; | ||
|
||
public class NoisyAverageGaussianAggregation | ||
extends SqlAggregationFunction | ||
{ | ||
// Constant references for short/long decimal types for use in operations that only manipulate unscaled values | ||
private static final DecimalType LONG_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_PRECISION, 0); | ||
private static final DecimalType SHORT_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION, 0); | ||
|
||
public static final NoisyAverageGaussianAggregation NOISY_AVERAGE_GAUSSIAN_AGGREGATION = new NoisyAverageGaussianAggregation(); | ||
private static final String NAME = "noisy_avg_gaussian"; | ||
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputShortDecimal", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputLongDecimal", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputDouble", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle REAL_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputReal", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle BIGINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputBigInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle INTEGER_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputInteger", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle SMALLINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputSmallInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
private static final MethodHandle TINYINT_INPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "inputTinyInt", NoisyCountAndSumState.class, Block.class, Block.class, int.class); | ||
|
||
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "output", NoisyCountAndSumState.class, BlockBuilder.class); | ||
|
||
private static final MethodHandle COMBINE_FUNCTION = methodHandle(NoisyAverageGaussianAggregation.class, "combine", NoisyCountAndSumState.class, NoisyCountAndSumState.class); | ||
|
||
public NoisyAverageGaussianAggregation() | ||
{ | ||
super(NAME, | ||
ImmutableList.of(typeVariable("T")), | ||
ImmutableList.of(), | ||
parseTypeSignature(StandardTypes.DOUBLE), | ||
ImmutableList.of(parseTypeSignature("T"), | ||
parseTypeSignature(StandardTypes.DOUBLE)), | ||
FunctionKind.AGGREGATE); | ||
} | ||
|
||
@Override | ||
public String getDescription() | ||
{ | ||
return "Calculates the average (arithmetic mean) of all the input values and then adds random Gaussian noise."; | ||
} | ||
|
||
@Override | ||
public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) | ||
{ | ||
Type type = boundVariables.getTypeVariable("T"); | ||
return generateAggregation(type); | ||
} | ||
|
||
private static BuiltInAggregationFunctionImplementation generateAggregation(Type type) | ||
{ | ||
DynamicClassLoader classLoader = new DynamicClassLoader(NoisyAverageGaussianAggregation.class.getClassLoader()); | ||
|
||
AccumulatorStateSerializer<?> stateSerializer = new NoisyCountAndSumStateSerializer(); | ||
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(NoisyCountAndSumState.class, classLoader); | ||
List<Type> inputTypes = ImmutableList.of(type, DOUBLE); | ||
|
||
MethodHandle inputFunction; | ||
if (type instanceof DecimalType) { | ||
inputFunction = ((DecimalType) type).isShort() ? SHORT_DECIMAL_INPUT_FUNCTION : LONG_DECIMAL_INPUT_FUNCTION; | ||
} | ||
else if (type instanceof TinyintType) { | ||
inputFunction = TINYINT_INPUT_FUNCTION; | ||
} | ||
else if (type instanceof SmallintType) { | ||
inputFunction = SMALLINT_INPUT_FUNCTION; | ||
} | ||
else if (type instanceof IntegerType) { | ||
inputFunction = INTEGER_INPUT_FUNCTION; | ||
} | ||
else if (type instanceof BigintType) { | ||
inputFunction = BIGINT_INPUT_FUNCTION; | ||
} | ||
else if (type instanceof RealType) { | ||
inputFunction = REAL_INPUT_FUNCTION; | ||
} | ||
else { | ||
inputFunction = DOUBLE_INPUT_FUNCTION; | ||
} | ||
|
||
AggregationMetadata metadata = new AggregationMetadata( | ||
generateAggregationName(NAME, DOUBLE.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), | ||
createInputParameterMetadata(type), | ||
inputFunction, | ||
COMBINE_FUNCTION, | ||
OUTPUT_FUNCTION, | ||
ImmutableList.of(new AccumulatorStateDescriptor( | ||
NoisyCountAndSumState.class, | ||
stateSerializer, | ||
stateFactory)), | ||
DOUBLE); | ||
|
||
Type intermediateType = stateSerializer.getSerializedType(); | ||
|
||
Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( | ||
Accumulator.class, | ||
metadata, | ||
classLoader); | ||
Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( | ||
GroupedAccumulator.class, | ||
metadata, | ||
classLoader); | ||
return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), DOUBLE, | ||
true, false, metadata, accumulatorClass, groupedAccumulatorClass); | ||
} | ||
|
||
private static List<ParameterMetadata> createInputParameterMetadata(Type type) | ||
{ | ||
return ImmutableList.of( | ||
new ParameterMetadata(STATE), | ||
new ParameterMetadata(BLOCK_INPUT_CHANNEL, type), | ||
new ParameterMetadata(BLOCK_INPUT_CHANNEL, DOUBLE), | ||
new ParameterMetadata(BLOCK_INDEX)); | ||
} | ||
|
||
public static void inputShortDecimal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = unscaledDecimalToBigInteger( | ||
unscaledDecimal(SHORT_DECIMAL_TYPE.getLong(valueBlock, position))) | ||
.doubleValue(); | ||
double noiseScale = DOUBLE.getDouble(noiseScaleBlock, position); | ||
|
||
updateState(state, value, noiseScale, null, null, null); | ||
} | ||
|
||
public static void inputLongDecimal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = unscaledDecimalToBigInteger( | ||
unscaledDecimal(LONG_DECIMAL_TYPE.getSlice(valueBlock, position))) | ||
.doubleValue(); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputDouble(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = DOUBLE.getDouble(valueBlock, position); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputReal(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = intBitsToFloat((int) REAL.getLong(valueBlock, position)); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputBigInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = BIGINT.getLong(valueBlock, position); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputInteger(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = INTEGER.getLong(valueBlock, position); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputSmallInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = SMALLINT.getLong(valueBlock, position); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
public static void inputTinyInt(NoisyCountAndSumState state, Block valueBlock, Block noiseScaleBlock, int position) | ||
{ | ||
double value = TINYINT.getLong(valueBlock, position); | ||
input(state, value, noiseScaleBlock, position); | ||
} | ||
|
||
private static void input(NoisyCountAndSumState state, double value, Block noiseScaleBlock, int position) | ||
{ | ||
double noiseScale = DOUBLE.getDouble(noiseScaleBlock, position); | ||
|
||
updateState(state, value, noiseScale, null, null, null); | ||
} | ||
|
||
public static void combine(NoisyCountAndSumState state, NoisyCountAndSumState otherState) | ||
{ | ||
combineStates(state, otherState); | ||
} | ||
|
||
public static void output(NoisyCountAndSumState state, BlockBuilder out) | ||
{ | ||
writeNoisyAvgOutput(state, out); | ||
} | ||
} |
Oops, something went wrong.