Skip to content

Commit

Permalink
Recover parameters properly
Browse files Browse the repository at this point in the history
  • Loading branch information
csala committed Jul 9, 2020
1 parent 5503b16 commit 96395e8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 22 deletions.
11 changes: 11 additions & 0 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,14 @@ def to_json(self, path):
"""
with open(path, 'w') as out_file:
json.dump(self._metadata, out_file, indent=4)

@classmethod
def from_json(cls, path):
"""Load a Table from a JSON
Args:
path (str):
Path of the JSON file to load
"""
with open(path, 'r') as in_file:
return cls(json.load(in_file))
61 changes: 47 additions & 14 deletions sdv/tabular/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class GaussianCopula(BaseTabularModel):
``categorical_fuzzy``.
"""

DISTRIBUTION = copulas.univariate.GaussianUnivariate
DEFAULT_DISTRIBUTION = copulas.univariate.Univariate
_distribution = None
_categorical_transformer = None
_model = None

HYPERPARAMETERS = {
Expand Down Expand Up @@ -51,6 +52,7 @@ class GaussianCopula(BaseTabularModel):
]
}
}
DEFAULT_TRANSFORMER = 'one_hot_encoding'
CATEGORICAL_TRANSFORMERS = {
'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False),
'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True),
Expand All @@ -61,21 +63,40 @@ class GaussianCopula(BaseTabularModel):
'O': rdt.transformers.OneHotEncodingTransformer
}

def __init__(self, distribution=None, categorical_transformer='categorical',
*args, **kwargs):
def __init__(self, distribution=None, categorical_transformer=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distribution = distribution or self.DISTRIBUTION

if self._metadata is not None and 'model_kwargs' in self._metadata._metadata:
model_kwargs = self._metadata._metadata['model_kwargs']
if distribution is None:
distribution = model_kwargs['distribution']

if categorical_transformer is None:
categorical_transformer = model_kwargs['categorical_transformer']

self._size = model_kwargs['_size']

self._distribution = distribution or self.DEFAULT_DISTRIBUTION

categorical_transformer = categorical_transformer or self.DEFAULT_TRANSFORMER
self._categorical_transformer = categorical_transformer

self.TRANSFORMER_TEMPLATES['O'] = self.CATEGORICAL_TRANSFORMERS[categorical_transformer]

def _update_metadata(self):
parameters = self._model.to_dict()
univariates = parameters['univariates']
columns = parameters['columns']

fields = self._metadata.get_fields()
for field_name, univariate in zip(columns, univariates):
field_meta = fields[field_name]
field_meta['distribution'] = univariate['type']
distributions = {}
for column, univariate in zip(columns, univariates):
distributions[column] = univariate['type']

self._metadata._metadata['model_kwargs'] = {
'distribution': distributions,
'categorical_transformer': self._categorical_transformer,
'_size': self._size
}

def _fit(self, data):
"""Fit the model to the table.
Expand All @@ -86,6 +107,7 @@ def _fit(self, data):
"""
self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution)
self._model.fit(data)
self._update_metadata()

def _sample(self, size):
"""Sample ``size`` rows from the model.
Expand All @@ -100,16 +122,24 @@ def _sample(self, size):
"""
return self._model.sample(size)

def get_parameters(self):
def get_parameters(self, flatten=False):
"""Get copula model parameters.
Compute model ``covariance`` and ``distribution.std``
before it returns the flatten dict.
Args:
flatten (bool):
Whether to flatten the parameters or not before
returning them.
Returns:
dict:
Copula flatten parameters.
Copula parameters.
"""
if not flatten:
return self._model.to_dict()

values = list()
triangle = np.tril(self._model.covariance)

Expand Down Expand Up @@ -195,7 +225,7 @@ def _unflatten_gaussian_copula(self, model_parameters):

return model_parameters

def set_parameters(self, parameters):
def set_parameters(self, parameters, unflatten=False):
"""Set copula model parameters.
Add additional keys after unflatte the parameters
Expand All @@ -204,10 +234,13 @@ def set_parameters(self, parameters):
Args:
dict:
Copula flatten parameters.
unflatten (bool):
Whether the parameters need to be unflattened or not.
"""
parameters = unflatten_dict(parameters)
parameters.setdefault('distribution', self.distribution)
if unflatten:
parameters = unflatten_dict(parameters)
parameters.setdefault('distribution', self._distribution)

parameters = self._unflatten_gaussian_copula(parameters)
parameters = self._unflatten_gaussian_copula(parameters)

self._model = copulas.multivariate.GaussianMultivariate.from_dict(parameters)
24 changes: 16 additions & 8 deletions tests/integration/tabular/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ def test_gaussian_copula():
)
gc.fit(users)

sampled = gc.sample()
parameters = gc.get_parameters()
new_gc = GaussianCopula(
table_metadata=gc.get_metadata(),
categorical_transformer='one_hot_encoding',
)
new_gc.set_parameters(parameters)

sampled = new_gc.sample()

# test shape is right
assert sampled.shape == users.shape
Expand All @@ -38,11 +45,12 @@ def test_gaussian_copula():
# country codes have been replaced with new ones
assert set(sampled.country.unique()) != set(users.country.unique())

assert gc.get_metadata().to_dict() == {
'fields': {
'user_id': {'type': 'id', 'subtype': 'integer'},
'country': {'type': 'categorical'},
'gender': {'type': 'categorical'},
'age': {'type': 'numerical', 'subtype': 'integer'}
}
metadata = gc.get_metadata().to_dict()
assert metadata['fields'] == {
'user_id': {'type': 'id', 'subtype': 'integer'},
'country': {'type': 'categorical'},
'gender': {'type': 'categorical'},
'age': {'type': 'numerical', 'subtype': 'integer'}
}

assert 'model_kwargs' in metadata

0 comments on commit 96395e8

Please sign in to comment.