Skip to content

Commit

Permalink
Moved example queries out of default values to model configs
Browse files Browse the repository at this point in the history
  • Loading branch information
damienbfs committed Nov 13, 2024
1 parent 96e6f81 commit c157a3c
Show file tree
Hide file tree
Showing 10 changed files with 426 additions and 644 deletions.
11 changes: 6 additions & 5 deletions core/lomas_core/error_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Type

from fastapi import FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pymongo.errors import WriteConcernError

Expand Down Expand Up @@ -94,7 +95,7 @@ async def invalid_query_exception_handler(
LOG.info(f"InvalidQueryException raised: {exc.error_message}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=InvalidQueryExceptionModel(message=exc.error_message),
content=jsonable_encoder(InvalidQueryExceptionModel(message=exc.error_message)),
)

@app.exception_handler(ExternalLibraryException)
Expand All @@ -104,9 +105,9 @@ async def external_library_exception_handler(
LOG.info(f"ExternalLibraryException raised: {exc.error_message}")
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=ExternalLibraryExceptionModel(
content=jsonable_encoder(ExternalLibraryExceptionModel(
message=exc.error_message, library=exc.library
),
)),
)

@app.exception_handler(UnauthorizedAccessException)
Expand All @@ -116,7 +117,7 @@ async def unauthorized_access_exception_handler(
LOG.info(f"UnauthorizedAccessException raised: {exc.error_message}")
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=UnauthorizedAccessExceptionModel(message=exc.error_message),
content=jsonable_encoder(UnauthorizedAccessExceptionModel(message=exc.error_message)),
)

@app.exception_handler(InternalServerException)
Expand All @@ -126,7 +127,7 @@ async def internal_server_exception_handler(
LOG.info(f"InternalServerException raised: {exc.error_message}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=InternalServerExceptionModel(),
content=jsonable_encoder(InternalServerExceptionModel()),
)


Expand Down
104 changes: 93 additions & 11 deletions core/lomas_core/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
SSynthMarginalSynthesizer,
)
from lomas_core.error_handler import InternalServerException
from lomas_core.models.requests_examples import (
example_diffprivlib,
example_dummy_diffprivlib,
example_dummy_opendp,
example_dummy_smartnoise_sql,
example_dummy_smartnoise_synth_query,
example_opendp,
example_smartnoise_sql,
example_smartnoise_sql_cost,
example_smartnoise_synth_cost,
example_smartnoise_synth_query,
)


class LomasRequestModel(BaseModel):
Expand All @@ -29,7 +41,9 @@ class GetDummyDataset(LomasRequestModel):
"""Model input to get a dummy dataset."""

dummy_nb_rows: int = Field(..., gt=0)
"""The number of dummy rows to generate."""
dummy_seed: int
"""The seed for the random generation of the dummy dataset."""


class QueryModel(LomasRequestModel):
Expand Down Expand Up @@ -58,18 +72,22 @@ class DummyQueryModel(QueryModel):
class SmartnoiseSQLRequestModel(LomasRequestModel):
"""Base input model for a smarnoise-sql request."""

model_config = ConfigDict(
json_schema_extra={"examples": [example_smartnoise_sql_cost]}
)

query_str: str
"""The SQL query to execute.
NOTE: the table name is \"df\", the query must end with \"FROM df\"
"""
epsilon: float = Field(..., gt=0)
"""Privacy parameter (e.g., 0.1)."""
delta: float = Field(..., gt=0)
delta: float = Field(..., ge=0)
"""Privacy parameter (e.g., 1e-5)."""
mechanisms: dict = {}
mechanisms: dict
"""
Dictionary of mechanisms for the query (default: {}).
Dictionary of mechanisms for the query.
See Smartnoise-SQL mechanisms documentation at
https://docs.smartnoise.org/sql/advanced.html#overriding-mechanisms.
Expand All @@ -79,6 +97,8 @@ class SmartnoiseSQLRequestModel(LomasRequestModel):
class SmartnoiseSQLQueryModel(SmartnoiseSQLRequestModel, QueryModel):
"""Base input model for a smartnoise-sql query."""

model_config = ConfigDict(json_schema_extra={"examples": [example_smartnoise_sql]})

postprocess: bool
"""
Whether to postprocess the query results (default: True).
Expand All @@ -91,77 +111,139 @@ class SmartnoiseSQLQueryModel(SmartnoiseSQLRequestModel, QueryModel):
class SmartnoiseSQLDummyQueryModel(SmartnoiseSQLQueryModel, DummyQueryModel):
"""Input model for a smartnoise-sql query on a dummy dataset."""

model_config = ConfigDict(
json_schema_extra={"examples": [example_dummy_smartnoise_sql]}
)


# SmartnoiseSynth
# ----------------------------------------------------------------------------
class SmartnoiseSynthRequestModel(LomasRequestModel):
"""Base input model for a SmartnoiseSynth request."""

model_config = ConfigDict(
json_schema_extra={"examples": [example_smartnoise_synth_cost]}
)

synth_name: Union[SSynthMarginalSynthesizer, SSynthGanSynthesizer]
"""Name of the synthesizer model to use."""
epsilon: float = Field(..., gt=0)
delta: Optional[float] = None
"""Privacy parameter (e.g., 0.1)."""
delta: float = Field(..., ge=0)
"""Privacy parameter (e.g., 1e-5)."""
select_cols: List
"""List of columns to select."""
synth_params: dict
"""
Keyword arguments to pass to the synthesizer constructor.
See https://docs.smartnoise.org/synth/synthesizers/index.html#, provide
all parameters of the model except `epsilon` and `delta`.
"""
nullable: bool
"""True if some data cells may be null."""
constraints: str
"""
Dictionnary for custom table transformer constraints.
Column that are not specified will be inferred based on metadata.
"""


class SmartnoiseSynthQueryModel(SmartnoiseSynthRequestModel, QueryModel):
"""Base input model for a smarnoise-synth query."""

model_config = ConfigDict(
json_schema_extra={"examples": [example_smartnoise_synth_query]}
)

return_model: bool
"""True to get Synthesizer model, False to get samples."""
condition: str
"""Sampling condition in `model.sample` (only relevant if return_model is False)."""
nb_samples: int
"""Number of samples to generate.
(only relevant if return_model is False)
"""


class SmartnoiseSynthDummyQueryModel(SmartnoiseSynthQueryModel, DummyQueryModel):
"""Input model for a smarnoise-synth query on a dummy dataset."""

# Same as normal query.
return_model: bool
condition: str
nb_samples: int
model_config = ConfigDict(
json_schema_extra={"examples": [example_dummy_smartnoise_synth_query]}
)


# OpenDP
# ----------------------------------------------------------------------------
class OpenDPRequestModel(LomasRequestModel):
"""Base input model for an opendp request."""

model_config = ConfigDict(use_attribute_docstrings=True)
model_config = ConfigDict(
use_attribute_docstrings=True, json_schema_extra={"examples": [example_opendp]}
)

opendp_json: str
"""Opendp pipeline."""
fixed_delta: Optional[float] = None
"""The OpenDP pipeline for the query."""
fixed_delta: Optional[float] = Field(..., ge=0)
"""
If the pipeline measurement is of type "ZeroConcentratedDivergence".
(e.g. with "make_gaussian") then it is converted to "SmoothedMaxDivergence"
with "make_zCDP_to_approxDP" (see "opendp measurements documentation at
https://docs.opendp.org/en/stable/api/python/opendp.combinators.html#opendp.combinators.make_zCDP_to_approxDP). # noqa # pylint: disable=C0301
In that case a "fixed_delta" must be provided by the user.
"""


class OpenDPQueryModel(OpenDPRequestModel, QueryModel):
"""Base input model for an opendp query."""

model_config = ConfigDict(json_schema_extra={"examples": [example_opendp]})


class OpenDPDummyQueryModel(OpenDPRequestModel, DummyQueryModel):
"""Input model for an opendp query on a dummy dataset."""

model_config = ConfigDict(json_schema_extra={"examples": [example_dummy_opendp]})


# DiffPrivLib
# ----------------------------------------------------------------------------
class DiffPrivLibRequestModel(LomasRequestModel):
"""Base input model for a diffprivlib request."""

model_config = ConfigDict(json_schema_extra={"examples": [example_diffprivlib]})

diffprivlib_json: str
"""The DiffPrivLib pipeline for the query (See diffprivlib_logger package.)."""
feature_columns: list
"""The list of feature columns to train."""
target_columns: Optional[list]
"""The list of target columns to predict."""
test_size: float = Field(..., gt=0.0, lt=1.0)
"""The proportion of the test set."""
test_train_split_seed: int
"""The seed for the random train/test split."""
imputer_strategy: str
"""The imputation strategy."""


class DiffPrivLibQueryModel(DiffPrivLibRequestModel, QueryModel):
"""Base input model for a diffprivlib query."""

model_config = ConfigDict(json_schema_extra={"examples": [example_diffprivlib]})


class DiffPrivLibDummyQueryModel(DiffPrivLibQueryModel, DummyQueryModel):
"""Input model for a DiffPrivLib query on a dummy dataset."""

model_config = ConfigDict(
json_schema_extra={"examples": [example_dummy_diffprivlib]}
)


# Utils
# ----------------------------------------------------------------------------
Expand Down
26 changes: 25 additions & 1 deletion core/lomas_core/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,31 @@ class SpentBudgetResponse(ResponseModel):
"""Model for responses to spent budget queries."""

total_spent_epsilon: float
"""The total spent epsilon privacy loss budget."""
total_spent_delta: float
"""The total spent delta privacy loss budget."""


class RemainingBudgetResponse(ResponseModel):
"""Model for responses to remaining budget queries."""

remaining_epsilon: float
"""The remaining epsilon privacy loss budget."""
remaining_delta: float
"""The remaining delta privacy loss budget."""


class DummyDsResponse(ResponseModel):
"""Model for responses to dummy dataset requests."""

model_config = ConfigDict(arbitrary_types_allowed=True)

dtypes: Dict[str, str]
"""The dummy_df column data types."""
datetime_columns: List[str]
"""The list of columns with datetime type."""
dummy_df: Annotated[pd.DataFrame, PlainSerializer(dataframe_to_dict)]
"""The dummy dataframe."""

@field_validator("dummy_df", mode="before")
@classmethod
Expand Down Expand Up @@ -86,6 +94,7 @@ class CostResponse(ResponseModel):
"""Model for responses to cost estimation requests or queries."""

model_config = ConfigDict(use_attribute_docstrings=True)

epsilon: float
"""The epsilon cost of the query."""
delta: float
Expand All @@ -101,57 +110,72 @@ class DiffPrivLibQueryResult(BaseModel):
"""Model for diffprivlib query result."""

model_config = ConfigDict(arbitrary_types_allowed=True)

res_type: Literal[DPLibraries.DIFFPRIVLIB] = DPLibraries.DIFFPRIVLIB
"""Result type description."""
score: float
"""The trained model score."""
model: Annotated[
DiffprivlibMixin,
PlainSerializer(serialize_model),
PlainValidator(deserialize_model),
]
"""The trained model."""


# SmartnoiseSQL
class SmartnoiseSQLQueryResult(BaseModel):
"""Type for smartnoise_sql result type."""

model_config = ConfigDict(arbitrary_types_allowed=True)

res_type: Literal[DPLibraries.SMARTNOISE_SQL] = DPLibraries.SMARTNOISE_SQL
"""Result type description."""
df: Annotated[
pd.DataFrame,
PlainSerializer(dataframe_to_dict),
PlainValidator(dataframe_from_dict),
]
"""Dataframe containing the query result."""


# SmartnoiseSynth
class SmartnoiseSynthModel(BaseModel):
"""Type for smartnoise_synth result when it is a pickled model."""

model_config = ConfigDict(arbitrary_types_allowed=True)

res_type: Literal[DPLibraries.SMARTNOISE_SYNTH] = DPLibraries.SMARTNOISE_SYNTH
"""Result type description."""
model: Annotated[
Synthesizer, PlainSerializer(serialize_model), PlainValidator(deserialize_model)
]
"""Synthetic data generator model."""


class SmartnoiseSynthSamples(BaseModel):
"""Type for smartnoise_synth result when it is a dataframe of samples."""

model_config = ConfigDict(arbitrary_types_allowed=True)

res_type: Literal["sn_synth_samples"] = "sn_synth_samples"
"""Result type description."""
df_samples: Annotated[
pd.DataFrame,
PlainSerializer(dataframe_to_dict),
PlainValidator(dataframe_from_dict),
]
"""Dataframe containing the generated synthetic samples."""


# OpenDP
class OpenDPQueryResult(BaseModel):
"""Type for opendp result."""

res_type: Literal[DPLibraries.OPENDP] = DPLibraries.OPENDP
"""Result type description."""
value: Union[int, float, List[Union[int, float]]]
"""The result value of the query."""


# Response object
Expand All @@ -173,4 +197,4 @@ class QueryResponse(CostResponse):
QueryResultTypeAlias,
Discriminator("res_type"),
]
"""The query result."""
"""The query result object."""
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ def _get_fit_model(
Returns:
Synthesizer: Fitted synthesizer model
"""
if query_json.delta is not None:
if query_json.synth_name != SSynthMarginalSynthesizer.MWEM:
# delta parameter is ignored for this synthesizer.
# TODO improve on this....
query_json.synth_params["delta"] = query_json.delta

if query_json.synth_name == SSynthGanSynthesizer.DP_CTGAN:
Expand Down
Loading

0 comments on commit c157a3c

Please sign in to comment.