From 45ebc1617ab82d3c801bbd697f6ed4317db70d65 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Fri, 26 Jun 2020 20:22:35 +0200 Subject: [PATCH 01/33] =?UTF-8?q?Bump=20version:=200.3.3=20=E2=86=92=200.3?= =?UTF-8?q?.4.dev0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdv/__init__.py | 2 +- setup.cfg | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index 5213dd4ca..8caed53d9 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.3.3' +__version__ = '0.3.4.dev0' from sdv.demo import get_available_demos, load_demo from sdv.metadata import Metadata diff --git a/setup.cfg b/setup.cfg index 24389c418..6a2ef34f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.3 +current_version = 0.3.4.dev0 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index 0c8189cdf..81283021b 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDV', - version='0.3.3', + version='0.3.4.dev0', zip_safe=False, ) From 2b52f86201ceee8785ed9353f212a27dd67ca556 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Wed, 1 Jul 2020 16:42:12 +0200 Subject: [PATCH 02/33] WIP --- sdv/metadata.py | 2 +- sdv/modeler.py | 6 +++- sdv/sampler.py | 80 ++++++++++++++++++++++++++++++++++++++++--------- sdv/sdv.py | 13 ++++---- 4 files changed, 80 insertions(+), 21 deletions(-) diff --git a/sdv/metadata.py b/sdv/metadata.py index dbf01ca60..94f3e8e55 100644 --- a/sdv/metadata.py +++ b/sdv/metadata.py @@ -624,7 +624,7 @@ def validate(self, tables=None): self._validate_table(table_name, table_meta, table) self._validate_circular_relationships(table_name) - self._validate_parents(table_name) + # self._validate_parents(table_name) def _check_field(self, table, field, exists=False): """Validate the existance of the table and existance (or not) of field.""" diff --git a/sdv/modeler.py b/sdv/modeler.py index aecfbd91a..7e766deb6 100644 --- a/sdv/modeler.py +++ b/sdv/modeler.py @@ -27,6 +27,7 @@ def __init__(self, metadata, model=GaussianCopula, model_kwargs=None): self.metadata = metadata self.model = model self.model_kwargs = dict() if model_kwargs is None else model_kwargs + self.table_sizes = dict() def _get_extension(self, child_name, child_table, foreign_key): """Generate list of extension for child tables. @@ -78,7 +79,7 @@ def cpa(self, table_name, tables, foreign_key=None): table_name (str): Name of the table to model. tables (dict): - Dict of tables tha have been already modeled. + Dict of original tables. foreign_key (str): Name of the foreign key that references this table. Used only when applying CPA on a child table. @@ -87,6 +88,7 @@ def cpa(self, table_name, tables, foreign_key=None): pandas.DataFrame: table data with the extensions created while modeling its children. """ + LOGGER.info('Modeling %s', table_name) if tables: @@ -94,6 +96,8 @@ def cpa(self, table_name, tables, foreign_key=None): else: table = self.metadata.load_table(table_name) + self.table_sizes[table_name] = len(table) + extended = self.metadata.transform(table_name, table) primary_key = self.metadata.get_primary_key(table_name) diff --git a/sdv/sampler.py b/sdv/sampler.py index b67f3e87a..8e0639b2b 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -23,13 +23,14 @@ class Sampler: primary_key = None remaining_primary_key = None - def __init__(self, metadata, models, model, model_kwargs): + def __init__(self, metadata, models, model, model_kwargs, table_sizes): self.metadata = metadata self.models = models self.primary_key = dict() self.remaining_primary_key = dict() self.model = model self.model_kwargs = model_kwargs + self.table_sizes = table_sizes def _reset_primary_keys_generators(self): """Reset the primary key generators.""" @@ -161,9 +162,9 @@ def _sample_children(self, table_name, sampled): table_rows = sampled[table_name] for child_name in self.metadata.get_children(table_name): for _, row in table_rows.iterrows(): - self._sample_table(child_name, table_name, row, sampled) + self._sample_child_rows(child_name, table_name, row, sampled) - def _sample_table(self, table_name, parent_name, parent_row, sampled): + def _sample_child_rows(self, table_name, parent_name, parent_row, sampled): extension = self._get_extension(parent_row, table_name) model = self.model(**self.model_kwargs) @@ -184,7 +185,49 @@ def _sample_table(self, table_name, parent_name, parent_row, sampled): self._sample_children(table_name, sampled) - def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children=True): + def _get_pdfs(self, parent_rows, child_name): + pdfs = dict() + for parent_id, row in parent_rows.iterrows(): + extension = self._get_extension(row, child_name) + model = self.model(**self.model_kwargs) + model.set_parameters(extension) + pdfs[parent_id] = model.model.probability_density + + return pdfs + + def _find_parent_id(self, row, pdfs): + likelihoods = dict() + for parent_id, pdf in pdfs.items(): + try: + likelihoods[parent_id] = pdf(row) + except np.linalg.LinAlgError: + likelihoods[parent_id] = None + + likelihoods = pd.Series(likelihoods) + likelihoods = likelihoods.fillna(likelihoods.mean()).fillna(0) + + # weights = likelihoods.values / likelihoods.sum() + + return likelihoods.idxmax() + + def _find_parent_ids(self, table_name, parent_name, sampled_rows, sampled=None): + ratio = self.table_sizes[parent_name] / self.table_sizes[table_name] + num_parent_rows = max(int(round(len(sampled_rows) * ratio)), 1) + parent_model = self.models[parent_name] + parent_rows = self._sample_rows(parent_model, num_parent_rows, parent_name) + + foreign_key = self.metadata.get_foreign_key(parent_name, table_name) + primary_key = self.metadata.get_primary_key(parent_name) + pdfs = self._get_pdfs(parent_rows.set_index(primary_key), table_name) + + parent_ids = list() + for _, row in sampled_rows.iterrows(): + parent_ids.append(self._find_parent_id(row, pdfs)) + + return parent_ids + + def sample(self, table_name, num_rows=None, reset_primary_keys=False, + sample_children=True, sample_parents=True): """Sample one table. Child tables will be sampled when ``sample_children`` is ``True``. @@ -195,11 +238,14 @@ def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children table_name (str): Table name to sample. num_rows (int): - Amount of rows to sample. + Amount of rows to sample. If ``None``, sample the same number of rows + as there were in the original table. reset_primary_keys (bool): Whether or not reset the primary keys generators. Defaults to ``False``. sample_children (bool): Whether or not sample child tables. Defaults to ``True``. + sample_parents (bool): + Whether or not sample parent tables. Defaults to ``True``. Returns: dict or pandas.DataFrame: @@ -211,16 +257,21 @@ def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children if reset_primary_keys: self._reset_primary_keys_generators() - model = self.models[table_name] + if num_rows is None: + num_rows = self.table_sizes[table_name] + model = self.models[table_name] sampled_rows = self._sample_rows(model, num_rows, table_name) parents = self.metadata.get_parents(table_name) - if parents: - parent_name = list(parents)[0] - foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - parent_id = self._get_primary_keys(parent_name, 1)[1][0] - sampled_rows[foreign_key] = parent_id + if parents and sample_parents: + data_fields = list(self.metadata.get_dtypes(table_name, ids=False).keys()) + table_data = sampled_rows[data_fields] + for parent_name in parents: + parent_ids = self._find_parent_ids(table_name, parent_name, table_data) + foreign_key = self.metadata.get_foreign_key(parent_name, table_name) + # parent_ids = self._get_primary_keys(parent_name, num_rows)[1] + sampled_rows[foreign_key] = parent_ids if sample_children: sampled_data = { @@ -237,7 +288,7 @@ def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children else: return self._transform_synthesized_rows(sampled_rows, table_name) - def sample_all(self, num_rows=5, reset_primary_keys=False): + def sample_all(self, num_rows=None, reset_primary_keys=False): """Samples the entire dataset. ``sample_all`` returns a dictionary with all the tables of the dataset sampled. @@ -249,9 +300,10 @@ def sample_all(self, num_rows=5, reset_primary_keys=False): Args: num_rows (int): - Number of rows to be sampled on the parent tables. + Number of rows to be sampled on the first parent tables. If ``None``, + sample the same number of rows as in the original tables. reset_primary_keys (bool): - Wheter or not reset the primary key generators. + Whether or not reset the primary key generators. Returns: dict: diff --git a/sdv/sdv.py b/sdv/sdv.py index f309801a0..4f05282b7 100644 --- a/sdv/sdv.py +++ b/sdv/sdv.py @@ -67,16 +67,18 @@ def fit(self, metadata, tables=None, root_path=None): self.modeler = Modeler(self.metadata, self.model, self.model_kwargs) self.modeler.model_database(tables) - self.sampler = Sampler(self.metadata, self.modeler.models, self.model, self.model_kwargs) + self.sampler = Sampler(self.metadata, self.modeler.models, self.model, + self.model_kwargs, self.modeler.table_sizes) - def sample(self, table_name, num_rows, sample_children=True, reset_primary_keys=False): + def sample(self, table_name, num_rows=None, sample_children=True, reset_primary_keys=False): """Sample ``num_rows`` rows from the indicated table. Args: table_name (str): Name of the table to sample from. num_rows (int): - Amount of rows to sample. + Amount of rows to sample. If ``None``, sample the same number of rows + as there were in the original table. sample_children (bool): Whether or not to sample children tables. Defaults to ``True``. reset_primary_keys (bool): @@ -100,12 +102,13 @@ def sample(self, table_name, num_rows, sample_children=True, reset_primary_keys= reset_primary_keys=reset_primary_keys ) - def sample_all(self, num_rows=5, reset_primary_keys=False): + def sample_all(self, num_rows=None, reset_primary_keys=False): """Sample the entire dataset. Args: num_rows (int): - Amount of rows to sample. Defaults to ``5``. + Number of rows to be sampled on the first parent tables. If ``None``, + sample the same number of rows as in the original tables. reset_primary_keys (bool): Wheter or not reset the primary key generators. Defaults to ``False``. From 8eae9c7b4069b3bfc01396c8d72c5c03a41801b7 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Wed, 1 Jul 2020 23:44:41 +0200 Subject: [PATCH 03/33] Multi-parent support --- sdv/models/copulas.py | 47 ++++++++++------- sdv/models/utils.py | 10 ++-- sdv/sampler.py | 115 +++++++++++++++++++++++------------------- setup.py | 2 +- 4 files changed, 102 insertions(+), 72 deletions(-) diff --git a/sdv/models/copulas.py b/sdv/models/copulas.py index 61c72bd9f..ff1665d50 100644 --- a/sdv/models/copulas.py +++ b/sdv/models/copulas.py @@ -1,5 +1,7 @@ import numpy as np -from copulas import multivariate, univariate +from copulas import EPSILON +from copulas.multivariate import GaussianMultivariate +from copulas.univariate import GaussianUnivariate from sdv.models.base import SDVModel from sdv.models.utils import ( @@ -29,7 +31,7 @@ class GaussianCopula(SDVModel): 4 1.925887 """ - DISTRIBUTION = univariate.GaussianUnivariate + DISTRIBUTION = GaussianUnivariate distribution = None model = None @@ -46,7 +48,7 @@ def fit(self, table_data): Data to be fitted. """ table_data = impute(table_data) - self.model = multivariate.GaussianMultivariate(distribution=self.distribution) + self.model = GaussianMultivariate(distribution=self.distribution) self.model.fit(table_data) def sample(self, num_samples): @@ -79,11 +81,20 @@ def get_parameters(self): values.append(row[:index + 1]) self.model.covariance = np.array(values) - for distribution in self.model.distribs.values(): - if distribution.std is not None: - distribution.std = np.log(distribution.std) + params = self.model.to_dict() + univariates = dict() + for name, univariate in zip(params.pop('columns'), params['univariates']): + univariates[name] = univariate + if 'scale' in univariate: + scale = univariate['scale'] + if scale == 0: + scale = EPSILON - return flatten_dict(self.model.to_dict()) + univariate['scale'] = np.log(scale) + + params['univariates'] = univariates + + return flatten_dict(params) def _prepare_sampled_covariance(self, covariance): """Prepare a covariance matrix. @@ -126,15 +137,20 @@ def _unflatten_gaussian_copula(self, model_parameters): Model parameters ready to recreate the model. """ - distribution_kwargs = { - 'fitted': True, + univariate_kwargs = { 'type': model_parameters['distribution'] } - distribs = model_parameters['distribs'] - for distribution in distribs.values(): - distribution.update(distribution_kwargs) - distribution['std'] = np.exp(distribution['std']) + columns = list() + univariates = list() + for column, univariate in model_parameters['univariates'].items(): + columns.append(column) + univariate.update(univariate_kwargs) + univariate['scale'] = np.exp(univariate['scale']) + univariates.append(univariate) + + model_parameters['univariates'] = univariates + model_parameters['columns'] = columns covariance = model_parameters['covariance'] model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) @@ -156,8 +172,5 @@ def set_parameters(self, parameters): parameters.setdefault('distribution', self.distribution) parameters = self._unflatten_gaussian_copula(parameters) - for param in parameters['distribs'].values(): - param.setdefault('type', self.distribution) - param.setdefault('fitted', True) - self.model = multivariate.GaussianMultivariate.from_dict(parameters) + self.model = GaussianMultivariate.from_dict(parameters) diff --git a/sdv/models/utils.py b/sdv/models/utils.py index a839ecfb5..241a20048 100644 --- a/sdv/models/utils.py +++ b/sdv/models/utils.py @@ -20,11 +20,15 @@ def flatten_array(nested, prefix=''): for index in range(len(nested)): prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) - if isinstance(nested[index], (list, np.ndarray)): - result.update(flatten_array(nested[index], prefix=prefix_key)) + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) else: - result[prefix_key] = nested[index] + result[prefix_key] = value return result diff --git a/sdv/sampler.py b/sdv/sampler.py index 8e0639b2b..706d8c064 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -37,24 +37,33 @@ def _reset_primary_keys_generators(self): self.primary_key = dict() self.remaining_primary_key = dict() - def _transform_synthesized_rows(self, synthesized, table_name): + def _finalize(self, sampled_data): """Reverse transform synthetized data. Args: - synthesized (pandas.DataFrame): - Generated data from model - table_name (str): - Name of the table. + sampled_data (dict): + Generated data Return: pandas.DataFrame: Formatted synthesized data. """ - reversed_data = self.metadata.reverse_transform(table_name, synthesized) + final_data = dict() + for table_name, table_rows in sampled_data.items(): + parents = self.metadata.get_parents(table_name) + if parents: + for parent_name in parents: + parent_ids = self._find_parent_ids(table_name, parent_name, sampled_data) + foreign_key = self.metadata.get_foreign_key(parent_name, table_name) + table_rows[foreign_key] = parent_ids + + reversed_data = self.metadata.reverse_transform(table_name, table_rows) + + fields = self.metadata.get_fields(table_name) - fields = self.metadata.get_fields(table_name) + final_data[table_name] = reversed_data[list(fields.keys())] - return reversed_data[list(fields.keys())] + return final_data def _get_primary_keys(self, table_name, num_rows): """Return the primary key and amount of values for the requested table. @@ -119,7 +128,7 @@ def _get_primary_keys(self, table_name, num_rows): return primary_key, primary_key_values - def _get_extension(self, parent_row, table_name): + def _extract_parameters(self, parent_row, table_name): """Get the params from a generated parent row. Args: @@ -128,7 +137,6 @@ def _get_extension(self, parent_row, table_name): table_name (str): Name of the table to make the model for. """ - prefix = '__{}__'.format(table_name) keys = [key for key in parent_row.keys() if key.startswith(prefix)] new_keys = {key: key[len(prefix):] for key in keys} @@ -165,64 +173,84 @@ def _sample_children(self, table_name, sampled): self._sample_child_rows(child_name, table_name, row, sampled) def _sample_child_rows(self, table_name, parent_name, parent_row, sampled): - extension = self._get_extension(parent_row, table_name) + parameters = self._extract_parameters(parent_row, table_name) model = self.model(**self.model_kwargs) - model.set_parameters(extension) - num_rows = max(round(extension['child_rows']), 0) + model.set_parameters(parameters) + num_rows = max(round(parameters['child_rows']), 0) - sampled_rows = self._sample_rows(model, num_rows, table_name) + table_rows = self._sample_rows(model, num_rows, table_name) parent_key = self.metadata.get_primary_key(parent_name) foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - sampled_rows[foreign_key] = parent_row[parent_key] + table_rows[foreign_key] = parent_row[parent_key] previous = sampled.get(table_name) if previous is None: - sampled[table_name] = sampled_rows + sampled[table_name] = table_rows else: - sampled[table_name] = pd.concat([previous, sampled_rows]).reset_index(drop=True) + sampled[table_name] = pd.concat([previous, table_rows]).reset_index(drop=True) self._sample_children(table_name, sampled) def _get_pdfs(self, parent_rows, child_name): pdfs = dict() for parent_id, row in parent_rows.iterrows(): - extension = self._get_extension(row, child_name) + parameters = self._extract_parameters(row, child_name) model = self.model(**self.model_kwargs) - model.set_parameters(extension) + model.set_parameters(parameters) pdfs[parent_id] = model.model.probability_density return pdfs - def _find_parent_id(self, row, pdfs): + def _find_parent_id(self, row, pdfs, num_rows): likelihoods = dict() for parent_id, pdf in pdfs.items(): try: - likelihoods[parent_id] = pdf(row) + likelihoods[parent_id] = max(pdf(row), 0.0) except np.linalg.LinAlgError: + # Singular matrix likelihoods[parent_id] = None likelihoods = pd.Series(likelihoods) - likelihoods = likelihoods.fillna(likelihoods.mean()).fillna(0) + mean = likelihoods.mean() + if (likelihoods == 0).all(): + # All rows got 0 likelihood, fallback to num_rows + likelihoods = num_rows + elif pd.isnull(mean) or mean == 0: + # No row got likelihood > 0, but some got singlar matrix + # Fallback to num_rows on the singular matrix rows + likelihoods = likelihoods.fillna(num_rows) + else: + # at least one row got likelihood > 0, so fill the + # singular matrix rows with the mean + likelihoods = likelihoods.fillna(mean) - # weights = likelihoods.values / likelihoods.sum() + total = likelihoods.sum() + if total == 0: + weights = np.ones(len(likelihoods)) + else: + weights = likelihoods.values / likelihoods.sum() - return likelihoods.idxmax() + return np.random.choice(likelihoods.index, p=weights) - def _find_parent_ids(self, table_name, parent_name, sampled_rows, sampled=None): - ratio = self.table_sizes[parent_name] / self.table_sizes[table_name] - num_parent_rows = max(int(round(len(sampled_rows) * ratio)), 1) - parent_model = self.models[parent_name] - parent_rows = self._sample_rows(parent_model, num_parent_rows, parent_name) + def _find_parent_ids(self, table_name, parent_name, sampled_data): + table_rows = sampled_data[table_name] + if parent_name in sampled_data: + parent_rows = sampled_data[parent_name] + else: + ratio = self.table_sizes[parent_name] / self.table_sizes[table_name] + num_parent_rows = max(int(round(len(table_rows) * ratio)), 1) + parent_model = self.models[parent_name] + parent_rows = self._sample_rows(parent_model, num_parent_rows, parent_name) - foreign_key = self.metadata.get_foreign_key(parent_name, table_name) primary_key = self.metadata.get_primary_key(parent_name) pdfs = self._get_pdfs(parent_rows.set_index(primary_key), table_name) + num_rows = parent_rows['__' + table_name + '__child_rows'].clip(0) parent_ids = list() - for _, row in sampled_rows.iterrows(): - parent_ids.append(self._find_parent_id(row, pdfs)) + for _, row in table_rows.iterrows(): + parent_ids.append(self._find_parent_id(row, pdfs, num_rows)) return parent_ids @@ -253,7 +281,6 @@ def sample(self, table_name, num_rows=None, reset_primary_keys=False, and child tables. - Returns a ``pandas.DataFrame`` when ``sample_children`` is ``False``. """ - if reset_primary_keys: self._reset_primary_keys_generators() @@ -261,32 +288,18 @@ def sample(self, table_name, num_rows=None, reset_primary_keys=False, num_rows = self.table_sizes[table_name] model = self.models[table_name] - sampled_rows = self._sample_rows(model, num_rows, table_name) - - parents = self.metadata.get_parents(table_name) - if parents and sample_parents: - data_fields = list(self.metadata.get_dtypes(table_name, ids=False).keys()) - table_data = sampled_rows[data_fields] - for parent_name in parents: - parent_ids = self._find_parent_ids(table_name, parent_name, table_data) - foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - # parent_ids = self._get_primary_keys(parent_name, num_rows)[1] - sampled_rows[foreign_key] = parent_ids + table_rows = self._sample_rows(model, num_rows, table_name) if sample_children: sampled_data = { - table_name: sampled_rows + table_name: table_rows } self._sample_children(table_name, sampled_data) - - for table, sampled_rows in sampled_data.items(): - sampled_data[table] = self._transform_synthesized_rows(sampled_rows, table) - - return sampled_data + return self._finalize(sampled_data) else: - return self._transform_synthesized_rows(sampled_rows, table_name) + return self._finalize({table_name: table_rows})[table_name] def sample_all(self, num_rows=None, reset_primary_keys=False): """Samples the entire dataset. diff --git a/setup.py b/setup.py index 81283021b..bca8362a0 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ 'exrex>=0.9.4,<0.11', 'numpy>=1.15.4,<1.17', 'pandas>=0.23.4,<0.25', - 'copulas>=0.2.5,<0.3', + 'copulas>=0.3,<0.4', 'rdt>=0.2.1,<0.3', 'graphviz>=0.13.2', 'sdmetrics>=0.0.1,<0.0.2' From bf50f63d33cdb5f548d637287e9b0fa53538c555 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 15:49:58 +0200 Subject: [PATCH 04/33] Run numbered examples only --- Makefile | 2 +- examples/0. Quickstart - README.ipynb | 506 ++++++------------ ...uickstart - Single Table - In Memory.ipynb | 117 ++-- .... Quickstart - Single Table - Census.ipynb | 159 +++--- .../3. Quickstart - Multitable - Files.ipynb | 165 +++--- examples/4. Anonymization.ipynb | 72 ++- ...5. Generate Metadata from Dataframes.ipynb | 51 +- examples/Demo - Walmart.ipynb | 185 ++++++- examples/Evaluation.ipynb | 193 +++++++ examples/demo_metadata.json | 26 +- 10 files changed, 831 insertions(+), 645 deletions(-) create mode 100644 examples/Evaluation.ipynb diff --git a/Makefile b/Makefile index bdcbeac21..31f1d6d6f 100644 --- a/Makefile +++ b/Makefile @@ -110,7 +110,7 @@ test-readme: ## run the readme snippets .PHONY: test-tutorials test-tutorials: ## run the tutorial notebooks - jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 examples/*.ipynb --stdout > /dev/null + jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 examples/?.\ *.ipynb --stdout > /dev/null .PHONY: test test: test-unit test-readme test-tutorials ## test everything that needs test dependencies diff --git a/examples/0. Quickstart - README.ipynb b/examples/0. Quickstart - README.ipynb index fa7ba905b..f6f2fecdc 100644 --- a/examples/0. Quickstart - README.ipynb +++ b/examples/0. Quickstart - README.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -48,7 +48,7 @@ " 'approved': {'type': 'boolean'}}}}}" ] }, - "execution_count": 9, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -100,7 +100,7 @@ " 9 9 9 2019-01-29 12:10:48 99.9 True}" ] }, - "execution_count": 10, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -111,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -188,10 +188,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 11, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -202,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": { "scrolled": false }, @@ -211,27 +211,43 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-06-25 16:03:03,773 - INFO - modeler - Modeling users\n", - "2020-06-25 16:03:03,778 - INFO - modeler - Modeling sessions\n", - "2020-06-25 16:03:03,783 - INFO - modeler - Modeling transactions\n", - "2020-06-25 16:03:03,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,812 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,819 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,825 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,834 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,840 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,861 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,891 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,909 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,935 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,972 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,990 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,018 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,141 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,326 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:24:31,338 - INFO - modeler - Modeling users\n", + "2020-07-02 14:24:31,339 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", + "2020-07-02 14:24:31,340 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-02 14:24:31,341 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", + "2020-07-02 14:24:31,356 - INFO - modeler - Modeling sessions\n", + "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field device\n", + "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field os\n", + "2020-07-02 14:24:31,371 - INFO - modeler - Modeling transactions\n", + "2020-07-02 14:24:31,372 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer NumericalTransformer for field amount\n", + "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer BooleanTransformer for field approved\n", + "2020-07-02 14:24:31,386 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return getattr(obj, method)(*args, **kwds)\n", + "2020-07-02 14:24:31,404 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-02 14:24:31,413 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,418 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,425 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,431 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,435 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,471 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,485 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,502 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,516 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,530 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,563 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,629 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:24:31,734 - INFO - modeler - Modeling Complete\n" ] } ], @@ -244,321 +260,129 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.726615061615442], [1.0454201815611962e-17, 0.8202611657476363], [0.906087951592798, 0.0, 0.9063498818810053], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.7211659999181501, 0.09407387578039639, 0.9062265705680096, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.7265698690747444], [0.1962595829950089, -0.4101010550927784, 0.9064318910717786, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15363369110349148, 1.7265890069169954], [1.6501160558524515, 0.34588621287899607, 0.9060859813924124, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6848330864411356, 0.089532714183855, 1.7265426703918063], [0.19506154837332199, -0.41017105995034014, 0.9064258143689862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15319193719168775, 1.7265715342431274, 0.08933484748381948, 1.726601450683982], [1.558118079154194, -0.49788243931513426, 0.9061755927291784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.4967074003515342, 0.5906071587381734, 1.287493411569124, 0.5910547002051788, 1.7265085917492808], [0.19622238674880996, -0.4101312643965461, 0.9063303749716511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15392250979132835, 1.7265715878572472, 0.08897912790316595, 1.726633522580159, 0.5911723429044428, 1.7265328687279475], [-1.426683586477038, 0.6340766049086322, -0.9064943110193431, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.350356266672644, -0.7728329555360014, -1.1105383964628412, -0.7729561856331211, -1.7046109781197039, -0.7727141956567691, 1.7265200912725471], [0.19613513187055676, -0.4101447540378119, 0.9060999824108501, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1536514208926178, 1.726564276607309, 0.08894497870076296, 1.7266086430031098, 0.5907859264695859, 1.7266240785900109, -0.7726518888264287, 1.7265194550742498], [-0.19600331514505734, 0.4100687094013292, -0.9060002638533913, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.15335524134834988, -1.7265840276560094, -0.08909093050930667, -1.7265435834346627, -0.5903898337829075, -1.7265249406712062, 0.7728873327661734, -1.7266499172273617, 1.726589663707634]]\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(103)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 104 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "array([[ 1.72661506e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.04542018e-17, 8.20261166e-01, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 9.06087952e-01, 0.00000000e+00, 9.06349882e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.72116600e+00, 9.40738758e-02, 9.06226571e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.72656987e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96259583e-01, -4.10101055e-01, 9.06431891e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53633691e-01, 1.72658901e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.65011606e+00, 3.45886213e-01, 9.06085981e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.68483309e+00, 8.95327142e-02,\n", - " 1.72654267e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.95061548e-01, -4.10171060e-01, 9.06425814e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53191937e-01, 1.72657153e+00,\n", - " 8.93348475e-02, 1.72660145e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.55811808e+00, -4.97882439e-01, 9.06175593e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.49670740e+00, 5.90607159e-01,\n", - " 1.28749341e+00, 5.91054700e-01, 1.72650859e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96222387e-01, -4.10131264e-01, 9.06330375e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53922510e-01, 1.72657159e+00,\n", - " 8.89791279e-02, 1.72663352e+00, 5.91172343e-01,\n", - " 1.72653287e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [-1.42668359e+00, 6.34076605e-01, -9.06494311e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, -1.35035627e+00, -7.72832956e-01,\n", - " -1.11053840e+00, -7.72956186e-01, -1.70461098e+00,\n", - " -7.72714196e-01, 1.72652009e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96135132e-01, -4.10144754e-01, 9.06099982e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53651421e-01, 1.72656428e+00,\n", - " 8.89449787e-02, 1.72660864e+00, 5.90785926e-01,\n", - " 1.72662408e+00, -7.72651889e-01, 1.72651946e+00,\n", - " 0.00000000e+00],\n", - " [-1.96003315e-01, 4.10068709e-01, -9.06000264e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, -1.53355241e-01, -1.72658403e+00,\n", - " -8.90909305e-02, -1.72654358e+00, -5.90389834e-01,\n", - " -1.72652494e+00, 7.72887333e-01, -1.72664992e+00,\n", - " 1.72658966e+00]])\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ipdb> ll\n", - "\u001b[1;32m 90 \u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prepare_sampled_covariance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 91 \u001b[0m \"\"\"Prepare a covariance matrix.\n", - "\u001b[1;32m 92 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 93 \u001b[0m \u001b[0mArgs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 94 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 95 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0mafter\u001b[0m \u001b[0munflattening\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 96 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 97 \u001b[0m \u001b[0mResult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 98 \u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 99 \u001b[0m \u001b[0msymmetric\u001b[0m \u001b[0mPositive\u001b[0m \u001b[0msemi\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mdefinite\u001b[0m \u001b[0mmatrix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 100 \u001b[0m \"\"\"\n", - "\u001b[1;32m 101 \u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 102 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m--> 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[1;32m 104 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 105 \u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 106 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 107 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 108 \u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 109 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(105)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 104 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 105 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 106 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(106)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 105 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 106 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(108)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 108 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 109 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "--Return--\n", - "[[1.726928289354223, -6.284976501719013e-05, 0.9060437009442075, 1.1799638567951807e-18, -1.7530700091068226e-20, 5.17606189171064e-21, ...], [-6.284976501719013e-05, 0.8203273652660589, -3.601422526398147e-05, 5.414945862916836e-17, 3.245206248848917e-20, -1.2006964972373577e-20, ...], [0.9060437009442075, -3.601422526398147e-05, 0.9064526869220875, 1.933376368463285e-17, -2.6655267441977074e-20, 2.5829833603018656e-21, ...], [1.1799638567951807e-18, 5.414945862916836e-17, 1.933376368463285e-17, 1.192092915444104e-07, -1.1365716705948445e-20, 4.7277921079772496e-21, ...], [-1.7530700091068226e-20, 3.245206248848917e-20, -2.6655267441977074e-20, -1.1365716705948445e-20, 1.1920929154436114e-07, 5.437058680680646e-21, ...], [5.17606189171064e-21, -1.2006964972373577e-20, 2.5829833603018656e-21, 4.7277921079772496e-21, 5.437058680680646e-21, 1.1920929154437684e-07, ...], ...]\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(108)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 108 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 109 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(145)\u001b[0;36m_unflatten_gaussian_copula\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 144 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 145 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 146 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.9735777830655135], [1.0081332489157226e-18, 0.07922245304646469], [1.8944415499749927, 0.0, 1.8945678304393692], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.9731403682292472, 0.009080369609414676, 1.8945052811636178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9736633430154003], [1.8257571570695899, -0.03951915463265521, 1.8943818752505202, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8223719529889824, 1.9735960424607375], [1.9662149071445283, 0.03332947346115095, 1.8943702463439986, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9696074413161604, 1.8155029069801496, 1.9735786685725154], [1.8259699330878305, -0.039591165816158624, 1.8946306798075692, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8222344112809918, 1.973660492989475, 1.8155381885055653, 1.9736509880140303], [1.9574634628424201, -0.0480675268053235, 1.8944800309305712, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9514242514903646, 1.8641779050506904, 1.931397843569703, 1.8644699021432583, 1.97360470087791], [1.8261599075098556, -0.039602178471882965, 1.8944795379664088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8221689799836889, 1.9736216215949314, 1.8159480875771647, 1.9736090474517491, 1.8643658311250595, 1.9735713191144897], [-1.9447451350012899, 0.06124906062526714, -1.8941152019092415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9372479886935285, -1.8815955695505688, -1.9143447992000526, -1.8819069289711103, -1.9715806948911605, -1.8818122561418829, 1.9735728616891475], [1.826265421380993, -0.03969242272676249, 1.8947570439794252, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.821544416662928, 1.9736431265524037, 1.815795088268001, 1.9735745553644184, 1.8640973954156812, 1.9736683290141885, -1.8818260533698774, 1.9736371138967417], [-1.8260997940060133, 0.039593106594922045, -1.8945455094034567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.8221756212795035, -1.9736399789373342, -1.8157747174291041, -1.973587031691748, -1.863968052730401, -1.9737028348459618, 1.8816417352261723, -1.9736384778314904, 1.9735905131918667]]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.7798122161501428], [8.418645096312187e-18, 0.6606961152453757], [1.1193765416016843, 0.0, 1.119151425141485], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.7754388866407922, 0.07580182865822005, 1.1193005781250343, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.7797523138925522], [0.5470801791882579, -0.3302484465683945, 1.1192475956082857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5129529949615216, 1.779801879956687], [1.718206887546536, 0.2784760244595479, 1.119072525637323, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.746240556959611, 0.4609048552481457, 1.7798064162607745], [0.5470062693558146, -0.3303697163214912, 1.1192392868938212, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5128535772731618, 1.7797855289798796, 0.4614315710843969, 1.7798640319936492], [1.6441999827502363, -0.4009517141082254, 1.1191904874368754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5946523925660905, 0.8650486386888843, 1.4262448365733487, 0.8650813439515379, 1.7798317994703388], [0.546920059994199, -0.3302738771226835, 1.1191467637807095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5131699084433272, 1.7797358524773224, 0.4613957209455638, 1.7798293469122717, 0.865178931907089, 1.779727156681622], [-1.5381861979492537, 0.5108352793295525, -1.119103691856403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.4767659633854602, -1.011712169525542, -1.2837980435462297, -1.0115506970319363, -1.7621372594453895, -1.0115631089643444, 1.7798513072348816], [0.547088849554847, -0.3303011887943328, 1.119126575120807, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5129183354605175, 1.779756666332634, 0.46107358199142606, 1.7797463173063606, 0.8648122390861165, 1.7797904461596725, -1.0117639532832405, 1.7798273537990834], [-0.5470493658634797, 0.33018152099185216, -1.119089754226156, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5129960724566422, -1.7798135506798376, -0.4609291114365542, -1.779840914364154, -0.8650801853648311, -1.7798114433503558, 1.0116215056600828, -1.7798428944094702, 1.7798232994357222]]\n", - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[None], [None, None], [None, None, None], [None, None, None, None]]\n", - "ipdb> q\n" - ] - }, - { - "ename": "BdbQuit", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sdv.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotFittedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'SDV instance has not been fitted'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, table_name, num_rows, reset_primary_keys, sample_children)\u001b[0m\n\u001b[1;32m 228\u001b[0m }\n\u001b[1;32m 229\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_rows\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_children\u001b[0;34m(self, table_name, sampled)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchild_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchild_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_row\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_table\u001b[0;34m(self, table_name, parent_name, parent_row, sampled)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mprevious\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_rows\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_children\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_children\u001b[0;34m(self, table_name, sampled)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchild_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchild_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_row\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_table\u001b[0;34m(self, table_name, parent_name, parent_row, sampled)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 170\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_parameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mextension\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 171\u001b[0m \u001b[0mnum_rows\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mextension\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'child_rows'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36mset_parameters\u001b[0;34m(self, parameters)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'distribution'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistribution\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mparameters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_unflatten_gaussian_copula\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'univariates'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'type'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistribution\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_unflatten_gaussian_copula\u001b[0;34m(self, model_parameters)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'covariance'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'covariance'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_prepare_sampled_covariance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_prepare_sampled_covariance\u001b[0;34m(self, covariance)\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[1;32m 101\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_prepare_sampled_covariance\u001b[0;34m(self, covariance)\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[1;32m 101\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mBdbQuit\u001b[0m: " - ] + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 ES F 50\n", + " 1 1 FR NaN 61\n", + " 2 2 UK F 19\n", + " 3 3 USA F 37\n", + " 4 4 UK M 14\n", + " 5 5 USA M 20\n", + " 6 6 FR F 41\n", + " 7 7 UK NaN 15\n", + " 8 8 USA F 17\n", + " 9 9 UK M 27,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 5 mobile android\n", + " 1 1 3 tablet ios\n", + " 2 2 5 mobile ios\n", + " 3 3 2 mobile ios\n", + " 4 4 1 mobile ios\n", + " 5 5 8 mobile ios\n", + " 6 6 3 mobile ios\n", + " 7 7 2 tablet ios\n", + " 8 8 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount \\\n", + " 0 0 1 2019-01-05 19:44:17.109082880 108.520480 \n", + " 1 1 0 2019-01-05 19:44:17.109082880 108.520446 \n", + " 2 2 4 2019-01-05 19:44:17.109082880 108.520456 \n", + " 3 3 4 2019-01-05 19:44:17.109082880 108.520463 \n", + " 4 4 0 2019-01-05 19:44:17.109082880 108.520451 \n", + " 5 5 4 2019-01-14 10:59:23.654829312 -293.744075 \n", + " 6 6 0 2019-01-08 21:06:22.124677120 -157.555578 \n", + " 7 7 4 2019-01-05 19:44:17.109082880 108.520405 \n", + " 8 8 4 2019-01-20 05:49:39.108628736 1436.388992 \n", + " 9 9 0 2019-01-13 01:21:03.496772608 -688.014399 \n", + " 10 10 5 2019-01-15 19:14:13.132234496 92.573682 \n", + " 11 11 0 2019-01-15 19:13:47.834800128 88.733610 \n", + " 12 12 6 2019-01-15 19:13:25.752608512 87.612329 \n", + " 13 13 6 2019-01-15 19:14:06.131340544 91.149239 \n", + " 14 14 4 2019-01-05 19:44:17.109082880 108.520315 \n", + " 15 15 4 2019-01-10 14:54:25.195068416 -405.271997 \n", + " 16 16 0 2019-01-06 11:03:17.424094208 321.267007 \n", + " 17 17 1 2019-01-15 19:13:30.394386944 89.473644 \n", + " 18 18 5 2019-01-15 19:13:23.178177536 87.084668 \n", + " 19 19 5 2019-01-15 19:13:53.814526976 87.328847 \n", + " 20 20 6 2019-01-15 19:13:47.778595840 87.698023 \n", + " 21 21 0 2019-01-05 19:44:17.109082880 108.520444 \n", + " 22 22 4 2019-01-03 13:01:55.231368192 -171.151299 \n", + " 23 23 4 2019-01-10 03:03:38.631104256 -711.512501 \n", + " 24 24 5 2019-01-15 19:13:39.923056128 92.779222 \n", + " 25 25 5 2019-01-15 19:13:54.239550976 87.046705 \n", + " 26 26 1 2019-01-15 19:13:51.475357952 89.768894 \n", + " 27 27 6 2019-01-15 19:13:38.378663168 85.864682 \n", + " 28 28 7 2019-01-20 22:53:01.679367168 84.550410 \n", + " 29 29 7 2019-01-20 22:53:01.672930560 84.739006 \n", + " 30 30 1 2019-01-20 22:53:01.444472832 84.790641 \n", + " 31 31 8 2019-01-20 22:53:01.404470784 84.595276 \n", + " 32 32 4 2019-01-05 19:44:17.109082880 108.520411 \n", + " 33 33 4 2019-01-15 17:22:02.305563904 -602.769803 \n", + " 34 34 1 2019-01-15 04:01:37.883063808 -493.475733 \n", + " 35 35 0 2019-01-15 19:13:20.536071936 82.887047 \n", + " 36 36 5 2019-01-15 19:13:50.685021696 90.577240 \n", + " 37 37 6 2019-01-15 19:14:05.128847616 91.543630 \n", + " 38 38 6 2019-01-15 19:13:10.650768896 90.478404 \n", + " 39 39 7 2019-01-20 22:53:01.680873472 84.876897 \n", + " 40 40 7 2019-01-20 22:53:01.735595520 85.080808 \n", + " 41 41 0 2019-01-20 22:53:01.332932352 84.584933 \n", + " 42 42 8 2019-01-20 22:53:01.451664640 84.779657 \n", + " \n", + " approved \n", + " 0 False \n", + " 1 False \n", + " 2 False \n", + " 3 False \n", + " 4 False \n", + " 5 True \n", + " 6 True \n", + " 7 False \n", + " 8 True \n", + " 9 True \n", + " 10 True \n", + " 11 True \n", + " 12 True \n", + " 13 True \n", + " 14 False \n", + " 15 True \n", + " 16 True \n", + " 17 True \n", + " 18 True \n", + " 19 True \n", + " 20 True \n", + " 21 False \n", + " 22 True \n", + " 23 True \n", + " 24 True \n", + " 25 True \n", + " 26 True \n", + " 27 True \n", + " 28 True \n", + " 29 True \n", + " 30 True \n", + " 31 True \n", + " 32 False \n", + " 33 True \n", + " 34 True \n", + " 35 True \n", + " 36 True \n", + " 37 True \n", + " 38 True \n", + " 39 True \n", + " 40 True \n", + " 41 True \n", + " 42 True }" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ diff --git a/examples/1. Quickstart - Single Table - In Memory.ipynb b/examples/1. Quickstart - Single Table - In Memory.ipynb index 6a07eeb32..b9f48bbd2 100644 --- a/examples/1. Quickstart - Single Table - In Memory.ipynb +++ b/examples/1. Quickstart - Single Table - In Memory.ipynb @@ -232,14 +232,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:45:45,775 - INFO - modeler - Modeling data\n", - "2020-02-01 19:45:45,776 - INFO - metadata - Loading transformer NumericalTransformer for field integer\n", - "2020-02-01 19:45:45,777 - INFO - metadata - Loading transformer NumericalTransformer for field float\n", - "2020-02-01 19:45:45,777 - INFO - metadata - Loading transformer CategoricalTransformer for field categorical\n", - "2020-02-01 19:45:45,778 - INFO - metadata - Loading transformer BooleanTransformer for field bool\n", - "2020-02-01 19:45:45,779 - INFO - metadata - Loading transformer NumericalTransformer for field nullable\n", - "2020-02-01 19:45:45,779 - INFO - metadata - Loading transformer DatetimeTransformer for field datetime\n", - "2020-02-01 19:45:45,824 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:25:11,511 - INFO - modeler - Modeling data\n", + "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field integer\n", + "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field float\n", + "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer CategoricalTransformer for field categorical\n", + "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer BooleanTransformer for field bool\n", + "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer NumericalTransformer for field nullable\n", + "2020-07-02 14:25:11,514 - INFO - metadata - Loading transformer DatetimeTransformer for field datetime\n", + "2020-07-02 14:25:11,551 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:11,564 - INFO - modeler - Modeling Complete\n" ] } ], @@ -291,52 +292,82 @@ " \n", " 0\n", " 0\n", - " 1\n", - " 0.080051\n", - " a\n", " NaN\n", " NaN\n", - " NaT\n", + " b\n", + " False\n", + " NaN\n", + " 2010-01-21 11:39:29.986688000\n", " \n", " \n", " 1\n", " 1\n", - " 0\n", - " -0.015712\n", - " a\n", - " False\n", " NaN\n", - " 2009-12-08 11:20:58.439345408\n", + " NaN\n", + " NaN\n", + " True\n", + " NaN\n", + " 2010-02-17 13:51:34.408565760\n", " \n", " \n", " 2\n", " 2\n", - " 1\n", - " 0.142979\n", - " a\n", - " True\n", - " 4.971408\n", - " 2010-01-08 00:32:21.629585920\n", + " 3.0\n", + " 0.304326\n", + " c\n", + " False\n", + " NaN\n", + " 2010-02-25 12:07:19.982103552\n", " \n", " \n", " 3\n", " 3\n", - " 1\n", - " 0.133913\n", + " 2.0\n", + " 0.174580\n", " a\n", " True\n", " NaN\n", - " 2010-01-14 21:37:22.051623936\n", + " 2010-01-21 04:04:59.336169472\n", " \n", " \n", " 4\n", " 4\n", - " 1\n", - " 0.159480\n", + " 3.0\n", + " 0.208637\n", " a\n", + " NaN\n", + " NaN\n", + " NaT\n", + " \n", + " \n", + " 5\n", + " 5\n", + " 1.0\n", + " 0.026796\n", + " b\n", + " NaN\n", + " NaN\n", + " NaT\n", + " \n", + " \n", + " 6\n", + " 6\n", + " 2.0\n", + " 0.166949\n", + " NaN\n", + " False\n", + " NaN\n", + " 2010-01-28 23:23:34.873413888\n", + " \n", + " \n", + " 7\n", + " 7\n", + " 1.0\n", + " 0.086972\n", + " b\n", " False\n", - " 4.294080\n", - " 2010-01-09 21:09:11.245925888\n", + " NaN\n", + " 2010-01-08 09:44:47.101891840\n", " \n", " \n", "\n", @@ -344,18 +375,24 @@ ], "text/plain": [ " index integer float categorical bool nullable \\\n", - "0 0 1 0.080051 a NaN NaN \n", - "1 1 0 -0.015712 a False NaN \n", - "2 2 1 0.142979 a True 4.971408 \n", - "3 3 1 0.133913 a True NaN \n", - "4 4 1 0.159480 a False 4.294080 \n", + "0 0 NaN NaN b False NaN \n", + "1 1 NaN NaN NaN True NaN \n", + "2 2 3.0 0.304326 c False NaN \n", + "3 3 2.0 0.174580 a True NaN \n", + "4 4 3.0 0.208637 a NaN NaN \n", + "5 5 1.0 0.026796 b NaN NaN \n", + "6 6 2.0 0.166949 NaN False NaN \n", + "7 7 1.0 0.086972 b False NaN \n", "\n", " datetime \n", - "0 NaT \n", - "1 2009-12-08 11:20:58.439345408 \n", - "2 2010-01-08 00:32:21.629585920 \n", - "3 2010-01-14 21:37:22.051623936 \n", - "4 2010-01-09 21:09:11.245925888 " + "0 2010-01-21 11:39:29.986688000 \n", + "1 2010-02-17 13:51:34.408565760 \n", + "2 2010-02-25 12:07:19.982103552 \n", + "3 2010-01-21 04:04:59.336169472 \n", + "4 NaT \n", + "5 NaT \n", + "6 2010-01-28 23:23:34.873413888 \n", + "7 2010-01-08 09:44:47.101891840 " ] }, "execution_count": 5, diff --git a/examples/2. Quickstart - Single Table - Census.ipynb b/examples/2. Quickstart - Single Table - Census.ipynb index 595fe502c..6857b7827 100644 --- a/examples/2. Quickstart - Single Table - Census.ipynb +++ b/examples/2. Quickstart - Single Table - Census.ipynb @@ -312,23 +312,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-06-25 23:39:05,857 - INFO - modeler - Modeling census\n", - "2020-06-25 23:39:05,857 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-06-25 23:39:05,858 - INFO - metadata - Loading transformer CategoricalTransformer for field workclass\n", - "2020-06-25 23:39:05,858 - INFO - metadata - Loading transformer NumericalTransformer for field fnlwgt\n", - "2020-06-25 23:39:05,859 - INFO - metadata - Loading transformer CategoricalTransformer for field education\n", - "2020-06-25 23:39:05,859 - INFO - metadata - Loading transformer NumericalTransformer for field education-num\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field marital-status\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field occupation\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field relationship\n", - "2020-06-25 23:39:05,861 - INFO - metadata - Loading transformer CategoricalTransformer for field race\n", - "2020-06-25 23:39:05,861 - INFO - metadata - Loading transformer CategoricalTransformer for field sex\n", - "2020-06-25 23:39:05,862 - INFO - metadata - Loading transformer NumericalTransformer for field capital-gain\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer NumericalTransformer for field capital-loss\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer NumericalTransformer for field hours-per-week\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer CategoricalTransformer for field native-country\n", - "2020-06-25 23:39:05,864 - INFO - metadata - Loading transformer CategoricalTransformer for field income\n", - "2020-06-25 23:39:06,119 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:25:31,656 - INFO - modeler - Modeling census\n", + "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", + "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer CategoricalTransformer for field workclass\n", + "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer CategoricalTransformer for field education\n", + "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field marital-status\n", + "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field occupation\n", + "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field relationship\n", + "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field race\n", + "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field sex\n", + "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer CategoricalTransformer for field native-country\n", + "2020-07-02 14:25:31,663 - INFO - metadata - Loading transformer CategoricalTransformer for field income\n", + "2020-07-02 14:25:31,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:31,928 - INFO - modeler - Modeling Complete\n" ] } ], @@ -385,91 +386,91 @@ " \n", " \n", " 0\n", - " 42\n", + " 35\n", " Private\n", - " 47585\n", - " Assoc-voc\n", + " 468185\n", + " HS-grad\n", " 9\n", - " Never-married\n", - " Prof-specialty\n", - " Not-in-family\n", + " Married-civ-spouse\n", + " Sales\n", + " Husband\n", " White\n", " Male\n", - " 2398\n", - " 3\n", - " 44\n", + " 2039\n", + " 436\n", + " 48\n", " United-States\n", " <=50K\n", " \n", " \n", " 1\n", - " 56\n", + " 47\n", " Private\n", - " 92870\n", + " 104447\n", " HS-grad\n", - " 13\n", + " 7\n", " Married-civ-spouse\n", - " Sales\n", - " Not-in-family\n", + " Adm-clerical\n", + " Husband\n", " White\n", " Male\n", - " 3624\n", - " 269\n", - " 30\n", + " -703\n", + " 81\n", + " 53\n", " United-States\n", " <=50K\n", " \n", " \n", " 2\n", - " 54\n", + " 20\n", " Private\n", - " 218711\n", + " 231391\n", " HS-grad\n", - " 13\n", + " 11\n", " Never-married\n", - " Other-service\n", + " Exec-managerial\n", " Not-in-family\n", " White\n", - " Female\n", - " 4420\n", - " -269\n", - " 43\n", + " Male\n", + " 654\n", + " -8\n", + " 57\n", " United-States\n", " <=50K\n", " \n", " \n", " 3\n", - " 31\n", - " ?\n", - " 71625\n", - " Some-college\n", - " 14\n", + " 35\n", + " Private\n", + " 223275\n", + " Masters\n", + " 12\n", " Married-civ-spouse\n", - " Prof-specialty\n", - " Husband\n", + " Sales\n", + " Not-in-family\n", " White\n", " Male\n", - " 3196\n", - " 898\n", - " 51\n", + " -392\n", + " -178\n", + " 71\n", " United-States\n", " <=50K\n", " \n", " \n", " 4\n", - " 37\n", + " 26\n", " Private\n", - " 184276\n", - " Bachelors\n", - " 14\n", - " Never-married\n", - " Craft-repair\n", + " -11408\n", + " HS-grad\n", + " 8\n", + " Married-civ-spouse\n", + " Machine-op-inspct\n", " Husband\n", " White\n", " Male\n", - " 4785\n", - " 907\n", - " 32\n", + " 5799\n", + " 13\n", + " 47\n", " United-States\n", " <=50K\n", " \n", @@ -478,26 +479,26 @@ "" ], "text/plain": [ - " age workclass fnlwgt education education-num marital-status \\\n", - "0 42 Private 47585 Assoc-voc 9 Never-married \n", - "1 56 Private 92870 HS-grad 13 Married-civ-spouse \n", - "2 54 Private 218711 HS-grad 13 Never-married \n", - "3 31 ? 71625 Some-college 14 Married-civ-spouse \n", - "4 37 Private 184276 Bachelors 14 Never-married \n", + " age workclass fnlwgt education education-num marital-status \\\n", + "0 35 Private 468185 HS-grad 9 Married-civ-spouse \n", + "1 47 Private 104447 HS-grad 7 Married-civ-spouse \n", + "2 20 Private 231391 HS-grad 11 Never-married \n", + "3 35 Private 223275 Masters 12 Married-civ-spouse \n", + "4 26 Private -11408 HS-grad 8 Married-civ-spouse \n", "\n", - " occupation relationship race sex capital-gain \\\n", - "0 Prof-specialty Not-in-family White Male 2398 \n", - "1 Sales Not-in-family White Male 3624 \n", - "2 Other-service Not-in-family White Female 4420 \n", - "3 Prof-specialty Husband White Male 3196 \n", - "4 Craft-repair Husband White Male 4785 \n", + " occupation relationship race sex capital-gain \\\n", + "0 Sales Husband White Male 2039 \n", + "1 Adm-clerical Husband White Male -703 \n", + "2 Exec-managerial Not-in-family White Male 654 \n", + "3 Sales Not-in-family White Male -392 \n", + "4 Machine-op-inspct Husband White Male 5799 \n", "\n", " capital-loss hours-per-week native-country income \n", - "0 3 44 United-States <=50K \n", - "1 269 30 United-States <=50K \n", - "2 -269 43 United-States <=50K \n", - "3 898 51 United-States <=50K \n", - "4 907 32 United-States <=50K " + "0 436 48 United-States <=50K \n", + "1 81 53 United-States <=50K \n", + "2 -8 57 United-States <=50K \n", + "3 -178 71 United-States <=50K \n", + "4 13 47 United-States <=50K " ] }, "execution_count": 7, @@ -520,7 +521,7 @@ { "data": { "text/plain": [ - "-43.51455835774797" + "-43.151506729008716" ] }, "execution_count": 8, diff --git a/examples/3. Quickstart - Multitable - Files.ipynb b/examples/3. Quickstart - Multitable - Files.ipynb index b4183f965..d00a5989e 100644 --- a/examples/3. Quickstart - Multitable - Files.ipynb +++ b/examples/3. Quickstart - Multitable - Files.ipynb @@ -11,21 +11,38 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:45:58,147 - INFO - modeler - Modeling customers\n", - "2020-02-01 19:45:58,148 - INFO - metadata - Loading table customers\n", - "2020-02-01 19:45:58,156 - INFO - metadata - Loading transformer CategoricalTransformer for field cust_postal_code\n", - "2020-02-01 19:45:58,157 - INFO - metadata - Loading transformer NumericalTransformer for field phone_number1\n", - "2020-02-01 19:45:58,158 - INFO - metadata - Loading transformer NumericalTransformer for field credit_limit\n", - "2020-02-01 19:45:58,158 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-02-01 19:45:58,174 - INFO - modeler - Modeling orders\n", - "2020-02-01 19:45:58,175 - INFO - metadata - Loading table orders\n", - "2020-02-01 19:45:58,179 - INFO - metadata - Loading transformer NumericalTransformer for field order_total\n", - "2020-02-01 19:45:58,183 - INFO - modeler - Modeling order_items\n", - "2020-02-01 19:45:58,183 - INFO - metadata - Loading table order_items\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer CategoricalTransformer for field product_id\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer NumericalTransformer for field unit_price\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer NumericalTransformer for field quantity\n", - "2020-02-01 19:45:58,824 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:25:44,066 - INFO - modeler - Modeling customers\n", + "2020-07-02 14:25:44,067 - INFO - metadata - Loading table customers\n", + "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer CategoricalTransformer for field cust_postal_code\n", + "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer NumericalTransformer for field phone_number1\n", + "2020-07-02 14:25:44,075 - INFO - metadata - Loading transformer NumericalTransformer for field credit_limit\n", + "2020-07-02 14:25:44,076 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", + "2020-07-02 14:25:44,094 - INFO - modeler - Modeling orders\n", + "2020-07-02 14:25:44,095 - INFO - metadata - Loading table orders\n", + "2020-07-02 14:25:44,098 - INFO - metadata - Loading transformer NumericalTransformer for field order_total\n", + "2020-07-02 14:25:44,101 - INFO - modeler - Modeling order_items\n", + "2020-07-02 14:25:44,102 - INFO - metadata - Loading table order_items\n", + "2020-07-02 14:25:44,106 - INFO - metadata - Loading transformer CategoricalTransformer for field product_id\n", + "2020-07-02 14:25:44,107 - INFO - metadata - Loading transformer NumericalTransformer for field unit_price\n", + "2020-07-02 14:25:44,108 - INFO - metadata - Loading transformer NumericalTransformer for field quantity\n", + "2020-07-02 14:25:44,120 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,138 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,147 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,155 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,164 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,171 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,183 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,189 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,196 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,210 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,231 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,258 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,281 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,303 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,370 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:25:44,496 - INFO - modeler - Modeling Complete\n" ] } ], @@ -45,9 +62,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:45:58,874 - INFO - metadata - Loading table customers\n", - "2020-02-01 19:45:58,878 - INFO - metadata - Loading table orders\n", - "2020-02-01 19:45:58,881 - INFO - metadata - Loading table order_items\n" + "2020-07-02 14:25:46,641 - INFO - metadata - Loading table customers\n", + "2020-07-02 14:25:46,646 - INFO - metadata - Loading table orders\n", + "2020-07-02 14:25:46,649 - INFO - metadata - Loading table order_items\n" ] } ], @@ -97,41 +114,41 @@ " 0\n", " 0\n", " 63145\n", - " 5941487535\n", - " 1694\n", - " UK\n", + " 7811758514\n", + " 510\n", + " US\n", " \n", " \n", " 1\n", " 1\n", - " 11371\n", - " 8369261990\n", - " 501\n", - " UK\n", + " 20166\n", + " 6720163478\n", + " 317\n", + " SPAIN\n", " \n", " \n", " 2\n", " 2\n", " 11371\n", - " 6106157274\n", - " 1287\n", - " CANADA\n", + " 7447861719\n", + " 1371\n", + " SPAIN\n", " \n", " \n", " 3\n", " 3\n", - " 6096\n", - " 4243270462\n", - " 1584\n", - " UK\n", + " 63145\n", + " 6835026704\n", + " 1024\n", + " US\n", " \n", " \n", " 4\n", " 4\n", - " 6096\n", - " 5705933008\n", - " 584\n", - " UK\n", + " 11371\n", + " 5371158268\n", + " 739\n", + " FRANCE\n", " \n", " \n", "\n", @@ -139,11 +156,11 @@ ], "text/plain": [ " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 0 63145 5941487535 1694 UK\n", - "1 1 11371 8369261990 501 UK\n", - "2 2 11371 6106157274 1287 CANADA\n", - "3 3 6096 4243270462 1584 UK\n", - "4 4 6096 5705933008 584 UK" + "0 0 63145 7811758514 510 US\n", + "1 1 20166 6720163478 317 SPAIN\n", + "2 2 11371 7447861719 1371 SPAIN\n", + "3 3 63145 6835026704 1024 US\n", + "4 4 11371 5371158268 739 FRANCE" ] }, "execution_count": 3, @@ -286,32 +303,32 @@ " \n", " 0\n", " 0\n", - " 1\n", - " 581\n", + " 0\n", + " 1784\n", " \n", " \n", " 1\n", " 1\n", " 1\n", - " 707\n", + " 1864\n", " \n", " \n", " 2\n", " 2\n", - " 1\n", - " 878\n", + " 4\n", + " 2677\n", " \n", " \n", " 3\n", " 3\n", - " 2\n", - " 285\n", + " 1\n", + " 1485\n", " \n", " \n", " 4\n", " 4\n", - " 3\n", - " 840\n", + " 1\n", + " 2435\n", " \n", " \n", "\n", @@ -319,11 +336,11 @@ ], "text/plain": [ " order_id customer_id order_total\n", - "0 0 1 581\n", - "1 1 1 707\n", - "2 2 1 878\n", - "3 3 2 285\n", - "4 4 3 840" + "0 0 0 1784\n", + "1 1 1 1864\n", + "2 2 4 2677\n", + "3 3 1 1485\n", + "4 4 1 2435" ] }, "execution_count": 5, @@ -456,41 +473,41 @@ " \n", " 0\n", " 0\n", - " 0\n", + " 2\n", " 6\n", - " 62\n", - " 3\n", + " 208\n", + " 7\n", " \n", " \n", " 1\n", " 1\n", - " 0\n", + " 7\n", " 10\n", - " 108\n", - " 5\n", + " 116\n", + " 2\n", " \n", " \n", " 2\n", " 2\n", " 0\n", - " 6\n", + " 10\n", " 57\n", - " 3\n", + " 6\n", " \n", " \n", " 3\n", " 3\n", - " 1\n", - " 10\n", - " 114\n", + " 3\n", " 6\n", + " -24\n", + " 1\n", " \n", " \n", " 4\n", " 4\n", - " 1\n", - " 10\n", - " 81\n", + " 3\n", + " 6\n", + " 38\n", " 4\n", " \n", " \n", @@ -499,11 +516,11 @@ ], "text/plain": [ " order_item_id order_id product_id unit_price quantity\n", - "0 0 0 6 62 3\n", - "1 1 0 10 108 5\n", - "2 2 0 6 57 3\n", - "3 3 1 10 114 6\n", - "4 4 1 10 81 4" + "0 0 2 6 208 7\n", + "1 1 7 10 116 2\n", + "2 2 0 10 57 6\n", + "3 3 3 6 -24 1\n", + "4 4 3 6 38 4" ] }, "execution_count": 7, diff --git a/examples/4. Anonymization.ipynb b/examples/4. Anonymization.ipynb index 65382e218..5ff4caa51 100644 --- a/examples/4. Anonymization.ipynb +++ b/examples/4. Anonymization.ipynb @@ -185,13 +185,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:46:01,587 - INFO - modeler - Modeling anonymized\n", - "2020-02-01 19:46:01,589 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-02-01 19:46:01,589 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-02-01 19:46:01,640 - INFO - modeler - Modeling normal\n", - "2020-02-01 19:46:01,641 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-02-01 19:46:01,641 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-02-01 19:46:01,658 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:26:26,044 - INFO - modeler - Modeling anonymized\n", + "2020-07-02 14:26:26,044 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", + "2020-07-02 14:26:26,045 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", + "2020-07-02 14:26:26,087 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:26:26,092 - INFO - modeler - Modeling normal\n", + "2020-07-02 14:26:26,092 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", + "2020-07-02 14:26:26,093 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", + "2020-07-02 14:26:26,109 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:26:26,113 - INFO - modeler - Modeling Complete\n" ] } ], @@ -250,44 +252,37 @@ " \n", " 0\n", " 0\n", - " Brandon\n", - " 4134540887507040\n", + " Pamela\n", + " 4801288395665668\n", " \n", " \n", " 1\n", " 1\n", - " Larry\n", - " 4288700396707168\n", + " Kimberly\n", + " 4801288395665668\n", " \n", " \n", " 2\n", " 2\n", - " Brandon\n", - " 4869967025488786\n", + " Pamela\n", + " 4801288395665668\n", " \n", " \n", " 3\n", " 3\n", - " Brandon\n", - " 4288700396707168\n", - " \n", - " \n", - " 4\n", - " 4\n", - " Larry\n", - " 4869967025488786\n", + " Kimberly\n", + " 4592405566223480\n", " \n", " \n", "\n", "" ], "text/plain": [ - " index name credit_card_number\n", - "0 0 Brandon 4134540887507040\n", - "1 1 Larry 4288700396707168\n", - "2 2 Brandon 4869967025488786\n", - "3 3 Brandon 4288700396707168\n", - "4 4 Larry 4869967025488786" + " index name credit_card_number\n", + "0 0 Pamela 4801288395665668\n", + "1 1 Kimberly 4801288395665668\n", + "2 2 Pamela 4801288395665668\n", + "3 3 Kimberly 4592405566223480" ] }, "execution_count": 7, @@ -337,19 +332,19 @@ " 0\n", " 0\n", " Joe\n", - " 1111222233334444\n", + " 8888888888888888\n", " \n", " \n", " 1\n", " 1\n", - " Bill\n", + " Joe\n", " 0000000000000000\n", " \n", " \n", " 2\n", " 2\n", - " Jeff\n", - " 1111222233334444\n", + " Joe\n", + " 8888888888888888\n", " \n", " \n", " 3\n", @@ -357,23 +352,16 @@ " Bill\n", " 8888888888888888\n", " \n", - " \n", - " 4\n", - " 4\n", - " Jeff\n", - " 0000000000000000\n", - " \n", " \n", "\n", "" ], "text/plain": [ " index name credit_card_number\n", - "0 0 Joe 1111222233334444\n", - "1 1 Bill 0000000000000000\n", - "2 2 Jeff 1111222233334444\n", - "3 3 Bill 8888888888888888\n", - "4 4 Jeff 0000000000000000" + "0 0 Joe 8888888888888888\n", + "1 1 Joe 0000000000000000\n", + "2 2 Joe 8888888888888888\n", + "3 3 Bill 8888888888888888" ] }, "execution_count": 8, diff --git a/examples/5. Generate Metadata from Dataframes.ipynb b/examples/5. Generate Metadata from Dataframes.ipynb index 5768ce084..91b001036 100644 --- a/examples/5. Generate Metadata from Dataframes.ipynb +++ b/examples/5. Generate Metadata from Dataframes.ipynb @@ -163,27 +163,26 @@ { "data": { "text/plain": [ - "{'tables': {'users': {'fields': {'user_id': {'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'}},\n", + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}},\n", " 'primary_key': 'user_id'},\n", - " 'sessions': {'fields': {'user_id': {'type': 'id',\n", + " 'sessions': {'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id',\n", " 'subtype': 'integer',\n", " 'ref': {'table': 'users', 'field': 'user_id'}},\n", " 'os': {'type': 'categorical'},\n", - " 'device': {'type': 'categorical'},\n", - " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'device': {'type': 'categorical'}},\n", " 'primary_key': 'session_id'},\n", " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", " 'format': '%Y-%m-%d'},\n", - " 'approved': {'type': 'boolean'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", " 'session_id': {'type': 'id',\n", " 'subtype': 'integer',\n", - " 'ref': {'table': 'sessions', 'field': 'session_id'}}},\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'approved': {'type': 'boolean'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'}},\n", " 'primary_key': 'transaction_id'}}}" ] }, @@ -333,7 +332,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 14, @@ -370,10 +369,10 @@ "\n", "users\n", "\n", - "user_id : id - integer\n", - "gender : categorical\n", - "age : numerical - integer\n", - "country : categorical\n", + "gender : categorical\n", + "user_id : id - integer\n", + "country : categorical\n", + "age : numerical - integer\n", "\n", "Primary key: user_id\n", "\n", @@ -383,10 +382,10 @@ "\n", "sessions\n", "\n", - "user_id : id - integer\n", - "os : categorical\n", - "device : categorical\n", - "session_id : id - integer\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "os : categorical\n", + "device : categorical\n", "\n", "Primary key: session_id\n", "Foreign key (users): user_id\n", @@ -405,10 +404,10 @@ "transactions\n", "\n", "timestamp : datetime\n", - "approved : boolean\n", - "amount : numerical - float\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", + "session_id : id - integer\n", + "transaction_id : id - integer\n", + "approved : boolean\n", + "amount : numerical - float\n", "\n", "Primary key: transaction_id\n", "Foreign key (sessions): session_id\n", @@ -424,7 +423,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 15, diff --git a/examples/Demo - Walmart.ipynb b/examples/Demo - Walmart.ipynb index a4015bb0f..d3bbe80e9 100644 --- a/examples/Demo - Walmart.ipynb +++ b/examples/Demo - Walmart.ipynb @@ -100,9 +100,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:46:10,378 - INFO - metadata - Loading table stores\n", - "2020-02-01 19:46:10,386 - INFO - metadata - Loading table features\n", - "2020-02-01 19:46:10,401 - INFO - metadata - Loading table depts\n" + "2020-07-02 14:27:25,479 - INFO - metadata - Loading table stores\n", + "2020-07-02 14:27:25,486 - INFO - metadata - Loading table features\n", + "2020-07-02 14:27:25,502 - INFO - metadata - Loading table depts\n" ] } ], @@ -213,7 +213,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -253,35 +253,136 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2020-02-01 19:46:10,769 - INFO - modeler - Modeling stores\n", - "2020-02-01 19:46:10,769 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", - "2020-02-01 19:46:10,770 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", - "2020-02-01 19:46:10,781 - INFO - modeler - Modeling features\n", - "2020-02-01 19:46:10,781 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-02-01 19:46:10,782 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", - "2020-02-01 19:46:10,782 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-02-01 19:46:10,783 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", - "2020-02-01 19:46:10,783 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", - "2020-02-01 19:46:10,784 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", - "2020-02-01 19:46:10,785 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", - "2020-02-01 19:46:10,785 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", - "2020-02-01 19:46:10,786 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", - "2020-02-01 19:46:10,786 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", - "2020-02-01 19:46:10,787 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", - "2020-02-01 19:46:12,119 - INFO - modeler - Modeling depts\n", - "2020-02-01 19:46:12,119 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-02-01 19:46:12,119 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", - "2020-02-01 19:46:12,120 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", - "2020-02-01 19:46:12,120 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "/home/xals/.virtualenvs/SDV.merge/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + "2020-07-02 14:27:31,883 - INFO - modeler - Modeling stores\n", + "2020-07-02 14:27:31,883 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-02 14:27:31,884 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", + "2020-07-02 14:27:31,895 - INFO - modeler - Modeling depts\n", + "2020-07-02 14:27:31,895 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-02 14:27:32,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,272 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,288 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,302 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,319 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,335 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,367 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,383 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,399 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,414 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,430 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,445 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,460 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,475 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,490 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,508 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,522 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,536 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,550 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,563 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,576 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,594 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,607 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,621 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,636 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,664 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,695 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,710 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,723 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,738 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,752 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,765 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,781 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,796 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,809 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,823 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,837 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,868 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,881 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,896 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,907 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,919 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:32,938 - INFO - modeler - Modeling features\n", + "2020-07-02 14:27:32,939 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-02 14:27:32,939 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-02 14:27:32,940 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-02 14:27:32,943 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-02 14:27:32,993 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,043 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,073 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,133 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,189 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,220 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,253 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,285 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,314 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,342 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,371 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-02 14:27:33,399 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,425 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,452 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,504 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,531 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,558 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,592 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,653 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,683 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,743 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,770 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,797 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,826 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,880 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,910 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,939 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,967 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:33,996 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,025 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,051 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,078 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,107 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,134 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,160 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,187 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,216 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,246 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,275 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,303 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,333 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 14:27:34,452 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", " return getattr(obj, method)(*args, **kwds)\n", - "2020-02-01 19:46:13,498 - INFO - modeler - Modeling Complete\n" + "2020-07-02 14:27:34,680 - INFO - modeler - Modeling Complete\n" ] } ], @@ -314,7 +415,31 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sdv.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotFittedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'SDV instance has not been fitted'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 335\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, table_name, num_rows, reset_primary_keys, sample_children, sample_parents)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 300\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_finalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_finalize\u001b[0;34m(self, sampled_data)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mparents\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparent_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparents\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mparent_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_find_parent_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mforeign_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_foreign_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mforeign_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparent_ids\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_find_parent_ids\u001b[0;34m(self, table_name, parent_name, sampled_data)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0mparent_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m \u001b[0mparent_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_find_parent_id\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpdfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparent_ids\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_find_parent_id\u001b[0;34m(self, row, pdfs, num_rows)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparent_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpdf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpdfs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 214\u001b[0;31m \u001b[0mlikelihoods\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparent_id\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 215\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinAlgError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0;31m# Singular matrix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/multivariate/gaussian.py\u001b[0m in \u001b[0;36mprobability_density\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 133\u001b[0m \"\"\"\n\u001b[1;32m 134\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mtransformed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_transform_to_normal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultivariate_normal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtransformed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcov\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/multivariate/gaussian.py\u001b[0m in \u001b[0;36m_transform_to_normal\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcolumn_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munivariate\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munivariates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mcolumn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcolumn_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munivariate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolumn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mEPSILON\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mEPSILON\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mppf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumn_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/univariate/base.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0mCumulative\u001b[0m \u001b[0mdistribution\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpoints\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 307\u001b[0m \"\"\"\n\u001b[0;32m--> 308\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcumulative_distribution\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 309\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpercent_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/univariate/base.py\u001b[0m in \u001b[0;36mcumulative_distribution\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 523\u001b[0m \"\"\"\n\u001b[1;32m 524\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 525\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 526\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpercent_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 455\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mlogcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, x, *args, **kwds)\u001b[0m\n\u001b[1;32m 1747\u001b[0m \u001b[0mcond1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open_support_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mscale\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1748\u001b[0m \u001b[0mcond2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0mcond0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1749\u001b[0;31m \u001b[0mcond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcond0\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0mcond1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1750\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtyp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0mplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mcond0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbadvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "samples = sdv.sample_all()" ] @@ -670,7 +795,9 @@ { "cell_type": "code", "execution_count": 9, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { diff --git a/examples/Evaluation.ipynb b/examples/Evaluation.ipynb new file mode 100644 index 000000000..3d2d5bd7e --- /dev/null +++ b/examples/Evaluation.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametablesrowscolumnsdescription
0airbnb-simplified2575140822airbnb user account information, including fir...
1rossmann2101832419rossmann stores promotional information and it...
2walmart342980420walmart stores information (size and type) wit...
\n", + "
" + ], + "text/plain": [ + " name tables rows columns \\\n", + "0 airbnb-simplified 2 5751408 22 \n", + "1 rossmann 2 1018324 19 \n", + "2 walmart 3 429804 20 \n", + "\n", + " description \n", + "0 airbnb user account information, including fir... \n", + "1 rossmann stores promotional information and it... \n", + "2 walmart stores information (size and type) wit... " + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.demo import get_available_demos\n", + "\n", + "get_available_demos()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-06-26 00:11:41,404 - INFO - metadata - Loading table store\n", + "2020-06-26 00:11:41,434 - INFO - metadata - Loading table historical\n", + "2020-06-26 00:11:43,265 - INFO - modeler - Modeling store\n", + "2020-06-26 00:11:43,265 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionOpenSinceYear\n", + "2020-06-26 00:11:43,266 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionOpenSinceMonth\n", + "2020-06-26 00:11:43,266 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionDistance\n", + "2020-06-26 00:11:43,267 - INFO - metadata - Loading transformer BooleanTransformer for field Promo2\n", + "2020-06-26 00:11:43,267 - INFO - metadata - Loading transformer NumericalTransformer for field Promo2SinceYear\n", + "2020-06-26 00:11:43,268 - INFO - metadata - Loading transformer NumericalTransformer for field Promo2SinceWeek\n", + "2020-06-26 00:11:43,268 - INFO - metadata - Loading transformer CategoricalTransformer for field StoreType\n", + "2020-06-26 00:11:43,269 - INFO - metadata - Loading transformer CategoricalTransformer for field Assortment\n", + "2020-06-26 00:11:43,270 - INFO - metadata - Loading transformer CategoricalTransformer for field PromoInterval\n", + "2020-06-26 00:11:43,310 - INFO - modeler - Modeling historical\n", + "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", + "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer NumericalTransformer for field DayOfWeek\n", + "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer NumericalTransformer for field Promo\n", + "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer CategoricalTransformer for field StateHoliday\n", + "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field Open\n", + "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field SchoolHoliday\n", + "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field Customers\n", + "2020-06-26 00:13:14,951 - INFO - modeler - Modeling Complete\n" + ] + }, + { + "ename": "TypeError", + "evalue": "ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0msampled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfirst_table\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mscores\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdemo\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/evaluation.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(synth, real, metadata, root_path, table_name, get_report)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0msynth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_validate_arguments\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msynth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mreport\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mget_report\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/evaluation.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(metadata, real, synthetic)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mreport\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMetricsReport\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mreport\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mreport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/report.py\u001b[0m in \u001b[0;36madd_metrics\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthese\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mreport\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \"\"\"\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mmetric\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/evaluation.py\u001b[0m in \u001b[0;36m_metrics\u001b[0;34m(metadata, real, synthetic)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mconstraint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdetection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mstatistical\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/__init__.py\u001b[0m in \u001b[0;36mmetrics\u001b[0;34m(metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 21\u001b[0m \"\"\"\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdetector\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mLogisticDetector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdetector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36mmetrics\u001b[0;34m(self, metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mMetric\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mnext\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \"\"\"\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_single_table_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_child_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36m_single_table_detection\u001b[0;34m(self, metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 57\u001b[0m auroc = self._compute_auroc(\n\u001b[1;32m 58\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_fields\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m synthetic_tables[table_name][table_fields])\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m yield Metric(\n", + "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36m_compute_auroc\u001b[0;34m(self, real_table, synthetic_table)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreal_table\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_table\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreal_table\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msynthetic_table\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "from sdv.demo import load_demo\n", + "from sdv.evaluation import evaluate\n", + "\n", + "scores = dict()\n", + "for demo in get_available_demos().name[1:]:\n", + " metadata, tables = load_demo(demo, metadata=True)\n", + " \n", + " sdv = SDV()\n", + " sdv.fit(metadata, tables)\n", + " \n", + " first_table = metadata.get_tables()[0]\n", + " sampled = sdv.sample_all(len(tables[first_table]))\n", + " \n", + " score = evaluate(sampled, tables, metadata)\n", + " scores[demo] = score" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scores" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/demo_metadata.json b/examples/demo_metadata.json index a8b345822..bb24a5963 100644 --- a/examples/demo_metadata.json +++ b/examples/demo_metadata.json @@ -2,10 +2,6 @@ "tables": { "users": { "fields": { - "user_id": { - "type": "id", - "subtype": "integer" - }, "age": { "type": "numerical", "subtype": "integer" @@ -13,6 +9,10 @@ "gender": { "type": "categorical" }, + "user_id": { + "type": "id", + "subtype": "integer" + }, "country": { "type": "categorical" } @@ -21,7 +21,11 @@ }, "sessions": { "fields": { - "os": { + "session_id": { + "type": "id", + "subtype": "integer" + }, + "device": { "type": "categorical" }, "user_id": { @@ -32,12 +36,8 @@ "field": "user_id" } }, - "device": { + "os": { "type": "categorical" - }, - "session_id": { - "type": "id", - "subtype": "integer" } }, "primary_key": "session_id" @@ -52,9 +52,6 @@ "type": "id", "subtype": "integer" }, - "approved": { - "type": "boolean" - }, "session_id": { "type": "id", "subtype": "integer", @@ -66,6 +63,9 @@ "amount": { "type": "numerical", "subtype": "float" + }, + "approved": { + "type": "boolean" } }, "primary_key": "transaction_id" From 33a8bff51d7e9f800fde579b54dc312bae7e671f Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 15:51:13 +0200 Subject: [PATCH 05/33] Update tests --- examples/Evaluation.ipynb | 193 ---------------------------------- sdv/metadata.py | 6 -- sdv/sampler.py | 15 +-- tests/integration/test_sdv.py | 36 +++++++ tests/models/test_copulas.py | 18 ++-- tests/test_metadata.py | 22 ---- tests/test_modeler.py | 1 + tests/test_sampler.py | 109 ++++++++++--------- tests/test_sdv.py | 2 +- 9 files changed, 114 insertions(+), 288 deletions(-) delete mode 100644 examples/Evaluation.ipynb create mode 100644 tests/integration/test_sdv.py diff --git a/examples/Evaluation.ipynb b/examples/Evaluation.ipynb deleted file mode 100644 index 3d2d5bd7e..000000000 --- a/examples/Evaluation.ipynb +++ /dev/null @@ -1,193 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametablesrowscolumnsdescription
0airbnb-simplified2575140822airbnb user account information, including fir...
1rossmann2101832419rossmann stores promotional information and it...
2walmart342980420walmart stores information (size and type) wit...
\n", - "
" - ], - "text/plain": [ - " name tables rows columns \\\n", - "0 airbnb-simplified 2 5751408 22 \n", - "1 rossmann 2 1018324 19 \n", - "2 walmart 3 429804 20 \n", - "\n", - " description \n", - "0 airbnb user account information, including fir... \n", - "1 rossmann stores promotional information and it... \n", - "2 walmart stores information (size and type) wit... " - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "from sdv.demo import get_available_demos\n", - "\n", - "get_available_demos()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-06-26 00:11:41,404 - INFO - metadata - Loading table store\n", - "2020-06-26 00:11:41,434 - INFO - metadata - Loading table historical\n", - "2020-06-26 00:11:43,265 - INFO - modeler - Modeling store\n", - "2020-06-26 00:11:43,265 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionOpenSinceYear\n", - "2020-06-26 00:11:43,266 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionOpenSinceMonth\n", - "2020-06-26 00:11:43,266 - INFO - metadata - Loading transformer NumericalTransformer for field CompetitionDistance\n", - "2020-06-26 00:11:43,267 - INFO - metadata - Loading transformer BooleanTransformer for field Promo2\n", - "2020-06-26 00:11:43,267 - INFO - metadata - Loading transformer NumericalTransformer for field Promo2SinceYear\n", - "2020-06-26 00:11:43,268 - INFO - metadata - Loading transformer NumericalTransformer for field Promo2SinceWeek\n", - "2020-06-26 00:11:43,268 - INFO - metadata - Loading transformer CategoricalTransformer for field StoreType\n", - "2020-06-26 00:11:43,269 - INFO - metadata - Loading transformer CategoricalTransformer for field Assortment\n", - "2020-06-26 00:11:43,270 - INFO - metadata - Loading transformer CategoricalTransformer for field PromoInterval\n", - "2020-06-26 00:11:43,310 - INFO - modeler - Modeling historical\n", - "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer NumericalTransformer for field DayOfWeek\n", - "2020-06-26 00:11:43,311 - INFO - metadata - Loading transformer NumericalTransformer for field Promo\n", - "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer CategoricalTransformer for field StateHoliday\n", - "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field Open\n", - "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field SchoolHoliday\n", - "2020-06-26 00:11:43,312 - INFO - metadata - Loading transformer NumericalTransformer for field Customers\n", - "2020-06-26 00:13:14,951 - INFO - modeler - Modeling Complete\n" - ] - }, - { - "ename": "TypeError", - "evalue": "ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0msampled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfirst_table\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mscores\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdemo\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/evaluation.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(synth, real, metadata, root_path, table_name, get_report)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0msynth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_validate_arguments\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msynth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mreport\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mget_report\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/evaluation.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(metadata, real, synthetic)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mreport\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMetricsReport\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mreport\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mreport\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/report.py\u001b[0m in \u001b[0;36madd_metrics\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthese\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0mto\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mreport\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \"\"\"\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mmetric\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/evaluation.py\u001b[0m in \u001b[0;36m_metrics\u001b[0;34m(metadata, real, synthetic)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mconstraint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdetection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mstatistical\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/__init__.py\u001b[0m in \u001b[0;36mmetrics\u001b[0;34m(metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 21\u001b[0m \"\"\"\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdetector\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mLogisticDetector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdetector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36mmetrics\u001b[0;34m(self, metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mMetric\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mnext\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \"\"\"\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_single_table_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_child_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_tables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36m_single_table_detection\u001b[0;34m(self, metadata, real_tables, synthetic_tables)\u001b[0m\n\u001b[1;32m 57\u001b[0m auroc = self._compute_auroc(\n\u001b[1;32m 58\u001b[0m \u001b[0mreal_tables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_fields\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m synthetic_tables[table_name][table_fields])\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m yield Metric(\n", - "\u001b[0;32m~/Projects/MIT/SDMetrics/sdmetrics/detection/tabular/base.py\u001b[0m in \u001b[0;36m_compute_auroc\u001b[0;34m(self, real_table, synthetic_table)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreal_table\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msynthetic_table\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreal_table\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msynthetic_table\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "from sdv.demo import load_demo\n", - "from sdv.evaluation import evaluate\n", - "\n", - "scores = dict()\n", - "for demo in get_available_demos().name[1:]:\n", - " metadata, tables = load_demo(demo, metadata=True)\n", - " \n", - " sdv = SDV()\n", - " sdv.fit(metadata, tables)\n", - " \n", - " first_table = metadata.get_tables()[0]\n", - " sampled = sdv.sample_all(len(tables[first_table]))\n", - " \n", - " score = evaluate(sampled, tables, metadata)\n", - " scores[demo] = score" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scores" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/sdv/metadata.py b/sdv/metadata.py index 94f3e8e55..efa610e58 100644 --- a/sdv/metadata.py +++ b/sdv/metadata.py @@ -575,11 +575,6 @@ def _validate_circular_relationships(self, parent, children=None): for child in children: self._validate_circular_relationships(parent, self.get_children(child)) - def _validate_parents(self, table_name): - """Make sure that the table has only one parent.""" - if len(self.get_parents(table_name)) > 1: - raise MetadataError('Table {} has more than one parent.'.format(table_name)) - def validate(self, tables=None): """Validate this metadata. @@ -624,7 +619,6 @@ def validate(self, tables=None): self._validate_table(table_name, table_meta, table) self._validate_circular_relationships(table_name) - # self._validate_parents(table_name) def _check_field(self, table, field, exists=False): """Validate the existance of the table and existance (or not) of field.""" diff --git a/sdv/sampler.py b/sdv/sampler.py index 706d8c064..3aecd076d 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -38,7 +38,11 @@ def _reset_primary_keys_generators(self): self.remaining_primary_key = dict() def _finalize(self, sampled_data): - """Reverse transform synthetized data. + """Do the final touches to the generated data. + + This method reverts the previous transformations to go back + to values in the original space and also adds the parent + keys in case foreign key relationships exist between the tables. Args: sampled_data (dict): @@ -226,11 +230,7 @@ def _find_parent_id(self, row, pdfs, num_rows): # singular matrix rows with the mean likelihoods = likelihoods.fillna(mean) - total = likelihoods.sum() - if total == 0: - weights = np.ones(len(likelihoods)) - else: - weights = likelihoods.values / likelihoods.sum() + weights = likelihoods.values / likelihoods.sum() return np.random.choice(likelihoods.index, p=weights) @@ -245,7 +245,8 @@ def _find_parent_ids(self, table_name, parent_name, sampled_data): parent_rows = self._sample_rows(parent_model, num_parent_rows, parent_name) primary_key = self.metadata.get_primary_key(parent_name) - pdfs = self._get_pdfs(parent_rows.set_index(primary_key), table_name) + parent_rows = parent_rows.set_index(primary_key) + pdfs = self._get_pdfs(parent_rows, table_name) num_rows = parent_rows['__' + table_name + '__child_rows'].clip(0) parent_ids = list() diff --git a/tests/integration/test_sdv.py b/tests/integration/test_sdv.py new file mode 100644 index 000000000..2314aee8b --- /dev/null +++ b/tests/integration/test_sdv.py @@ -0,0 +1,36 @@ +from sdv import SDV, load_demo + + +def test_sdv(): + metadata, tables = load_demo(metadata=True) + + sdv = SDV() + sdv.fit(metadata, tables) + + # Sample all + sampled = sdv.sample_all() + + assert set(sampled.keys()) == {'users', 'sessions', 'transactions'} + assert len(sampled['users']) == 10 + + # Sample with children + sampled = sdv.sample('users', reset_primary_keys=True) + + assert set(sampled.keys()) == {'users', 'sessions', 'transactions'} + assert len(sampled['users']) == 10 + + # Sample without children + users = sdv.sample('users', sample_children=False) + + assert users.shape == tables['users'].shape + assert list(users.columns) == list(tables['users'].columns) + + sessions = sdv.sample('sessions', sample_children=False) + + assert sessions.shape == tables['sessions'].shape + assert list(sessions.columns) == list(tables['sessions'].columns) + + transactions = sdv.sample('transactions', sample_children=False) + + assert transactions.shape == tables['transactions'].shape + assert list(transactions.columns) == list(tables['transactions'].columns) diff --git a/tests/models/test_copulas.py b/tests/models/test_copulas.py index 0dd0579d1..56a14215b 100644 --- a/tests/models/test_copulas.py +++ b/tests/models/test_copulas.py @@ -26,8 +26,11 @@ def test__unflatten_gaussian_copula(self): # Run model_parameters = { - 'distribs': { - 'foo': {'std': 0.5} + 'univariates': { + 'foo': { + 'scale': 0.0, + 'loc': 5 + }, }, 'covariance': [[0.4, 0.1], [0.1]], 'distribution': 'GaussianUnivariate' @@ -36,13 +39,14 @@ def test__unflatten_gaussian_copula(self): # Asserts expected = { - 'distribs': { - 'foo': { - 'fitted': True, - 'std': 1.6487212707001282, + 'univariates': [ + { + 'scale': 1.0, + 'loc': 5, 'type': 'GaussianUnivariate' } - }, + ], + 'columns': ['foo'], 'distribution': 'GaussianUnivariate', 'covariance': [[0.4, 0.2], [0.2, 0.0]] } diff --git a/tests/test_metadata.py b/tests/test_metadata.py index a322bc7c7..137a258b8 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -218,28 +218,6 @@ def test__dict_metadata_dict(self): } assert result == expected - def test__validate_parents_no_error(self): - """Test that any error is raised with a supported structure""" - # Setup - mock = MagicMock(spec_set=Metadata) - mock.get_parents.return_value = [] - - # Run - Metadata._validate_parents(mock, 'demo') - - # Asserts - mock.get_parents.assert_called_once_with('demo') - - def test__validate_parents_raise_error(self): - """Test that a ValueError is raised because the bad structure""" - # Setup - mock = MagicMock(spec_set=Metadata) - mock.get_parents.return_value = ['foo', 'bar'] - - # Run - with pytest.raises(MetadataError): - Metadata._validate_parents(mock, 'demo') - @patch('sdv.metadata.Metadata._analyze_relationships') @patch('sdv.metadata.Metadata._dict_metadata') def test___init__default_metadata_dict(self, mock_meta, mock_relationships): diff --git a/tests/test_modeler.py b/tests/test_modeler.py index bf162d6c8..f701e2469 100644 --- a/tests/test_modeler.py +++ b/tests/test_modeler.py @@ -69,6 +69,7 @@ def test_cpa_with_tables_no_primary_key(self): modeler.model = Mock(spec=SDVModel) modeler.model_kwargs = dict() modeler.models = dict() + modeler.table_sizes = {'data': 5} modeler.metadata.transform.return_value = pd.DataFrame({'data': [1, 2, 3]}) modeler.metadata.get_primary_key.return_value = None diff --git a/tests/test_sampler.py b/tests/test_sampler.py index a5eddf75f..fd3e0daa0 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -15,7 +15,13 @@ def test___init__(self): """Test create a default instance of Sampler class""" # Run models = {'test': Mock()} - sampler = Sampler('test_metadata', models, SDVModel, dict()) + sampler = Sampler( + 'test_metadata', + models, + SDVModel, + {'model': 'kwargs'}, + {'table': 'sizes'} + ) # Asserts assert sampler.metadata == 'test_metadata' @@ -23,7 +29,8 @@ def test___init__(self): assert sampler.primary_key == dict() assert sampler.remaining_primary_key == dict() assert sampler.model == SDVModel - assert sampler.model_kwargs == dict() + assert sampler.model_kwargs == {'model': 'kwargs'} + assert sampler.table_sizes == {'table': 'sizes'} def test__reset_primary_keys_generators(self): """Test reset values""" @@ -39,23 +46,50 @@ def test__reset_primary_keys_generators(self): assert sampler.primary_key == dict() assert sampler.remaining_primary_key == dict() - def test__transform_synthesized_rows(self): - """Test transform synthesized rows""" + def test__finalize(self): + """Test finalize""" # Setup - metadata_reverse_transform = pd.DataFrame({'foo': [0, 1], 'bar': [2, 3], 'tar': [4, 5]}) - sampler = Mock(spec=Sampler) sampler.metadata = Mock(spec=Metadata) - sampler.metadata.reverse_transform.return_value = metadata_reverse_transform - sampler.metadata.get_fields.return_value = {'foo': 'some data', 'tar': 'some data'} + sampler.metadata.get_parents.return_value = ['b', 'c'] + + sampler.metadata.reverse_transform.side_effect = lambda x, y: y + + sampler.metadata.get_fields.return_value = { + 'a': 'some data', + 'b': 'some data', # fk + 'c': 'some data' # fk + } + + sampler._find_parent_ids.side_effect = [ + [2, 3], + [4, 5] + ] + sampler.metadata.get_foreign_key.side_effect = [ + 'b', + 'c', + ] # Run - synthesized = pd.DataFrame({'data': [1, 2, 3]}) - result = Sampler._transform_synthesized_rows(sampler, synthesized, 'test') + sampled_data = { + 'test': pd.DataFrame({ + 'a': [0, 1], # actual data + 'z': [6, 7] # not used + }) + } + result = Sampler._finalize(sampler, sampled_data) # Asserts - expected = pd.DataFrame({'foo': [0, 1], 'tar': [4, 5]}) - pd.testing.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + assert isinstance(result, dict) + expected = pd.DataFrame({ + 'a': [0, 1], + 'b': [2, 3], + 'c': [4, 5], + }) + pd.testing.assert_frame_equal( + result['test'].sort_index(axis=1), + expected.sort_index(axis=1) + ) def test__get_primary_keys_none(self): """Test returns a tuple of none when a table doesn't have a primary key""" @@ -133,7 +167,7 @@ def test__get_primary_keys_raises_value_error_remaining(self): with pytest.raises(ValueError): Sampler._get_primary_keys(sampler, 'test', 5) - def test__get_extension(self): + def test__extract_parameters(self): """Test get extension""" # Setup sampler = Mock(spec=Sampler) @@ -141,7 +175,7 @@ def test__get_extension(self): # Run parent_row = pd.Series([[0, 1], [1, 0]], index=['__foo__field', '__foo__field2']) table_name = 'foo' - result = Sampler._get_extension(sampler, parent_row, table_name) + result = Sampler._extract_parameters(sampler, parent_row, table_name) # Asserts expected = {'field': [0, 1], 'field2': [1, 0]} @@ -191,14 +225,14 @@ def test__sample_children(self): ['child C', 'test', pd.Series([22], index=['field'], name=1), sampled], ['child C', 'test', pd.Series([33], index=['field'], name=2), sampled], ] - actual_calls = sampler._sample_table.call_args_list + actual_calls = sampler._sample_child_rows.call_args_list for result_call, expected_call in zip(actual_calls, expected_calls): assert result_call[0][0] == expected_call[0] assert result_call[0][1] == expected_call[1] assert result_call[0][3] == expected_call[3] pd.testing.assert_series_equal(result_call[0][2], expected_call[2]) - def test__sample_table_sampled_empty(self): + def test__sample_child_rows_sampled_empty(self): """Test sample table when sampled is still an empty dict.""" # Setup model = Mock(spec=SDVModel) @@ -208,7 +242,7 @@ def test__sample_table_sampled_empty(self): sampler.model = model sampler.model_kwargs = dict() - sampler._get_extension.return_value = {'child_rows': 5} + sampler._extract_parameters.return_value = {'child_rows': 5} table_model_mock = Mock() sampler.models = {'test': table_model_mock} @@ -222,10 +256,10 @@ def test__sample_table_sampled_empty(self): # Run parent_row = pd.Series({'id': 0}) sampled = dict() - Sampler._sample_table(sampler, 'test', 'parent', parent_row, sampled) + Sampler._sample_child_rows(sampler, 'test', 'parent', parent_row, sampled) # Asserts - sampler._get_extension.assert_called_once_with(parent_row, 'test') + sampler._extract_parameters.assert_called_once_with(parent_row, 'test') sampler._sample_rows.assert_called_once_with(model, 5, 'test') assert sampler._sample_children.call_count == 1 @@ -240,7 +274,7 @@ def test__sample_table_sampled_empty(self): expected_sampled ) - def test__sample_table_sampled_not_empty(self): + def test__sample_child_rows_sampled_not_empty(self): """Test sample table when sampled previous sampled rows exist.""" # Setup model = Mock(spec=SDVModel) @@ -249,7 +283,7 @@ def test__sample_table_sampled_not_empty(self): sampler = Mock(spec=Sampler) sampler.model = model sampler.model_kwargs = dict() - sampler._get_extension.return_value = {'child_rows': 5} + sampler._extract_parameters.return_value = {'child_rows': 5} table_model_mock = Mock() sampler.models = {'test': table_model_mock} @@ -270,11 +304,10 @@ def test__sample_table_sampled_not_empty(self): 'parent_id': [0, 0, 0, 0, 0] }) } - Sampler._sample_table(sampler, 'test', 'parent', parent_row, sampled) + Sampler._sample_child_rows(sampler, 'test', 'parent', parent_row, sampled) # Asserts - sampler._get_extension.assert_called_once_with(parent_row, 'test') - # sampler._get_model.assert_called_once_with({'child_rows': 5}, table_model_mock) + sampler._extract_parameters.assert_called_once_with(parent_row, 'test') sampler._sample_rows.assert_called_once_with(model, 5, 'test') assert sampler._sample_children.call_count == 1 @@ -308,31 +341,3 @@ def sample_side_effect(table, num_rows): assert sampler._reset_primary_keys_generators.call_count == 1 pd.testing.assert_frame_equal(result['table a'], pd.DataFrame({'foo': range(3)})) pd.testing.assert_frame_equal(result['table c'], pd.DataFrame({'foo': range(3)})) - - def test_sample_table_with_parents(self): - """Test sample table with parents.""" - sampler = Mock(spec=Sampler) - sampler.metadata = Mock(spec=Metadata) - sampler.metadata.get_parents.return_value = ['test_parent'] - sampler.metadata.get_foreign_key.return_value = 'id' - sampler.models = {'test': 'some model'} - sampler._get_primary_keys.return_value = None, pd.Series({'id': 0}) - sampler._sample_rows.return_value = pd.DataFrame({'id': [0, 1]}) - - Sampler.sample(sampler, 'test', 5) - sampler.metadata.get_parents.assert_called_once_with('test') - sampler.metadata.get_foreign_key.assert_called_once_with('test_parent', 'test') - - def test_sample_no_sample_children(self): - """Test sample no sample children""" - # Setup - sampler = Mock(spec=Sampler) - sampler.models = {'test': 'model'} - sampler.metadata.get_parents.return_value = None - - # Run - Sampler.sample(sampler, 'test', 5, sample_children=False) - sampler._transform_synthesized_rows.assert_called_once_with( - sampler._sample_rows.return_value, - 'test' - ) diff --git a/tests/test_sdv.py b/tests/test_sdv.py index f1373129e..21c61f7ef 100644 --- a/tests/test_sdv.py +++ b/tests/test_sdv.py @@ -86,7 +86,7 @@ def test_sample_all_fitted(self): # Asserts assert result == 'test' - sdv.sampler.sample_all.assert_called_once_with(5, reset_primary_keys=False) + sdv.sampler.sample_all.assert_called_once_with(None, reset_primary_keys=False) def test_sample_all_not_fitted(self): """Check that the sample_all raise an exception when is not fitted.""" From 92da737989324ffd544c9897891f4656dde02fd8 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 16:52:38 +0200 Subject: [PATCH 06/33] Fix test on py35 --- tests/integration/test_sdv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_sdv.py b/tests/integration/test_sdv.py index 2314aee8b..be9b40325 100644 --- a/tests/integration/test_sdv.py +++ b/tests/integration/test_sdv.py @@ -23,14 +23,14 @@ def test_sdv(): users = sdv.sample('users', sample_children=False) assert users.shape == tables['users'].shape - assert list(users.columns) == list(tables['users'].columns) + assert set(users.columns) == set(tables['users'].columns) sessions = sdv.sample('sessions', sample_children=False) assert sessions.shape == tables['sessions'].shape - assert list(sessions.columns) == list(tables['sessions'].columns) + assert set(sessions.columns) == set(tables['sessions'].columns) transactions = sdv.sample('transactions', sample_children=False) assert transactions.shape == tables['transactions'].shape - assert list(transactions.columns) == list(tables['transactions'].columns) + assert set(transactions.columns) == set(tables['transactions'].columns) From 5b041f465d30ce9bf69603cd50ea917ae581f309 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 18:56:37 +0200 Subject: [PATCH 07/33] Vectorize and search for parents only if necessary --- sdv/sampler.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/sdv/sampler.py b/sdv/sampler.py index 3aecd076d..2dd318e7d 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -57,9 +57,10 @@ def _finalize(self, sampled_data): parents = self.metadata.get_parents(table_name) if parents: for parent_name in parents: - parent_ids = self._find_parent_ids(table_name, parent_name, sampled_data) foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - table_rows[foreign_key] = parent_ids + if foreign_key not in table_rows: + parent_ids = self._find_parent_ids(table_name, parent_name, sampled_data) + table_rows[foreign_key] = parent_ids reversed_data = self.metadata.reverse_transform(table_name, table_rows) @@ -198,6 +199,7 @@ def _sample_child_rows(self, table_name, parent_name, parent_row, sampled): self._sample_children(table_name, sampled) def _get_pdfs(self, parent_rows, child_name): + """Build a model for each parent row and get its pdf function.""" pdfs = dict() for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, child_name) @@ -207,16 +209,7 @@ def _get_pdfs(self, parent_rows, child_name): return pdfs - def _find_parent_id(self, row, pdfs, num_rows): - likelihoods = dict() - for parent_id, pdf in pdfs.items(): - try: - likelihoods[parent_id] = max(pdf(row), 0.0) - except np.linalg.LinAlgError: - # Singular matrix - likelihoods[parent_id] = None - - likelihoods = pd.Series(likelihoods) + def _find_parent_id(self, likelihoods, num_rows): mean = likelihoods.mean() if (likelihoods == 0).all(): # All rows got 0 likelihood, fallback to num_rows @@ -234,6 +227,19 @@ def _find_parent_id(self, row, pdfs, num_rows): return np.random.choice(likelihoods.index, p=weights) + def _get_likelihoods(self, table_rows, parent_rows, table_name): + likelihoods = dict() + for parent_id, row in parent_rows.iterrows(): + parameters = self._extract_parameters(row, table_name) + model = self.model(**self.model_kwargs) + model.set_parameters(parameters) + try: + likelihoods[parent_id] = model.model.probability_density(table_rows) + except np.linalg.LinAlgError: + likelihoods[parent_id] = None + + return pd.DataFrame(likelihoods, index=table_rows.index) + def _find_parent_ids(self, table_name, parent_name, sampled_data): table_rows = sampled_data[table_name] if parent_name in sampled_data: @@ -246,14 +252,10 @@ def _find_parent_ids(self, table_name, parent_name, sampled_data): primary_key = self.metadata.get_primary_key(parent_name) parent_rows = parent_rows.set_index(primary_key) - pdfs = self._get_pdfs(parent_rows, table_name) num_rows = parent_rows['__' + table_name + '__child_rows'].clip(0) - parent_ids = list() - for _, row in table_rows.iterrows(): - parent_ids.append(self._find_parent_id(row, pdfs, num_rows)) - - return parent_ids + likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name) + return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) def sample(self, table_name, num_rows=None, reset_primary_keys=False, sample_children=True, sample_parents=True): From 6c7be1a50280f244dd073b1bd981f9e0304734f8 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 18:57:10 +0200 Subject: [PATCH 08/33] Update notebook examples --- examples/Demo - Walmart.ipynb | 872 +++++++++++++++++----------------- examples/demo_metadata.json | 30 +- 2 files changed, 450 insertions(+), 452 deletions(-) diff --git a/examples/Demo - Walmart.ipynb b/examples/Demo - Walmart.ipynb index d3bbe80e9..2e3303259 100644 --- a/examples/Demo - Walmart.ipynb +++ b/examples/Demo - Walmart.ipynb @@ -100,9 +100,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 14:27:25,479 - INFO - metadata - Loading table stores\n", - "2020-07-02 14:27:25,486 - INFO - metadata - Loading table features\n", - "2020-07-02 14:27:25,502 - INFO - metadata - Loading table depts\n" + "2020-07-02 18:50:09,230 - INFO - metadata - Loading table stores\n", + "2020-07-02 18:50:09,237 - INFO - metadata - Loading table features\n", + "2020-07-02 18:50:09,255 - INFO - metadata - Loading table depts\n" ] } ], @@ -213,7 +213,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -261,128 +261,128 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 14:27:31,883 - INFO - modeler - Modeling stores\n", - "2020-07-02 14:27:31,883 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", - "2020-07-02 14:27:31,884 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", - "2020-07-02 14:27:31,895 - INFO - modeler - Modeling depts\n", - "2020-07-02 14:27:31,895 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", - "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", - "2020-07-02 14:27:31,896 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 14:27:32,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,272 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,288 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,302 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,319 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,335 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,367 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,383 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,399 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,414 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,430 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,445 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,460 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,475 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,490 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,508 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,522 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,536 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,550 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,563 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,576 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,594 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,607 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,621 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,636 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,664 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,695 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,710 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,723 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,738 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,752 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,765 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,781 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,796 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,809 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,823 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,837 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,868 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,881 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,896 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,907 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,919 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:32,938 - INFO - modeler - Modeling features\n", - "2020-07-02 14:27:32,939 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 14:27:32,939 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", - "2020-07-02 14:27:32,940 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", - "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", - "2020-07-02 14:27:32,941 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", - "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", - "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", - "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", - "2020-07-02 14:27:32,942 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", - "2020-07-02 14:27:32,943 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", - "2020-07-02 14:27:32,993 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,043 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,073 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,133 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,189 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,220 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,253 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,285 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,314 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,342 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,371 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + "2020-07-02 18:50:09,639 - INFO - modeler - Modeling stores\n", + "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", + "2020-07-02 18:50:09,653 - INFO - modeler - Modeling features\n", + "2020-07-02 18:50:09,653 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-02 18:50:09,655 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-02 18:50:09,659 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-02 18:50:09,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,760 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,817 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,871 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,929 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,958 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:09,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,076 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,158 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,184 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,211 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,239 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,268 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,297 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,325 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,352 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,378 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,410 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,440 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,467 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,494 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,519 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,547 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,572 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,602 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,628 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,656 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,686 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,768 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,799 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,827 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,858 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,886 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,939 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,967 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:10,996 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,028 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,058 - INFO - modeler - Modeling depts\n", + "2020-07-02 18:50:11,058 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-02 18:50:11,169 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,445 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,461 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,474 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,489 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,505 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,536 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,568 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,603 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 14:27:33,399 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,425 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,452 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,504 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,531 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,558 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,592 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,653 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,683 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,743 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,770 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,797 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,826 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,880 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,910 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,939 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,967 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:33,996 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,025 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,051 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,078 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,107 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,134 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,160 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,187 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,216 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,246 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,275 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,303 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,333 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:27:34,452 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,618 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,634 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,683 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,698 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,728 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,741 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,754 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,769 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,782 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,800 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,814 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,829 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,844 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,859 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,874 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,887 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,900 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,914 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,928 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,940 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,955 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:11,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,003 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,017 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,032 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,058 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,069 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,084 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-02 18:50:12,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", " return getattr(obj, method)(*args, **kwds)\n", - "2020-07-02 14:27:34,680 - INFO - modeler - Modeling Complete\n" + "2020-07-02 18:50:12,380 - INFO - modeler - Modeling Complete\n" ] } ], @@ -417,29 +417,27 @@ "metadata": {}, "outputs": [ { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sdv.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotFittedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'SDV instance has not been fitted'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 335\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, table_name, num_rows, reset_primary_keys, sample_children, sample_parents)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 300\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_finalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_finalize\u001b[0;34m(self, sampled_data)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mparents\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparent_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparents\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mparent_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_find_parent_ids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0mforeign_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_foreign_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mforeign_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparent_ids\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_find_parent_ids\u001b[0;34m(self, table_name, parent_name, sampled_data)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0mparent_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m \u001b[0mparent_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_find_parent_id\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpdfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparent_ids\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_find_parent_id\u001b[0;34m(self, row, pdfs, num_rows)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparent_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpdf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpdfs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 214\u001b[0;31m \u001b[0mlikelihoods\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparent_id\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 215\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinAlgError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0;31m# Singular matrix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/multivariate/gaussian.py\u001b[0m in \u001b[0;36mprobability_density\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 133\u001b[0m \"\"\"\n\u001b[1;32m 134\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mtransformed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_transform_to_normal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultivariate_normal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtransformed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcov\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/multivariate/gaussian.py\u001b[0m in \u001b[0;36m_transform_to_normal\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcolumn_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munivariate\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munivariates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mcolumn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcolumn_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munivariate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolumn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mEPSILON\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mEPSILON\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mppf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumn_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/univariate/base.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0mCumulative\u001b[0m \u001b[0mdistribution\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpoints\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 307\u001b[0m \"\"\"\n\u001b[0;32m--> 308\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcumulative_distribution\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 309\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpercent_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/copulas/univariate/base.py\u001b[0m in \u001b[0;36mcumulative_distribution\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 523\u001b[0m \"\"\"\n\u001b[1;32m 524\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 525\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 526\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpercent_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 455\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mlogcdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/SDV/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py\u001b[0m in \u001b[0;36mcdf\u001b[0;34m(self, x, *args, **kwds)\u001b[0m\n\u001b[1;32m 1747\u001b[0m \u001b[0mcond1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open_support_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mscale\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1748\u001b[0m \u001b[0mcond2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0mcond0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1749\u001b[0;31m \u001b[0mcond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcond0\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0mcond1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1750\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtyp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0mplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mcond0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbadvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] + "data": { + "text/plain": [ + "{'stores': 45, 'features': 8190, 'depts': 421570}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], + "source": [ + "sdv.modeler.table_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": false + }, + "outputs": [], "source": [ "samples = sdv.sample_all()" ] @@ -453,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -486,32 +484,32 @@ " \n", " 0\n", " B\n", - " 108725\n", - " 0\n", + " 85496\n", + " 3\n", " \n", " \n", " 1\n", - " B\n", - " 71407\n", - " 1\n", + " A\n", + " 178862\n", + " 4\n", " \n", " \n", " 2\n", - " A\n", - " 192291\n", - " 2\n", + " B\n", + " 69654\n", + " 5\n", " \n", " \n", " 3\n", - " B\n", - " 91790\n", - " 3\n", + " A\n", + " 211981\n", + " 6\n", " \n", " \n", " 4\n", - " B\n", - " 147700\n", - " 4\n", + " A\n", + " 131188\n", + " 7\n", " \n", " \n", "\n", @@ -519,14 +517,14 @@ ], "text/plain": [ " Type Size Store\n", - "0 B 108725 0\n", - "1 B 71407 1\n", - "2 A 192291 2\n", - "3 B 91790 3\n", - "4 B 147700 4" + "0 B 85496 3\n", + "1 A 178862 4\n", + "2 B 69654 5\n", + "3 A 211981 6\n", + "4 A 131188 7" ] }, - "execution_count": 6, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -537,7 +535,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -578,107 +576,107 @@ " \n", " \n", " 0\n", - " 2012-01-06 12:07:18.407483136\n", + " 2012-04-23 03:09:52.638271488\n", " NaN\n", - " 0\n", + " 3\n", " False\n", " NaN\n", - " 3371.662534\n", - " 3.518252\n", - " 8.595226\n", - " 82.638090\n", - " NaN\n", + " -9803.940199\n", + " 3.561375\n", + " 8.838728\n", + " 67.162475\n", " NaN\n", - " 190.969301\n", + " 2703.17729\n", + " 186.471991\n", " \n", " \n", " 1\n", - " 2012-02-04 17:43:32.634723840\n", - " NaN\n", - " 0\n", + " 2011-04-19 08:45:12.429521664\n", + " 483.892955\n", + " 3\n", " False\n", + " 7504.524416\n", " NaN\n", + " 3.495118\n", + " 7.360667\n", + " 42.785730\n", + " 2772.597105\n", " NaN\n", - " 3.690170\n", - " 8.564543\n", - " 75.341551\n", - " 5178.361034\n", - " NaN\n", - " 197.207944\n", + " 192.048268\n", " \n", " \n", " 2\n", - " 2011-02-28 11:10:43.405782528\n", - " 3991.026002\n", - " 0\n", + " 2011-01-30 22:13:59.841415680\n", + " NaN\n", + " 3\n", " False\n", " NaN\n", " NaN\n", - " 3.736181\n", - " 7.736966\n", - " 52.131744\n", - " 2939.548298\n", - " 2017.730801\n", - " 198.774994\n", + " 3.361946\n", + " 7.524812\n", + " 34.945770\n", + " NaN\n", + " NaN\n", + " 192.100673\n", " \n", " \n", " 3\n", - " 2013-10-07 10:01:00.746535424\n", - " NaN\n", - " 0\n", + " 2011-10-08 08:16:00.977235968\n", + " 4661.175670\n", + " 3\n", " False\n", + " 3392.028528\n", + " -5837.649003\n", + " 2.994273\n", + " 7.993152\n", + " 66.818180\n", " NaN\n", - " 2969.019454\n", - " 3.004613\n", - " 9.759439\n", - " 73.570879\n", - " 10744.780013\n", - " 4958.619653\n", - " 173.358556\n", + " NaN\n", + " 197.921916\n", " \n", " \n", " 4\n", - " 2011-12-29 13:02:32.885990144\n", - " 6395.942402\n", - " 0\n", + " 2011-09-29 23:18:43.912751616\n", + " NaN\n", + " 3\n", " False\n", " NaN\n", - " -1840.929955\n", - " 3.882088\n", - " 10.158594\n", - " 77.868550\n", " NaN\n", + " 3.116659\n", + " 4.945393\n", + " 48.979036\n", " NaN\n", - " 197.546955\n", + " 2957.77612\n", + " 192.677797\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", - "0 2012-01-06 12:07:18.407483136 NaN 0 False NaN \n", - "1 2012-02-04 17:43:32.634723840 NaN 0 False NaN \n", - "2 2011-02-28 11:10:43.405782528 3991.026002 0 False NaN \n", - "3 2013-10-07 10:01:00.746535424 NaN 0 False NaN \n", - "4 2011-12-29 13:02:32.885990144 6395.942402 0 False NaN \n", + " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", + "0 2012-04-23 03:09:52.638271488 NaN 3 False NaN \n", + "1 2011-04-19 08:45:12.429521664 483.892955 3 False 7504.524416 \n", + "2 2011-01-30 22:13:59.841415680 NaN 3 False NaN \n", + "3 2011-10-08 08:16:00.977235968 4661.175670 3 False 3392.028528 \n", + "4 2011-09-29 23:18:43.912751616 NaN 3 False NaN \n", "\n", - " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", - "0 3371.662534 3.518252 8.595226 82.638090 NaN \n", - "1 NaN 3.690170 8.564543 75.341551 5178.361034 \n", - "2 NaN 3.736181 7.736966 52.131744 2939.548298 \n", - "3 2969.019454 3.004613 9.759439 73.570879 10744.780013 \n", - "4 -1840.929955 3.882088 10.158594 77.868550 NaN \n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -9803.940199 3.561375 8.838728 67.162475 NaN \n", + "1 NaN 3.495118 7.360667 42.785730 2772.597105 \n", + "2 NaN 3.361946 7.524812 34.945770 NaN \n", + "3 -5837.649003 2.994273 7.993152 66.818180 NaN \n", + "4 NaN 3.116659 4.945393 48.979036 NaN \n", "\n", - " MarkDown2 CPI \n", - "0 NaN 190.969301 \n", - "1 NaN 197.207944 \n", - "2 2017.730801 198.774994 \n", - "3 4958.619653 173.358556 \n", - "4 NaN 197.546955 " + " MarkDown2 CPI \n", + "0 2703.17729 186.471991 \n", + "1 NaN 192.048268 \n", + "2 NaN 192.100673 \n", + "3 NaN 197.921916 \n", + "4 2957.77612 192.677797 " ] }, - "execution_count": 7, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -689,7 +687,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -723,42 +721,42 @@ " \n", " \n", " 0\n", - " 2012-05-01 02:58:51.702994688\n", - " 16832.327095\n", - " 0\n", - " 77\n", + " 2011-04-22 15:24:18.057608704\n", + " 11196.989134\n", + " 3\n", + " 38\n", " False\n", " \n", " \n", " 1\n", - " 2011-01-06 18:03:36.333044224\n", - " 52593.751325\n", - " 0\n", - " -35\n", + " 2010-02-06 19:21:02.431538688\n", + " -14038.529503\n", + " 3\n", + " 29\n", " False\n", " \n", " \n", " 2\n", - " 2012-09-19 18:41:44.623307520\n", - " 41113.554799\n", - " 0\n", - " 48\n", + " 2012-06-04 16:23:09.227934976\n", + " -6519.738485\n", + " 3\n", + " 46\n", " False\n", " \n", " \n", " 3\n", - " 2012-02-18 19:35:16.603736320\n", - " 46877.378050\n", - " 0\n", - " 21\n", + " 2011-08-09 17:18:54.910250752\n", + " 23194.918038\n", + " 3\n", + " 45\n", " False\n", " \n", " \n", " 4\n", - " 2011-10-04 10:57:22.125225216\n", - " 47731.413911\n", - " 0\n", - " 60\n", + " 2010-09-01 23:10:54.986872576\n", + " 16761.426407\n", + " 3\n", + " 29\n", " False\n", " \n", " \n", @@ -767,14 +765,14 @@ ], "text/plain": [ " Date Weekly_Sales Store Dept IsHoliday\n", - "0 2012-05-01 02:58:51.702994688 16832.327095 0 77 False\n", - "1 2011-01-06 18:03:36.333044224 52593.751325 0 -35 False\n", - "2 2012-09-19 18:41:44.623307520 41113.554799 0 48 False\n", - "3 2012-02-18 19:35:16.603736320 46877.378050 0 21 False\n", - "4 2011-10-04 10:57:22.125225216 47731.413911 0 60 False" + "0 2011-04-22 15:24:18.057608704 11196.989134 3 38 False\n", + "1 2010-02-06 19:21:02.431538688 -14038.529503 3 29 False\n", + "2 2012-06-04 16:23:09.227934976 -6519.738485 3 46 False\n", + "3 2011-08-09 17:18:54.910250752 23194.918038 3 45 False\n", + "4 2010-09-01 23:10:54.986872576 16761.426407 3 29 False" ] }, - "execution_count": 8, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -794,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "metadata": { "scrolled": false }, @@ -803,198 +801,198 @@ "data": { "text/plain": [ "{'features': Date MarkDown1 Store IsHoliday \\\n", - " 0 2012-05-01 16:47:44.635517440 9856.327932 5 False \n", - " 1 2011-03-02 09:17:23.001815040 NaN 5 False \n", - " 2 2009-12-09 12:09:15.280852480 NaN 5 False \n", - " 3 2012-09-27 22:00:49.812289792 7575.545504 5 False \n", - " 4 2011-05-09 00:56:06.155403520 NaN 5 False \n", - " 5 2011-09-07 10:54:38.274752256 -319.511710 5 False \n", - " 6 2011-07-15 04:51:41.685670656 NaN 5 False \n", - " 7 2010-10-26 23:57:49.155473408 NaN 5 False \n", - " 8 2013-01-07 03:08:40.389892352 -25.338635 5 False \n", - " 9 2014-02-11 19:59:19.362542592 8154.995622 5 False \n", - " 10 2009-08-16 21:08:36.649604352 NaN 5 True \n", - " 11 2011-04-07 23:12:30.381310976 NaN 5 True \n", - " 12 2011-06-01 23:40:25.271012608 NaN 5 False \n", - " 13 2010-11-25 16:27:01.937167872 NaN 5 False \n", - " 14 2012-04-01 23:49:59.091133440 NaN 5 False \n", - " 15 2011-11-29 06:44:29.486765568 5394.076595 5 False \n", - " 16 2012-12-23 01:54:43.755820800 616.736543 5 False \n", - " 17 2011-02-22 12:01:12.161310464 NaN 5 False \n", - " 18 2012-09-14 09:12:48.080929792 4264.005722 5 False \n", - " 19 2011-09-29 04:49:59.894278656 NaN 5 True \n", - " 20 2011-07-06 18:19:10.024774400 NaN 5 False \n", - " 21 2010-12-20 03:51:25.390490880 8089.734544 5 False \n", - " 22 2011-12-13 08:37:11.030729472 9237.405853 5 False \n", - " 23 2011-05-10 05:02:03.746472704 NaN 5 False \n", - " 24 2011-09-16 10:58:14.326885632 2452.122989 5 False \n", - " 25 2013-08-04 01:41:06.983268352 3136.407925 5 False \n", - " 26 2012-02-27 13:58:51.523597568 4152.240138 5 False \n", - " 27 2011-02-14 11:07:31.025720320 NaN 5 False \n", - " 28 2010-12-31 09:46:24.324750592 NaN 5 False \n", - " 29 2011-11-01 17:13:08.654945280 3392.664690 5 False \n", + " 0 2012-03-09 19:29:04.121987584 4547.265066 48 False \n", + " 1 2012-01-12 12:05:27.772331264 NaN 51 False \n", + " 2 2012-10-09 23:52:58.454752256 13291.983302 51 False \n", + " 3 2011-07-26 23:58:33.986580992 NaN 49 False \n", + " 4 2012-01-17 02:27:51.821400832 6865.756770 50 False \n", + " 5 2009-12-27 02:32:16.625306880 NaN 51 False \n", + " 6 2011-07-12 23:01:03.955540224 NaN 48 False \n", + " 7 2011-09-25 10:16:19.542280704 7114.339579 48 False \n", + " 8 2011-05-06 02:00:24.919497728 NaN 48 False \n", + " 9 2013-10-23 10:17:59.945650432 2267.991754 50 False \n", + " 10 2011-05-18 08:31:22.419635968 NaN 48 False \n", + " 11 2010-07-27 18:16:54.600127744 NaN 49 False \n", + " 12 2010-11-04 00:37:47.996870656 NaN 51 False \n", + " 13 2012-03-23 15:04:51.172607488 10192.668685 48 False \n", + " 14 2007-11-14 22:28:51.711726848 NaN 48 False \n", + " 15 2013-02-07 20:55:46.684175872 9146.998153 49 False \n", + " 16 2011-07-18 00:03:42.998976512 NaN 50 False \n", + " 17 2010-10-02 21:34:09.917561088 NaN 49 False \n", + " 18 2010-03-12 18:25:31.509475328 NaN 51 False \n", + " 19 2013-06-24 02:50:15.027540736 2943.979267 50 False \n", + " 20 2010-09-24 15:42:39.457930752 NaN 48 False \n", + " 21 2012-03-27 19:00:12.877346816 7750.605869 51 False \n", + " 22 2010-05-26 23:28:27.091251968 NaN 48 False \n", + " 23 2011-05-01 23:07:37.680983296 NaN 52 False \n", + " 24 2012-10-07 16:56:39.076689152 11838.898938 51 False \n", + " 25 2011-02-19 21:38:01.903840000 NaN 52 False \n", + " 26 2012-01-18 06:02:24.764590592 7428.686729 52 False \n", + " 27 2012-09-28 01:36:19.603453184 4962.148181 52 False \n", + " 28 2011-07-28 19:21:49.952510208 12773.901269 48 False \n", + " 29 2011-09-04 16:53:45.561477888 NaN 50 False \n", " .. ... ... ... ... \n", - " 970 2011-10-05 18:36:28.353068288 NaN 5 False \n", - " 971 2012-01-12 04:51:28.610967552 NaN 5 False \n", - " 972 2012-03-24 12:11:51.862262272 NaN 5 False \n", - " 973 2013-07-12 16:33:09.626688000 6561.766212 5 False \n", - " 974 2011-07-07 11:50:18.363575808 NaN 5 False \n", - " 975 2012-04-22 12:03:07.605280512 451.002188 5 False \n", - " 976 2011-04-24 16:40:02.879920896 NaN 5 False \n", - " 977 2013-05-10 19:00:49.186777088 14456.109439 5 False \n", - " 978 2011-09-16 12:39:50.775564288 NaN 5 False \n", - " 979 2011-01-27 23:51:25.038483968 NaN 5 False \n", - " 980 2011-04-08 07:11:55.072327424 NaN 5 False \n", - " 981 2012-04-14 11:31:04.547722496 NaN 5 False \n", - " 982 2010-12-31 16:22:41.385283072 NaN 5 False \n", - " 983 2014-05-03 06:44:59.234656512 7187.928500 5 False \n", - " 984 2010-09-10 17:16:27.645561344 NaN 5 True \n", - " 985 2012-02-26 19:44:58.053891840 12936.660878 5 True \n", - " 986 2011-08-12 15:34:48.039582976 1592.820792 5 False \n", - " 987 2011-01-04 14:20:20.553192192 NaN 5 False \n", - " 988 2011-10-07 13:59:02.167729408 NaN 5 False \n", - " 989 2012-11-21 05:42:10.223450880 5971.323886 5 False \n", - " 990 2011-10-01 05:02:17.844523520 NaN 5 False \n", - " 991 2012-02-23 17:58:47.762346752 -2849.677150 5 False \n", - " 992 2011-04-20 10:01:13.460115456 NaN 5 False \n", - " 993 2010-12-21 01:41:52.356143872 NaN 5 False \n", - " 994 2011-01-17 09:10:37.772929536 NaN 5 False \n", - " 995 2011-06-03 03:47:46.385799936 NaN 5 False \n", - " 996 2009-01-13 10:16:37.091725312 NaN 5 False \n", - " 997 2012-03-11 15:48:35.927031808 NaN 5 True \n", - " 998 2010-11-17 15:02:44.981103616 NaN 5 False \n", - " 999 2010-09-07 12:56:31.595172096 NaN 5 False \n", + " 970 2011-06-14 18:48:47.646437632 7563.489728 48 False \n", + " 971 2011-05-11 09:27:03.321238016 3763.579757 50 False \n", + " 972 2012-10-27 08:44:16.987041024 -2107.559543 49 False \n", + " 973 2011-03-26 17:36:28.323915264 NaN 48 False \n", + " 974 2012-03-30 09:18:45.289797888 NaN 52 False \n", + " 975 2012-12-25 00:57:24.386340608 3483.807111 51 False \n", + " 976 2010-09-17 18:59:15.931293440 NaN 52 False \n", + " 977 2012-04-18 06:40:20.102041344 2869.173605 50 False \n", + " 978 2010-01-18 10:21:26.293353216 NaN 49 False \n", + " 979 2010-08-13 00:08:32.324416256 NaN 50 False \n", + " 980 2012-01-17 04:04:26.614297088 NaN 50 False \n", + " 981 2011-06-23 01:40:48.209999104 NaN 48 False \n", + " 982 2012-05-15 04:01:02.762140160 3935.159830 50 False \n", + " 983 2011-03-03 10:20:57.530126848 5704.369183 52 False \n", + " 984 2011-09-11 11:54:13.537061888 -3939.448245 49 False \n", + " 985 2012-07-01 08:38:12.122587648 NaN 49 False \n", + " 986 2013-01-15 14:24:51.253277952 7802.046675 52 False \n", + " 987 2011-03-16 12:56:22.305051648 NaN 49 False \n", + " 988 2012-02-09 21:21:51.383019776 3563.474259 50 False \n", + " 989 2011-04-28 07:35:57.312682752 NaN 49 False \n", + " 990 2012-07-06 11:32:42.706951424 4134.684319 48 False \n", + " 991 2012-09-10 11:27:22.375229184 2670.997722 48 False \n", + " 992 2010-04-19 00:59:34.282525696 NaN 51 False \n", + " 993 2012-05-21 02:19:56.773740288 725.161402 51 False \n", + " 994 2012-08-25 02:46:15.815232768 2736.817373 52 False \n", + " 995 2011-11-27 15:29:59.302161664 NaN 48 False \n", + " 996 2012-11-19 23:22:47.651802880 NaN 49 False \n", + " 997 2012-03-06 18:23:43.093342208 NaN 51 True \n", + " 998 2011-11-17 03:00:20.669487872 NaN 50 False \n", + " 999 2010-10-18 08:04:04.443984128 NaN 51 False \n", " \n", - " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", - " 0 2843.666588 2084.265105 3.539294 8.026048 67.279415 \n", - " 1 NaN NaN 3.136235 7.464250 35.839727 \n", - " 2 NaN NaN 3.117027 8.804691 55.531072 \n", - " 3 6291.467352 7644.651968 3.950191 8.449126 63.312543 \n", - " 4 NaN NaN 3.151100 10.724954 78.355700 \n", - " 5 319.073405 8048.361448 3.214550 6.730897 46.004140 \n", - " 6 NaN NaN 3.073315 8.794787 46.230595 \n", - " 7 NaN NaN 2.851889 8.455094 73.007249 \n", - " 8 -842.343741 1973.521805 3.219424 NaN 83.146804 \n", - " 9 4753.392021 2744.820760 4.389613 6.049940 55.535728 \n", - " 10 NaN NaN 2.921131 10.411630 81.934195 \n", - " 11 -380.767318 NaN 3.588675 10.915397 41.862765 \n", - " 12 3841.062967 -692.534491 3.183958 8.449273 91.520699 \n", - " 13 NaN NaN 2.676381 10.171808 59.799911 \n", - " 14 NaN NaN 4.068388 11.061642 49.266088 \n", - " 15 4759.123311 3958.005275 3.860776 8.559916 54.371928 \n", - " 16 -2880.580355 6861.020784 3.241783 7.662614 44.933106 \n", - " 17 NaN NaN 3.287047 8.232218 73.642978 \n", - " 18 NaN NaN 3.553487 4.932510 70.156403 \n", - " 19 NaN 3498.711115 3.650312 10.192201 64.554441 \n", - " 20 NaN NaN 3.711555 5.975218 45.371293 \n", - " 21 NaN 8008.053972 2.965420 6.715859 65.736442 \n", - " 22 NaN 5226.209147 3.770431 4.689591 48.907799 \n", - " 23 NaN NaN 3.148174 7.048833 63.931931 \n", - " 24 NaN NaN 3.228123 7.125082 54.541421 \n", - " 25 -2311.380219 9527.093555 4.376510 6.227941 39.120836 \n", - " 26 625.303990 NaN 3.789428 5.716800 99.128378 \n", - " 27 NaN NaN 3.259632 8.651978 75.544242 \n", - " 28 NaN NaN 3.703069 10.753015 40.004955 \n", - " 29 NaN NaN 3.741760 6.060680 100.625665 \n", - " .. ... ... ... ... ... \n", - " 970 NaN NaN 3.572695 7.966283 57.537589 \n", - " 971 NaN NaN 3.828349 9.731679 82.126278 \n", - " 972 NaN NaN 3.313500 NaN 46.785653 \n", - " 973 3054.127592 -724.702300 4.287166 9.651842 51.824356 \n", - " 974 NaN NaN 3.761812 8.287437 88.088342 \n", - " 975 -3266.187881 8215.493689 3.642633 9.500544 29.831979 \n", - " 976 NaN NaN 2.819650 7.832418 40.807527 \n", - " 977 4122.302492 6313.274115 4.174946 8.727237 70.737165 \n", - " 978 5838.158736 NaN 3.066125 NaN 31.147061 \n", - " 979 NaN NaN 3.151944 7.658218 61.645832 \n", - " 980 NaN 630.957174 3.183844 9.862215 17.948335 \n", - " 981 NaN NaN 3.271366 9.679093 61.315823 \n", - " 982 NaN NaN 3.867379 8.191903 51.461519 \n", - " 983 2608.840965 -562.271212 3.781071 NaN 42.556422 \n", - " 984 NaN NaN 2.956717 7.515493 75.537151 \n", - " 985 3915.581744 1272.550074 3.473630 5.244393 78.032536 \n", - " 986 2982.566125 NaN 2.981562 9.439555 62.315341 \n", - " 987 NaN NaN 2.839158 9.016700 58.018581 \n", - " 988 3139.374304 NaN 3.952791 6.206879 68.542122 \n", - " 989 4102.145584 -2255.694045 4.091854 9.904070 48.245039 \n", - " 990 NaN NaN 3.318221 9.150930 100.355027 \n", - " 991 NaN -852.368857 3.281618 6.774089 63.196870 \n", - " 992 NaN NaN 3.199449 6.715256 26.426601 \n", - " 993 NaN NaN 3.339197 11.524258 56.354388 \n", - " 994 NaN NaN 3.422232 7.571462 65.796130 \n", - " 995 1692.546761 NaN 3.178186 5.160924 43.172044 \n", - " 996 NaN NaN 2.748483 11.166695 34.682053 \n", - " 997 NaN NaN 3.713086 8.024620 35.345627 \n", - " 998 NaN NaN 3.683783 8.076637 43.072280 \n", - " 999 NaN NaN 2.804057 9.006230 79.267188 \n", + " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", + " 0 2262.924323 4314.258484 2.944899 12.207513 44.866684 \n", + " 1 648.327530 NaN 3.415876 8.296209 79.755000 \n", + " 2 7464.648400 4818.694540 3.438865 10.375869 53.011465 \n", + " 3 7722.936675 NaN 3.238451 8.340640 81.331769 \n", + " 4 385.720439 2933.746873 3.352123 4.517977 33.216711 \n", + " 5 NaN NaN 2.595952 8.423675 58.153432 \n", + " 6 NaN NaN 3.337125 7.289977 68.198691 \n", + " 7 1376.144152 -1375.405701 3.586827 4.433586 57.457667 \n", + " 8 NaN 885.964999 3.132711 7.941049 82.624443 \n", + " 9 4380.021000 4861.726545 4.617105 7.636990 64.616449 \n", + " 10 NaN NaN 3.490697 9.699187 113.890285 \n", + " 11 NaN NaN 3.337917 10.428851 76.354063 \n", + " 12 NaN NaN 3.016818 10.044523 52.925513 \n", + " 13 NaN 1505.002284 3.244111 7.329469 34.127333 \n", + " 14 NaN NaN 2.497412 12.247448 100.350997 \n", + " 15 7982.863954 7539.825400 3.909380 5.689230 65.945165 \n", + " 16 NaN NaN 2.895094 7.718521 54.750702 \n", + " 17 NaN NaN 3.722465 7.339952 56.152632 \n", + " 18 NaN NaN 2.665366 8.045471 78.257558 \n", + " 19 -3310.004199 -808.378062 4.379650 5.320308 39.604265 \n", + " 20 NaN 4604.622951 3.058964 8.584222 44.177016 \n", + " 21 3790.217545 3053.642374 3.091990 5.069122 25.777591 \n", + " 22 NaN NaN 2.931042 9.062068 41.649615 \n", + " 23 NaN NaN 3.721348 6.586825 67.785117 \n", + " 24 4610.034450 -617.120861 4.010528 8.686732 35.445785 \n", + " 25 NaN NaN 3.249523 8.450699 38.639637 \n", + " 26 4005.965014 -28.792693 3.808513 9.188265 57.584018 \n", + " 27 6433.867544 -3996.702966 2.976900 8.991618 6.179927 \n", + " 28 NaN NaN 3.208463 11.875517 50.400717 \n", + " 29 NaN NaN 3.455848 7.377382 56.480212 \n", + " .. ... ... ... ... ... \n", + " 970 7130.653517 613.434417 3.798558 8.849840 45.081481 \n", + " 971 -950.449794 8369.741392 3.691277 8.687173 21.625873 \n", + " 972 1423.082811 3915.095969 3.716980 7.812599 38.891102 \n", + " 973 NaN NaN 3.277952 7.881991 74.018427 \n", + " 974 NaN NaN 3.444190 4.147388 79.128181 \n", + " 975 5645.978692 3893.099917 3.821702 8.410268 87.162204 \n", + " 976 NaN NaN 2.640021 7.458197 35.588448 \n", + " 977 -2627.096894 -142.014936 3.320107 6.970687 49.700872 \n", + " 978 NaN NaN 2.710125 9.927722 83.348488 \n", + " 979 NaN NaN 3.139506 8.348364 46.708000 \n", + " 980 NaN NaN 3.578857 7.555184 25.595379 \n", + " 981 NaN -1147.726465 3.685507 9.407928 91.745352 \n", + " 982 375.025287 11130.134824 3.467141 6.285321 38.179293 \n", + " 983 NaN NaN 2.675670 10.258619 50.599203 \n", + " 984 NaN NaN 3.375990 9.175979 44.036130 \n", + " 985 NaN 649.171815 3.370836 6.948002 22.863602 \n", + " 986 1944.149217 896.663662 3.971550 6.242637 97.107161 \n", + " 987 NaN NaN 3.524965 7.059279 65.469921 \n", + " 988 2550.079436 121.672409 2.898533 5.929999 58.612442 \n", + " 989 NaN NaN 3.237647 6.992924 72.225708 \n", + " 990 NaN NaN 4.075160 8.392256 80.837964 \n", + " 991 -1543.390402 5264.766267 3.157407 7.017727 81.515109 \n", + " 992 NaN NaN 3.293817 8.719239 65.735676 \n", + " 993 932.568195 1845.259380 3.157982 7.494639 65.103553 \n", + " 994 -2905.284812 -1268.636611 3.776062 6.736775 46.083801 \n", + " 995 NaN NaN 3.424730 NaN 79.231321 \n", + " 996 9027.387381 NaN 3.960882 8.911136 26.948635 \n", + " 997 NaN NaN 3.764028 8.049413 111.723841 \n", + " 998 NaN 6261.182193 3.609560 6.024725 38.411214 \n", + " 999 NaN NaN 3.051021 7.855145 75.446009 \n", " \n", - " MarkDown5 MarkDown2 CPI \n", - " 0 10204.869614 2819.319182 194.172590 \n", - " 1 NaN 1682.652584 137.004267 \n", - " 2 NaN NaN 179.678323 \n", - " 3 7501.331616 4715.605494 161.759497 \n", - " 4 NaN NaN 128.911283 \n", - " 5 6750.992820 797.112925 205.683144 \n", - " 6 NaN NaN 203.634987 \n", - " 7 NaN NaN 230.663986 \n", - " 8 3871.914754 3988.518657 NaN \n", - " 9 1624.808895 735.714943 133.713539 \n", - " 10 NaN NaN 197.575244 \n", - " 11 NaN NaN 139.060197 \n", - " 12 NaN NaN 194.510387 \n", - " 13 NaN NaN 78.886026 \n", - " 14 NaN NaN 81.669058 \n", - " 15 7758.050721 NaN 198.052999 \n", - " 16 2853.381375 9269.494345 104.932028 \n", - " 17 NaN NaN 204.210702 \n", - " 18 3343.712458 NaN 178.168110 \n", - " 19 NaN NaN 161.572206 \n", - " 20 NaN NaN 128.629900 \n", - " 21 NaN NaN 181.929594 \n", - " 22 3656.616824 1231.941589 193.633886 \n", - " 23 NaN 33.178516 198.074896 \n", - " 24 -1626.757019 NaN 105.550579 \n", - " 25 644.446061 10905.729749 137.971932 \n", - " 26 6373.039633 NaN 192.460300 \n", - " 27 NaN NaN 213.233609 \n", - " 28 NaN NaN 123.195502 \n", - " 29 2826.315133 NaN 195.701117 \n", - " .. ... ... ... \n", - " 970 NaN NaN 148.906214 \n", - " 971 NaN NaN 204.262748 \n", - " 972 NaN NaN NaN \n", - " 973 -1499.994553 16305.139629 140.188752 \n", - " 974 NaN NaN 161.099549 \n", - " 975 2685.234955 4271.729391 108.874111 \n", - " 976 NaN NaN 188.576448 \n", - " 977 3914.364495 3982.010819 196.334916 \n", - " 978 NaN NaN NaN \n", - " 979 NaN NaN 158.554149 \n", - " 980 NaN NaN 163.242219 \n", - " 981 NaN NaN 184.637472 \n", - " 982 NaN NaN 95.502993 \n", - " 983 5061.605364 4816.687238 NaN \n", - " 984 NaN NaN 150.901411 \n", - " 985 5375.719303 10141.431209 250.053444 \n", - " 986 6132.068608 7917.959484 120.500763 \n", - " 987 NaN NaN 127.836518 \n", - " 988 NaN NaN 206.249045 \n", - " 989 9976.937232 3037.291454 126.800358 \n", - " 990 NaN NaN 239.016886 \n", - " 991 4704.812744 NaN 213.808756 \n", - " 992 NaN NaN 136.904711 \n", - " 993 NaN NaN 177.862208 \n", - " 994 NaN NaN 166.085349 \n", - " 995 NaN NaN 163.449486 \n", - " 996 NaN NaN 136.797738 \n", - " 997 464.410675 NaN 186.260545 \n", - " 998 NaN NaN 140.825171 \n", - " 999 NaN NaN 168.683814 \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 7218.910186 NaN 154.437715 \n", + " 1 4928.572337 1860.474257 126.006100 \n", + " 2 4381.712904 6491.080062 182.505905 \n", + " 3 NaN 2075.661361 170.252793 \n", + " 4 -1363.927516 6383.011111 181.850615 \n", + " 5 NaN NaN 250.558631 \n", + " 6 NaN NaN 213.846371 \n", + " 7 1644.489130 4966.726612 189.105468 \n", + " 8 NaN 4689.948004 152.931090 \n", + " 9 3291.191563 2344.032770 152.780297 \n", + " 10 NaN NaN 195.602913 \n", + " 11 NaN NaN 176.381561 \n", + " 12 NaN NaN 147.881741 \n", + " 13 1668.882903 NaN 155.004680 \n", + " 14 NaN NaN 193.442985 \n", + " 15 1437.517079 4874.558939 152.281926 \n", + " 16 8439.806431 NaN 130.530240 \n", + " 17 NaN NaN 201.486725 \n", + " 18 NaN NaN 188.760741 \n", + " 19 -6907.068884 3246.451314 108.120768 \n", + " 20 NaN NaN 148.788886 \n", + " 21 5724.976324 4496.854628 190.036218 \n", + " 22 NaN NaN 181.746008 \n", + " 23 NaN NaN 200.412670 \n", + " 24 6637.078664 -1388.022496 87.164451 \n", + " 25 NaN NaN 195.622231 \n", + " 26 2922.749710 NaN 224.233432 \n", + " 27 -3042.801607 NaN 131.909287 \n", + " 28 7063.102432 NaN 146.156666 \n", + " 29 NaN NaN 176.694595 \n", + " .. ... ... ... \n", + " 970 2882.447318 NaN 148.323409 \n", + " 971 1918.923930 NaN 141.723698 \n", + " 972 151.136535 1221.530133 155.532651 \n", + " 973 NaN NaN 213.405340 \n", + " 974 NaN NaN 186.636286 \n", + " 975 2598.731935 -846.566855 190.957987 \n", + " 976 NaN NaN 153.917260 \n", + " 977 8372.618269 1768.408869 249.068471 \n", + " 978 NaN NaN 177.318454 \n", + " 979 NaN NaN 208.719040 \n", + " 980 NaN 754.021358 108.580497 \n", + " 981 NaN NaN 159.526588 \n", + " 982 7349.003146 -732.495611 173.559001 \n", + " 983 NaN NaN 140.910196 \n", + " 984 401.319810 NaN 175.340840 \n", + " 985 NaN NaN 171.683362 \n", + " 986 -465.292159 -744.114065 177.553651 \n", + " 987 NaN NaN 144.344477 \n", + " 988 567.400188 NaN 213.833680 \n", + " 989 NaN NaN 157.113353 \n", + " 990 NaN -232.801449 132.592027 \n", + " 991 -72.622621 NaN 207.754643 \n", + " 992 NaN NaN 132.575018 \n", + " 993 4955.778653 4683.646548 160.260401 \n", + " 994 2898.089751 NaN 129.801694 \n", + " 995 NaN NaN NaN \n", + " 996 NaN 608.026190 115.452177 \n", + " 997 NaN NaN 206.323679 \n", + " 998 NaN 7020.489423 178.966240 \n", + " 999 NaN NaN 152.636567 \n", " \n", " [1000 rows x 12 columns]}" ] }, - "execution_count": 9, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/demo_metadata.json b/examples/demo_metadata.json index bb24a5963..1becd1764 100644 --- a/examples/demo_metadata.json +++ b/examples/demo_metadata.json @@ -2,32 +2,32 @@ "tables": { "users": { "fields": { - "age": { - "type": "numerical", - "subtype": "integer" - }, - "gender": { + "country": { "type": "categorical" }, "user_id": { "type": "id", "subtype": "integer" }, - "country": { + "gender": { "type": "categorical" + }, + "age": { + "type": "numerical", + "subtype": "integer" } }, "primary_key": "user_id" }, "sessions": { "fields": { + "device": { + "type": "categorical" + }, "session_id": { "type": "id", "subtype": "integer" }, - "device": { - "type": "categorical" - }, "user_id": { "type": "id", "subtype": "integer", @@ -48,9 +48,9 @@ "type": "datetime", "format": "%Y-%m-%d" }, - "transaction_id": { - "type": "id", - "subtype": "integer" + "amount": { + "type": "numerical", + "subtype": "float" }, "session_id": { "type": "id", @@ -60,9 +60,9 @@ "field": "session_id" } }, - "amount": { - "type": "numerical", - "subtype": "float" + "transaction_id": { + "type": "id", + "subtype": "integer" }, "approved": { "type": "boolean" From 03f46a663cbfd704ffbf5179030f675104432c84 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 19:24:19 +0200 Subject: [PATCH 09/33] Add tests for _find_parent_ids --- sdv/sampler.py | 39 ++++++--------- tests/test_sampler.py | 107 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 103 insertions(+), 43 deletions(-) diff --git a/sdv/sampler.py b/sdv/sampler.py index 2dd318e7d..dc6be988f 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -171,13 +171,13 @@ def _sample_rows(self, model, num_rows, table_name): return sampled - def _sample_children(self, table_name, sampled): - table_rows = sampled[table_name] + def _sample_children(self, table_name, sampled_data): + table_rows = sampled_data[table_name] for child_name in self.metadata.get_children(table_name): for _, row in table_rows.iterrows(): - self._sample_child_rows(child_name, table_name, row, sampled) + self._sample_child_rows(child_name, table_name, row, sampled_data) - def _sample_child_rows(self, table_name, parent_name, parent_row, sampled): + def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data): parameters = self._extract_parameters(parent_row, table_name) model = self.model(**self.model_kwargs) @@ -190,37 +190,28 @@ def _sample_child_rows(self, table_name, parent_name, parent_row, sampled): foreign_key = self.metadata.get_foreign_key(parent_name, table_name) table_rows[foreign_key] = parent_row[parent_key] - previous = sampled.get(table_name) + previous = sampled_data.get(table_name) if previous is None: - sampled[table_name] = table_rows + sampled_data[table_name] = table_rows else: - sampled[table_name] = pd.concat([previous, table_rows]).reset_index(drop=True) - - self._sample_children(table_name, sampled) - - def _get_pdfs(self, parent_rows, child_name): - """Build a model for each parent row and get its pdf function.""" - pdfs = dict() - for parent_id, row in parent_rows.iterrows(): - parameters = self._extract_parameters(row, child_name) - model = self.model(**self.model_kwargs) - model.set_parameters(parameters) - pdfs[parent_id] = model.model.probability_density + sampled_data[table_name] = pd.concat([previous, table_rows]).reset_index(drop=True) - return pdfs + self._sample_children(table_name, sampled_data) - def _find_parent_id(self, likelihoods, num_rows): + @staticmethod + def _find_parent_id(likelihoods, num_rows): mean = likelihoods.mean() if (likelihoods == 0).all(): # All rows got 0 likelihood, fallback to num_rows likelihoods = num_rows elif pd.isnull(mean) or mean == 0: - # No row got likelihood > 0, but some got singlar matrix - # Fallback to num_rows on the singular matrix rows + # Some rows got singlar matrix error and the rest were 0 + # Fallback to num_rows on the singular matrix rows and + # keep 0s on the rest. likelihoods = likelihoods.fillna(num_rows) else: - # at least one row got likelihood > 0, so fill the - # singular matrix rows with the mean + # at least one row got a valid likelihood, so fill the + # rows that got a singular matrix error with the mean likelihoods = likelihoods.fillna(mean) weights = likelihoods.values / likelihoods.sum() diff --git a/tests/test_sampler.py b/tests/test_sampler.py index fd3e0daa0..a881e3dd8 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,6 +1,6 @@ -from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch +import numpy as np import pandas as pd import pytest @@ -9,7 +9,7 @@ from sdv.sampler import Sampler -class TestSampler(TestCase): +class TestSampler: def test___init__(self): """Test create a default instance of Sampler class""" @@ -61,10 +61,7 @@ def test__finalize(self): 'c': 'some data' # fk } - sampler._find_parent_ids.side_effect = [ - [2, 3], - [4, 5] - ] + sampler._find_parent_ids.return_value = [4, 5] sampler.metadata.get_foreign_key.side_effect = [ 'b', 'c', @@ -74,6 +71,7 @@ def test__finalize(self): sampled_data = { 'test': pd.DataFrame({ 'a': [0, 1], # actual data + 'b': [2, 3], # existing fk key 'z': [6, 7] # not used }) } @@ -90,6 +88,7 @@ def test__finalize(self): result['test'].sort_index(axis=1), expected.sort_index(axis=1) ) + sampler._find_parent_ids.assert_called_once() def test__get_primary_keys_none(self): """Test returns a tuple of none when a table doesn't have a primary key""" @@ -168,7 +167,7 @@ def test__get_primary_keys_raises_value_error_remaining(self): Sampler._get_primary_keys(sampler, 'test', 5) def test__extract_parameters(self): - """Test get extension""" + """Test extract parameters""" # Setup sampler = Mock(spec=Sampler) @@ -206,24 +205,24 @@ def test__sample_children(self): sampler.metadata.get_children.return_value = ['child A', 'child B', 'child C'] # Run - sampled = { + sampled_data = { 'test': pd.DataFrame({'field': [11, 22, 33]}) } - Sampler._sample_children(sampler, 'test', sampled) + Sampler._sample_children(sampler, 'test', sampled_data) # Asserts sampler.metadata.get_children.assert_called_once_with('test') expected_calls = [ - ['child A', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child A', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child A', 'test', pd.Series([33], index=['field'], name=2), sampled], - ['child B', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child B', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child B', 'test', pd.Series([33], index=['field'], name=2), sampled], - ['child C', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child C', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child C', 'test', pd.Series([33], index=['field'], name=2), sampled], + ['child A', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child A', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child A', 'test', pd.Series([33], index=['field'], name=2), sampled_data], + ['child B', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child B', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child B', 'test', pd.Series([33], index=['field'], name=2), sampled_data], + ['child C', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child C', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child C', 'test', pd.Series([33], index=['field'], name=2), sampled_data], ] actual_calls = sampler._sample_child_rows.call_args_list for result_call, expected_call in zip(actual_calls, expected_calls): @@ -341,3 +340,73 @@ def sample_side_effect(table, num_rows): assert sampler._reset_primary_keys_generators.call_count == 1 pd.testing.assert_frame_equal(result['table a'], pd.DataFrame({'foo': range(3)})) pd.testing.assert_frame_equal(result['table c'], pd.DataFrame({'foo': range(3)})) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_0(self, choice_mock): + """If all likelihoods are 0, use num_rows.""" + likelihoods = pd.Series([0, 0, 0, 0]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) + + choice_mock.assert_called_once() + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_singlar_matrix(self, choice_mock): + """If all likelihoods got singular matrix, use num_rows.""" + likelihoods = pd.Series([None, None, None, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) + + choice_mock.assert_called_once() + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_0_or_singlar_matrix(self, choice_mock): + """If likehoods are either 0 or NaN, fill the gaps with num_rows.""" + likelihoods = pd.Series([0, None, 0, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0, 2 / 6, 0, 4 / 6]) + + choice_mock.assert_called_once() + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_some_good(self, choice_mock): + """If some likehoods are good, fill the gaps with num_rows.""" + likelihoods = pd.Series([0.5, None, 1.5, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0.5 / 4, 1 / 4, 1.5 / 4, 1 / 4]) + + choice_mock.assert_called_once() + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_good(self, choice_mock): + """If all are good, use the likelihoods unmodified.""" + likelihoods = pd.Series([0.5, 1, 1.5, 2]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0.5 / 5, 1 / 5, 1.5 / 5, 2 / 5]) + + choice_mock.assert_called_once() + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) From ca6dd51230be7d64ca46f298a740e9df9034ca07 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 19:39:47 +0200 Subject: [PATCH 10/33] Fix test on py35 --- tests/test_sampler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index a881e3dd8..6474780ad 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -88,7 +88,7 @@ def test__finalize(self): result['test'].sort_index(axis=1), expected.sort_index(axis=1) ) - sampler._find_parent_ids.assert_called_once() + assert sampler._find_parent_ids.call_count == 1 def test__get_primary_keys_none(self): """Test returns a tuple of none when a table doesn't have a primary key""" @@ -351,7 +351,7 @@ def test__find_parent_id_all_0(self, choice_mock): expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) - choice_mock.assert_called_once() + assert choice_mock.call_count == 1 assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) @@ -365,7 +365,7 @@ def test__find_parent_id_all_singlar_matrix(self, choice_mock): expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) - choice_mock.assert_called_once() + assert choice_mock.call_count == 1 assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) @@ -379,7 +379,7 @@ def test__find_parent_id_all_0_or_singlar_matrix(self, choice_mock): expected_weights = np.array([0, 2 / 6, 0, 4 / 6]) - choice_mock.assert_called_once() + assert choice_mock.call_count == 1 assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) @@ -393,7 +393,7 @@ def test__find_parent_id_some_good(self, choice_mock): expected_weights = np.array([0.5 / 4, 1 / 4, 1.5 / 4, 1 / 4]) - choice_mock.assert_called_once() + assert choice_mock.call_count == 1 assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) @@ -407,6 +407,6 @@ def test__find_parent_id_all_good(self, choice_mock): expected_weights = np.array([0.5 / 5, 1 / 5, 1.5 / 5, 2 / 5]) - choice_mock.assert_called_once() + assert choice_mock.call_count == 1 assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) From 9753a217955bfc222d925b06bef2ab5cdbf68d4a Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 20:04:28 +0200 Subject: [PATCH 11/33] Add integration test for multi-parent dataset --- sdv/data/got_families/character_families.csv | 13 ++++ sdv/data/got_families/characters.csv | 8 ++ sdv/data/got_families/families.csv | 5 ++ sdv/data/got_families/metadata.json | 79 ++++++++++++++++++++ tests/integration/test_sdv.py | 36 +++++++++ 5 files changed, 141 insertions(+) create mode 100644 sdv/data/got_families/character_families.csv create mode 100644 sdv/data/got_families/characters.csv create mode 100644 sdv/data/got_families/families.csv create mode 100644 sdv/data/got_families/metadata.json diff --git a/sdv/data/got_families/character_families.csv b/sdv/data/got_families/character_families.csv new file mode 100644 index 000000000..6b5560db5 --- /dev/null +++ b/sdv/data/got_families/character_families.csv @@ -0,0 +1,13 @@ +character_id,family_id,generation,type +1,1,8,father +1,4,5,mother +2,1,10,father +2,2,4,mother +3,3,7,both +4,4,12,both +5,1,10,father +5,2,4,mother +6,1,10,father +6,2,4,mother +7,1,10,father +7,2,4,mother diff --git a/sdv/data/got_families/characters.csv b/sdv/data/got_families/characters.csv new file mode 100644 index 000000000..f37e21af5 --- /dev/null +++ b/sdv/data/got_families/characters.csv @@ -0,0 +1,8 @@ +age,character_id,name +20,1,Jon +16,2,Arya +35,3,Tyrion +18,4,Daenerys +19,5,Sansa +24,6,Robb +15,7,Bran diff --git a/sdv/data/got_families/families.csv b/sdv/data/got_families/families.csv new file mode 100644 index 000000000..51745d59e --- /dev/null +++ b/sdv/data/got_families/families.csv @@ -0,0 +1,5 @@ +family_id,name +1,Stark +2,Tully +3,Lannister +4,Targaryen diff --git a/sdv/data/got_families/metadata.json b/sdv/data/got_families/metadata.json new file mode 100644 index 000000000..9ecff24a5 --- /dev/null +++ b/sdv/data/got_families/metadata.json @@ -0,0 +1,79 @@ +{ + "path": "", + "tables": [ + { + "use": true, + "primary_key": "character_id", + "fields": [ + { + "regex": "^[1-9]{1,2}$", + "type": "id", + "name": "character_id" + }, + { + "type": "categorical", + "name": "name" + }, + { + "subtype": "integer", + "type": "numerical", + "name": "age" + } + ], + "headers": true, + "path": "characters.csv", + "name": "characters" + }, + { + "use": true, + "primary_key": "family_id", + "fields": [ + { + "regex": "^[1-9]$", + "type": "id", + "name": "family_id" + }, + { + "type": "categorical", + "name": "name" + } + ], + "headers": true, + "path": "families.csv", + "name": "families" + }, + { + "headers": true, + "path": "character_families.csv", + "use": true, + "name": "character_families", + "fields": [ + { + "type": "id", + "ref": { + "field": "character_id", + "table": "characters" + }, + "name": "character_id" + }, + { + "type": "id", + "ref": { + "field": "family_id", + "table": "families" + }, + "name": "family_id" + }, + { + "type": "categorical", + "name": "type" + }, + { + "subtype": "integer", + "type": "numerical", + "name": "generation" + } + ] + } + ] +} diff --git a/tests/integration/test_sdv.py b/tests/integration/test_sdv.py index be9b40325..f72cb22b4 100644 --- a/tests/integration/test_sdv.py +++ b/tests/integration/test_sdv.py @@ -34,3 +34,39 @@ def test_sdv(): assert transactions.shape == tables['transactions'].shape assert set(transactions.columns) == set(tables['transactions'].columns) + + +def test_sdv_multiparent(): + metadata, tables = load_demo('got_families', metadata=True) + + sdv = SDV() + sdv.fit(metadata, tables) + + # Sample all + sampled = sdv.sample_all() + + assert set(sampled.keys()) == {'characters', 'families', 'character_families'} + assert len(sampled['characters']) == 7 + + # Sample with children + sampled = sdv.sample('characters', reset_primary_keys=True) + + assert set(sampled.keys()) == {'characters', 'character_families'} + assert len(sampled['characters']) == 7 + assert 'family_id' in sampled['character_families'] + + # Sample without children + characters = sdv.sample('characters', sample_children=False) + + assert characters.shape == tables['characters'].shape + assert set(characters.columns) == set(tables['characters'].columns) + + families = sdv.sample('families', sample_children=False) + + assert families.shape == tables['families'].shape + assert set(families.columns) == set(tables['families'].columns) + + character_families = sdv.sample('character_families', sample_children=False) + + assert character_families.shape == tables['character_families'].shape + assert set(character_families.columns) == set(tables['character_families'].columns) From 60dbcf74280c370911ef17f72c5d50f72981efc3 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 2 Jul 2020 22:59:49 +0200 Subject: [PATCH 12/33] Allow creating a metadata with multiple parents --- sdv/metadata.py | 86 +++++++++++++++++++++++++++++++++--------- sdv/models/copulas.py | 2 +- tests/test_metadata.py | 11 +----- 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/sdv/metadata.py b/sdv/metadata.py index efa610e58..f19ed985c 100644 --- a/sdv/metadata.py +++ b/sdv/metadata.py @@ -240,6 +240,30 @@ def get_tables(self): """ return list(self._metadata['tables'].keys()) + def get_field_meta(self, table_name, field_name): + """Get the metadata dict for a table. + + Args: + table_name (str): + Name of the table to which the field belongs. + field_name (str): + Name of the field to get data for. + + Returns: + dict: + field metadata + + Raises: + ValueError: + If the table or the field do not exist in this metadata. + """ + field_meta = self.get_fields(table_name).get(field_name) + if field_meta is None: + raise ValueError( + 'Table "{}" does not contain a field name "{}"'.format(table_name, field_name)) + + return copy.deepcopy(field_meta) + def get_fields(self, table_name): """Get table fields metadata. @@ -291,11 +315,9 @@ def get_foreign_key(self, parent, child): ValueError: If the relationship does not exist. """ - primary = self.get_primary_key(parent) - for name, field in self.get_fields(child).items(): ref = field.get('ref') - if ref and ref['field'] == primary: + if ref and ref['table'] == parent: return name raise ValueError('{} is not parent of {}'.format(parent, child)) @@ -370,8 +392,13 @@ def get_dtypes(self, table_name, ids=False): if ids and field_type == 'id': if (name != table_meta.get('primary_key')) and not field.get('ref'): - raise MetadataError( - 'id field `{}` is neither a primary or a foreign key'.format(name)) + for child_table in self.get_children(table_name): + if name == self.get_foreign_key(table_name, child_table): + break + + else: + raise MetadataError( + 'id field `{}` is neither a primary or a foreign key'.format(name)) if ids or (field_type != 'id'): dtypes[name] = dtype @@ -672,14 +699,17 @@ def add_field(self, table, field, field_type, field_subtype=None, properties=Non def _get_key_subtype(field_meta): """Get the appropriate key subtype.""" field_type = field_meta['type'] + if field_type == 'categorical': field_subtype = 'string' + elif field_type in ('numerical', 'id'): field_subtype = field_meta['subtype'] if field_subtype not in ('integer', 'string'): raise ValueError( 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) ) + else: raise ValueError( 'Invalid field "type" for key field: "{}"'.format(field_type) @@ -740,7 +770,11 @@ def add_relationship(self, parent, child, foreign_key=None): * The child table already has a parent. * The new relationship closes a relationship circle. """ - # Validate table and field names + # Validate tables exists + self.get_table_meta(parent) + self.get_table_meta(child) + + # Validate field names primary_key = self.get_primary_key(parent) if not primary_key: raise ValueError('Parent table "{}" does not have a primary key'.format(parent)) @@ -748,33 +782,51 @@ def add_relationship(self, parent, child, foreign_key=None): if foreign_key is None: foreign_key = primary_key + parent_key_meta = copy.deepcopy(self.get_field_meta(parent, primary_key)) + child_key_meta = copy.deepcopy(self.get_field_meta(child, foreign_key)) + # Validate relationships - if self.get_parents(child): - raise ValueError('Table "{}" already has a parent'.format(child)) + child_ref = child_key_meta.get('ref') + if child_ref: + raise ValueError( + 'Field "{}.{}" already defines a relationship'.format(child, foreign_key)) grandchildren = self.get_children(child) if grandchildren: self._validate_circular_relationships(parent, grandchildren) - # Copy primary key details over to the foreign key - foreign_key_details = copy.deepcopy(self.get_fields(parent)[primary_key]) - foreign_key_details['ref'] = { + # Make sure that the parent key is an id + if parent_key_meta['type'] != 'id': + parent_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) + parent_key_meta['type'] = 'id' + + # Update the child key meta + child_key_meta['subtype'] = self._get_key_subtype(child_key_meta) + child_key_meta['type'] = 'id' + child_key_meta['ref'] = { 'table': parent, 'field': primary_key } # Make sure that key subtypes are the same - foreign_meta = self.get_fields(child).get(foreign_key) - if foreign_meta: - foreign_subtype = self._get_key_subtype(foreign_meta) - if foreign_subtype != foreign_key_details['subtype']: - raise ValueError('Primary and Foreign key subtypes mismatch') + if child_key_meta['subtype'] != parent_key_meta['subtype']: + raise ValueError('Parent and Child key subtypes mismatch') + + # Make a backup + metadata_backup = copy.deepcopy(self._metadata) - self._metadata['tables'][child]['fields'][foreign_key] = foreign_key_details + self._metadata['tables'][parent]['fields'][primary_key] = parent_key_meta + self._metadata['tables'][child]['fields'][foreign_key] = child_key_meta # Re-analyze the relationships self._analyze_relationships() + try: + self.validate() + except MetadataError: + self._metadata = metadata_backup + raise + def _get_field_details(self, data, fields): """Get or build all the fields metadata. diff --git a/sdv/models/copulas.py b/sdv/models/copulas.py index ff1665d50..fcd8a3c8b 100644 --- a/sdv/models/copulas.py +++ b/sdv/models/copulas.py @@ -152,7 +152,7 @@ def _unflatten_gaussian_copula(self, model_parameters): model_parameters['univariates'] = univariates model_parameters['columns'] = columns - covariance = model_parameters['covariance'] + covariance = model_parameters.get('covariance') model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) return model_parameters diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 137a258b8..812ac7bcd 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -381,6 +381,7 @@ def test_get_dtypes_error_id(self): } metadata = Mock(spec_set=Metadata) metadata.get_table_meta.return_value = table_meta + metadata.get_children.return_value = [] metadata._DTYPES = Metadata._DTYPES # Run @@ -608,23 +609,16 @@ def test_get_primary_key(self): def test_get_foreign_key(self): """Test get foreign key""" # Setup - primary_key = 'a_primary_key' fields = { 'a_field': { 'ref': { + 'table': 'parent', 'field': 'a_primary_key' }, 'name': 'a_field' - }, - 'p_field': { - 'ref': { - 'field': 'another_key_field' - }, - 'name': 'p_field' } } metadata = Mock(spec_set=Metadata) - metadata.get_primary_key.return_value = primary_key metadata.get_fields.return_value = fields # Run @@ -632,7 +626,6 @@ def test_get_foreign_key(self): # Asserts assert result == 'a_field' - metadata.get_primary_key.assert_called_once_with('parent') metadata.get_fields.assert_called_once_with('child') def test_reverse_transform(self): From 8ab9734d2e6e8343865309975097d269c19c33d5 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:36:24 +0200 Subject: [PATCH 13/33] Add benchmark module --- EVALUATION.md | 93 +++++++++++++++++++++++++++++++++++++++++++++- sdv/benchmark.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++ sdv/demo.py | 8 ++-- 3 files changed, 192 insertions(+), 5 deletions(-) create mode 100644 sdv/benchmark.py diff --git a/EVALUATION.md b/EVALUATION.md index 5d64ac6db..a8343184f 100644 --- a/EVALUATION.md +++ b/EVALUATION.md @@ -12,7 +12,7 @@ generate a simple standardized score. After you have modeled your databased and generated samples out of the SDV models you will be left with a dictionary that contains table names and dataframes. -For exmple, if we model and sample the demo dataset: +For example, if we model and sample the demo dataset: ```python3 from sdv import SDV @@ -44,3 +44,94 @@ the value will be negative. For further options, including visualizations and more detailed reports, please refer to the [SDMetrics](https://github.com/sdv-dev/SDMetrics) library. + + +## SDV Benchmark + +SDV also provides a simple functionality to evaluate the performance of SDV across a +collection of demo datasets or custom datasets hosted in a local folder. + +In order to execute this evaluation you can execute the function `sdv.benchmark.run_benchmark`: + +```python3 +from sdv.benchmark import run_benchmark + +scores = run_benchmark() +``` + +This function has the following arguments: + +* `datasets`: List of dataset names, which can either be names of demo datasets or + names of custom datasets stored in a local folder. +* `datasets_path`: Path where the custom datasets are stored. If not provided, the + dataset names are interpreted as demo datasets. +* `distributed`: Whether to execute the benchmark using Dask. Defaults to True. +* `timeout`: Maximum time allowed for each dataset to be modeled, sampled and evaluated. + Any dataset that takes longer to run will return a score of `None`. + +For example, the following command will run the SDV benchmark on all the given demo datasets +using `dask` and a timeout of 60 seconds: + +```python +scores = run_benchmark( + datasets=['DCG_v1', 'trains_v1', 'UTube_v1'], + distributed=True, + timeout=60 +) +``` + +And the result will be a DataFrame containing a table with the columns `dataset`, `score`: + +| dataset | score | +|:-------:|:-----:| +| DCG_v1 | -14.49341665631863 | +| trains_v1 | -30.26840342069557 | +| UTube_v1 | -8.57618576332235 | + +Additionally, if some dataset has raised an error or has reached the timeout, an `error` +column will be added indicating the details. + +### Demo Datasets + +The collection of datasets can be seen using the `sdv.demo.get_demo_demos`, +which returns a table with a description of the dataset properties: + +```python3 +from sdv.demo import get_available_demos + +demos = get_available_demos() +``` + +The result is a table indicating the name of the dataset and a few properties, such as the +number of tables that compose the dataset and the total number of rows and columns: + +| name | tables | rows | columns | +|-----------------------|----------|---------|-----------| +| UTube_v1 | 2 | 2735 | 10 | +| SAP_v1 | 4 | 3841029 | 71 | +| NCAA_v1 | 9 | 202305 | 333 | +| airbnb-simplified | 2 | 5751408 | 22 | +| Atherosclerosis_v1 | 4 | 12781 | 307 | +| rossmann | 3 | 2035533 | 21 | +| walmart | 4 | 544869 | 24 | +| AustralianFootball_v1 | 4 | 139179 | 193 | +| Pyrimidine_v1 | 2 | 296 | 38 | +| world_v1 | 3 | 5302 | 39 | +| Accidents_v1 | 3 | 1463093 | 87 | +| trains_v1 | 2 | 83 | 15 | +| legalActs_v1 | 5 | 1754397 | 50 | +| DCG_v1 | 2 | 8258 | 9 | +| imdb_ijs_v1 | 7 | 5647694 | 50 | +| SalesDB_v1 | 4 | 6735507 | 35 | +| MuskSmall_v1 | 2 | 568 | 173 | +| KRK_v1 | 1 | 1000 | 9 | +| Chess_v1 | 2 | 2052 | 57 | +| Telstra_v1 | 5 | 148021 | 23 | +| mutagenesis_v1 | 3 | 10324 | 26 | +| PremierLeague_v1 | 4 | 11308 | 250 | +| census | 1 | 32561 | 15 | +| FNHK_v1 | 3 | 2113275 | 43 | +| imdb_MovieLens_v1 | 7 | 1249411 | 58 | +| financial_v1 | 8 | 1079680 | 84 | +| ftp_v1 | 2 | 96491 | 13 | +| Triazine_v1 | 2 | 1302 | 35 | diff --git a/sdv/benchmark.py b/sdv/benchmark.py new file mode 100644 index 000000000..7641bd1f5 --- /dev/null +++ b/sdv/benchmark.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import logging +import multiprocessing +import os +import signal +import sys +from datetime import datetime + +import pandas as pd + +from sdv import Metadata, SDV +from sdv.demo import get_available_demos, load_demo +from sdv.evaluation import evaluate + +LOGGER = logging.getLogger(__name__) + + +def _score_dataset(dataset, datasets_path, output): + start = datetime.now() + + try: + if datasets_path is None: + metadata, tables = load_demo(dataset, metadata=True) + else: + metadata = Metadata(os.path.join(datasets_path, dataset, 'metadata.json')) + tables = metadata.load_tables() + + sdv = SDV() + LOGGER.info('Modeling dataset %s', dataset) + sdv.fit(metadata, tables) + + LOGGER.info('Sampling dataset %s', dataset) + sampled = sdv.sample_all(10) + + LOGGER.info('Evaluating dataset %s', dataset) + score = evaluate(sampled, metadata=metadata) + + LOGGER.info('%s: %s - ELAPSED: %s', dataset, score, datetime.now() - start) + output.update({ + 'dataset': dataset, + 'score': score, + }) + + except Exception as ex: + error = '{}: {}'.format(type(ex).__name__, str(ex)) + LOGGER.error('%s: %s - ELAPSED: %s', dataset, error, datetime.now() - start) + output.update({ + 'dataset': dataset, + 'error': error + }) + + +def score_dataset(dataset, datasets_path, timeout=None): + with multiprocessing.Manager() as manager: + output = manager.dict() + process = multiprocessing.Process( + target=_score_dataset, + args=(dataset, datasets_path, output) + ) + + process.start() + process.join(timeout) + process.terminate() + + if not output: + LOGGER.warn('%s: TIMEOUT', dataset) + return { + 'dataset': dataset, + 'error': 'timeout' + } + + return dict(output) + + +def benchmark(datasets=None, datasets_path=None, distributed=True, timeout=None): + if datasets is None: + if datasets_path is None: + datasets = get_available_demos().name + else: + datasets = os.listdir(datasets_path) + + if distributed: + import dask + + global score_dataset + score_dataset = dask.delayed(score_dataset) + + scores = list() + for dataset in datasets: + scores.append(score_dataset(dataset, datasets_path, timeout)) + + if distributed: + scores = dask.compute(*scores) + + return pd.DataFrame(scores) diff --git a/sdv/demo.py b/sdv/demo.py index 672130ada..5e723b4b0 100644 --- a/sdv/demo.py +++ b/sdv/demo.py @@ -104,9 +104,9 @@ def _download(dataset_name, data_path): zf.extractall(data_path) -def _load(dataset_name, data_path): - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) +def _get_dataset_path(dataset_name, data_path): + if not os.path.exists(data_path): + os.makedirs(data_path) if not os.path.exists(os.path.join(data_path, dataset_name)): _download(dataset_name, data_path) @@ -151,7 +151,7 @@ def _load_dummy(): def _load_demo_dataset(dataset_name, data_path): - data_path = _load(dataset_name, data_path) + dataset_path = _get_dataset_path(dataset_name, data_path) meta = Metadata(metadata=os.path.join(data_path, 'metadata.json')) tables = meta.load_tables() return meta, tables From 0cd81bdbd08fdc2b099c233b3466efd7cd6dc52e Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:42:15 +0200 Subject: [PATCH 14/33] Fix lint --- sdv/benchmark.py | 4 +--- sdv/demo.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sdv/benchmark.py b/sdv/benchmark.py index 7641bd1f5..df4365d0e 100644 --- a/sdv/benchmark.py +++ b/sdv/benchmark.py @@ -3,13 +3,11 @@ import logging import multiprocessing import os -import signal -import sys from datetime import datetime import pandas as pd -from sdv import Metadata, SDV +from sdv import SDV, Metadata from sdv.demo import get_available_demos, load_demo from sdv.evaluation import evaluate diff --git a/sdv/demo.py b/sdv/demo.py index 5e723b4b0..9ef350904 100644 --- a/sdv/demo.py +++ b/sdv/demo.py @@ -152,7 +152,7 @@ def _load_dummy(): def _load_demo_dataset(dataset_name, data_path): dataset_path = _get_dataset_path(dataset_name, data_path) - meta = Metadata(metadata=os.path.join(data_path, 'metadata.json')) + meta = Metadata(metadata=os.path.join(dataset_path, 'metadata.json')) tables = meta.load_tables() return meta, tables From daefd3b05a36b1aeea3243c5372a9b72d15b4595 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:51:00 +0200 Subject: [PATCH 15/33] =?UTF-8?q?Bump=20version:=200.3.4.dev0=20=E2=86=92?= =?UTF-8?q?=200.3.4.dev1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdv/__init__.py | 2 +- setup.cfg | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index 8caed53d9..3a0f29f7d 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.3.4.dev0' +__version__ = '0.3.4.dev1' from sdv.demo import get_available_demos, load_demo from sdv.metadata import Metadata diff --git a/setup.cfg b/setup.cfg index 6a2ef34f0..d39776a09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.4.dev0 +current_version = 0.3.4.dev1 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index bca8362a0..3784752e2 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDV', - version='0.3.4.dev0', + version='0.3.4.dev1', zip_safe=False, ) From 83bd3471c777ca40aced2ca43a221b020e0a01d6 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:53:03 +0200 Subject: [PATCH 16/33] Add release notes for v0.3.4 --- HISTORY.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index b2a2e996d..00b164949 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,16 @@ # History +## 0.3.4 - 2020-07-04 + +## New Features + +* Support for Multiple Parents - [Issue #162](https://github.com/sdv-dev/SDV/issues/162) by @csala +* Sample by default the same number of rows as in the original table - [Issue #163](https://github.com/sdv-dev/SDV/issues/163) by @csala + +### General Improvements + +* Add benchmark - [Issue #165](https://github.com/sdv-dev/SDV/issues/165) by @csala + ## 0.3.3 - 2020-06-26 ### General Improvements From 729a59b0d25d672660f96aca32bb2425e75d60fa Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:55:51 +0200 Subject: [PATCH 17/33] =?UTF-8?q?Bump=20version:=200.3.4.dev1=20=E2=86=92?= =?UTF-8?q?=200.3.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdv/__init__.py | 2 +- setup.cfg | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index 3a0f29f7d..e56c24244 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.3.4.dev1' +__version__ = '0.3.4' from sdv.demo import get_available_demos, load_demo from sdv.metadata import Metadata diff --git a/setup.cfg b/setup.cfg index d39776a09..9890ee3b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.4.dev1 +current_version = 0.3.4 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index 3784752e2..53c783ad9 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDV', - version='0.3.4.dev1', + version='0.3.4', zip_safe=False, ) From 21a5fdc27fc88139db80e858f77178aafcdce1cc Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Sat, 4 Jul 2020 14:56:10 +0200 Subject: [PATCH 18/33] =?UTF-8?q?Bump=20version:=200.3.4=20=E2=86=92=200.3?= =?UTF-8?q?.5.dev0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdv/__init__.py | 2 +- setup.cfg | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index e56c24244..d8e94dd37 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.3.4' +__version__ = '0.3.5.dev0' from sdv.demo import get_available_demos, load_demo from sdv.metadata import Metadata diff --git a/setup.cfg b/setup.cfg index 9890ee3b2..e4bf6d2c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.4 +current_version = 0.3.5.dev0 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index 53c783ad9..c8a4a56b5 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDV', - version='0.3.4', + version='0.3.5.dev0', zip_safe=False, ) From a54702957f43234d9e717b24320fcecf2f4801ad Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 13:20:27 +0200 Subject: [PATCH 19/33] Tabular models, work in progress --- sdv/{metadata.py => metadata/__init__.py} | 141 ++----- sdv/metadata/errors.py | 2 + sdv/metadata/table.py | 489 ++++++++++++++++++++++ sdv/metadata/visualization.py | 120 ++++++ sdv/tabular/__init__.py | 0 sdv/tabular/base.py | 180 ++++++++ sdv/tabular/copulas.py | 186 ++++++++ sdv/tabular/utils.py | 219 ++++++++++ 8 files changed, 1224 insertions(+), 113 deletions(-) rename sdv/{metadata.py => metadata/__init__.py} (90%) create mode 100644 sdv/metadata/errors.py create mode 100644 sdv/metadata/table.py create mode 100644 sdv/metadata/visualization.py create mode 100644 sdv/tabular/__init__.py create mode 100644 sdv/tabular/base.py create mode 100644 sdv/tabular/copulas.py create mode 100644 sdv/tabular/utils.py diff --git a/sdv/metadata.py b/sdv/metadata/__init__.py similarity index 90% rename from sdv/metadata.py rename to sdv/metadata/__init__.py index f19ed985c..fc91ad549 100644 --- a/sdv/metadata.py +++ b/sdv/metadata/__init__.py @@ -9,6 +9,17 @@ import pandas as pd from rdt import HyperTransformer, transformers +from sdv.metadata import visualization +from sdv.metadata.errors import MetadataError +from sdv.metadata.table import Table + +__all__ = [ + 'Metadata', + 'MetadataError', + 'Table', + 'visualization' +] + LOGGER = logging.getLogger(__name__) @@ -965,119 +976,7 @@ def to_json(self, path): with open(path, 'w') as out_file: json.dump(self._metadata, out_file, indent=4) - @staticmethod - def _get_graphviz_extension(path): - if path: - path_splitted = path.split('.') - if len(path_splitted) == 1: - raise ValueError('Path without graphviz extansion.') - - graphviz_extension = path_splitted[-1] - - if graphviz_extension not in graphviz.backend.FORMATS: - raise ValueError( - '"{}" not a valid graphviz extension format.'.format(graphviz_extension) - ) - - return '.'.join(path_splitted[:-1]), graphviz_extension - - return None, None - - def _visualize_add_nodes(self, plot): - """Add nodes into a `graphviz.Digraph`. - - Each node represent a metadata table. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - # Append table fields - fields = [] - - for name, value in self.get_fields(table).items(): - if value.get('subtype') is not None: - fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) - - else: - fields.append('{} : {}'.format(name, value['type'])) - - fields = r'\l'.join(fields) - - # Append table extra information - extras = [] - - primary_key = self.get_primary_key(table) - if primary_key is not None: - extras.append('Primary key: {}'.format(primary_key)) - - parents = self.get_parents(table) - for parent in parents: - foreign_key = self.get_foreign_key(parent, table) - extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) - - path = self.get_table_meta(table).get('path') - if path is not None: - extras.append('Data path: {}'.format(path)) - - extras = r'\l'.join(extras) - - # Add table node - title = r'{%s|%s\l|%s\l}' % (table, fields, extras) - plot.node(table, label=title) - - def _visualize_add_edges(self, plot): - """Add edges into a `graphviz.Digraph`. - - Each edge represents a relationship between two metadata tables. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - for parent in list(self.get_parents(table)): - plot.edge( - parent, - table, - label=' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ), - arrowhead='crow' - ) - - def visualize(self, path=None): - """Plot metadata usign graphviz. - - Try to generate a plot using graphviz. - If a ``path`` is provided save the output into a file. - - Args: - path (str): - Output file path to save the plot, it requires a graphviz - supported extension. If ``None`` do not save the plot. - Defaults to ``None``. - """ - filename, graphviz_extension = self._get_graphviz_extension(path) - plot = graphviz.Digraph( - 'Metadata', - format=graphviz_extension, - node_attr={ - "shape": "Mrecord", - "fillcolor": "lightgoldenrod1", - "style": "filled" - }, - ) - - self._visualize_add_nodes(plot) - self._visualize_add_edges(plot) - - if filename: - plot.render(filename=filename, cleanup=True, format=graphviz_extension) - else: - return plot - - def __str__(self): + def __repr__(self): tables = self.get_tables() relationships = [ ' {}.{} -> {}.{}'.format( @@ -1099,3 +998,19 @@ def __str__(self): tables, '\n'.join(relationships) ) + + def visualize(self, path=None): + """Plot metadata usign graphviz. + + Generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + path (str): + Output file path to save the plot. It requires a graphviz + supported extension. If ``None`` do not save the plot and + just return the ``graphviz.Digraph`` object. + Defaults to ``None``. + """ + return visualization.visualize(self, path) + diff --git a/sdv/metadata/errors.py b/sdv/metadata/errors.py new file mode 100644 index 000000000..8e1bada44 --- /dev/null +++ b/sdv/metadata/errors.py @@ -0,0 +1,2 @@ +class MetadataError(Exception): + pass diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py new file mode 100644 index 000000000..85b129d78 --- /dev/null +++ b/sdv/metadata/table.py @@ -0,0 +1,489 @@ +import copy +import json +import logging +import os + +import numpy as np +import pandas as pd +from faker import Faker +from rdt import HyperTransformer, transformers + +from sdv.metadata.error import MetadataError + +LOGGER = logging.getLogger(__name__) + + +class Table: + """Table Metadata. + + The Metadata class provides a unified layer of abstraction over the metadata + of a single Table, which includes both the necessary details to load the data + from the filesystem and to know how to parse and transform it to numerical data. + + Args: + metadata (str or dict): + Path to a ``json`` file that contains the metadata or a ``dict`` representation + of ``metadata`` following the same structure. + root_path (str): + The path to which the dataset is located. Defaults to ``None``. + """ + + _metadata = None + _hyper_transformer = None + _root_path = None + + _FIELD_TEMPLATES = { + 'i': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'f': { + 'type': 'numerical', + 'subtype': 'float', + }, + 'O': { + 'type': 'categorical', + }, + 'b': { + 'type': 'boolean', + }, + 'M': { + 'type': 'datetime', + } + } + _DTYPES = { + ('categorical', None): 'object', + ('boolean', None): 'bool', + ('numerical', None): 'float', + ('numerical', 'float'): 'float', + ('numerical', 'integer'): 'int', + ('datetime', None): 'datetime64', + ('id', None): 'int', + ('id', 'integer'): 'int', + ('id', 'string'): 'str' + } + + def __init__(self, metadata=None, root_path=None, field_names=None, primary_key=None, + field_types=None, anonymize_fields=None, constraints=None,): + self._metadata = metadata or dict() + self._hyper_transformer = None + self._root_path = None + + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types or {} + self._anonymize_fields = anonymize_fields + self._constraints = constraints + + def _get_field_dtype(self, field_name, field_metadata): + field_type = field_metadata['type'] + field_subtype = field_metadata.get('subtype') + dtype = self._DTYPES.get((field_type, field_subtype)) + if not dtype: + raise MetadataError( + 'Invalid type and subtype combination for field {}: ({}, {})'.format( + field_name, field_type, field_subtype) + ) + + return dtype + + def get_dtypes(self, ids=False): + """Get a ``dict`` with the ``dtypes`` for each field of the table. + + Args: + ids (bool): + Whether or not include the id fields. Defaults to ``False``. + + Returns: + dict: + Dictionary that contains the field names and data types. + + Raises: + ValueError: + If a field has an invalid type or subtype. + """ + dtypes = dict() + for name, field_meta in self._metadata['fields'].items(): + field_type = field_meta['type'] + + if ids or (field_type != 'id'): + dtypes[name] = self._get_field_dtype(name, field_meta) + + return dtypes + + def _get_faker(self): + """Return the faker object to anonymize data. + + Returns: + function: + Faker function to generate new data instances with ``self.anonymize`` arguments. + + Raises: + ValueError: + A ``ValueError`` is raised if the faker category we want don't exist. + """ + if isinstance(self._anonymize, (tuple, list)): + category, *args = self._anonymize + else: + category = self._anonymize + args = tuple() + + try: + faker_method = getattr(Faker(), category) + + def faker(): + return faker_method(*args) + + return faker + except AttributeError: + raise ValueError('Category "{}" couldn\'t be found on faker'.format(self.anonymize)) + + def _anonymize(self, data): + """Anonymize data and save the anonymization mapping in-memory.""" + # TODO: Do this by column + faker = self._get_faker() + uniques = data.unique() + fake_data = [faker() for x in range(len(uniques))] + + mapping = dict(zip(uniques, fake_data)) + MAPS[id(self)] = mapping + + return data.map(mapping) + + def _build_fields_metadata(self, data): + """Build all the fields metadata. + + Args: + data (pandas.DataFrame): + Data to be analyzed. + + Returns: + dict: + Dict of valid fields. + + Raises: + ValueError: + If a column from the data analyzed is an unsupported data type + """ + field_names = self._field_names or data.columns + + fields_metadata = dict() + for field_name in field_names: + if not field_name in data: + raise ValueError('Field {} not found in given data'.format(field_name)) + + field_meta = self._field_types.get(field_name) + if field_meta: + # Validate the given meta + self._get_field_dtype(field_name, field_meta) + else: + dtype = data[field_name].dtype + field_template = self._FIELD_TEMPLATES.get(dtype.kind) + if field_template is None: + raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field_name)) + + field_meta = copy.deepcopy(field_template) + + fields_metadata[field_name] = field_meta + + return fields_metadata + + def _get_pii_fields(self): + """Get the ``pii_category`` for each field that contains PII. + + Returns: + dict: + pii field names and categories. + """ + pii_fields = dict() + for name, field in self._metadata['fields'].items(): + if field['type'] == 'categorical' and field.get('pii', False): + pii_fields[name] = field['pii_category'] + + return pii_fields + + def _get_transformers(self, dtypes): + """Create the transformer instances needed to process the given dtypes. + + Temporary drop-in replacement of ``HyperTransformer._analyze`` method, + before RDT catches up. + + Args: + dtypes (dict): + mapping of field names and dtypes. + pii_fields (dict): + mapping of pii field names and categories. + + Returns: + dict: + mapping of field names and transformer instances. + """ + transformers_dict = dict() + pii_fields = self._get_pii_fields() + for name, dtype in dtypes.items(): + dtype = np.dtype(dtype) + if dtype.kind == 'i': + transformer = transformers.NumericalTransformer(dtype=int) + elif dtype.kind == 'f': + transformer = transformers.NumericalTransformer(dtype=float) + elif dtype.kind == 'O': + anonymize = pii_fields.get(name) + transformer = transformers.CategoricalTransformer(anonymize=anonymize) + elif dtype.kind == 'b': + transformer = transformers.BooleanTransformer() + elif dtype.kind == 'M': + transformer = transformers.DatetimeTransformer() + else: + raise ValueError('Unsupported dtype: {}'.format(dtype)) + + LOGGER.info('Loading transformer %s for field %s', + transformer.__class__.__name__, name) + transformers_dict[name] = transformer + + return transformers_dict + + def _fit_hyper_transformer(self, data): + """Create and return a new ``rdt.HyperTransformer`` instance. + + First get the ``dtypes`` and then use them to build a transformer dictionary + to be used by the ``HyperTransformer``. + + Returns: + rdt.HyperTransformer + """ + dtypes = self.get_dtypes(ids=False) + transformers_dict = self._get_transformers(dtypes) + self._hyper_transformer = HyperTransformer(transformers=transformers_dict) + self._hyper_transformer.fit(data[list(dtypes.keys())]) + + def set_primary_key(self, field_name): + """Set the primary key of this table. + + The field must exist and either be an integer or categorical field. + + Args: + field_name (str): + Name of the field to be used as the new primary key. + + Raises: + ValueError: + If the table or the field do not exist or if the field has an + invalid type or subtype. + """ + if field_name is not None: + if field_name not in self.get_fields(): + raise ValueError('Field "{}" does not exist in this table'.format(field_name)) + + field_metadata = self._metadata['fields'][field_name] + field_subtype = self._get_key_subtype(field_metadata) + + field_metadata.update({ + 'type': 'id', + 'subtype': field_subtype + }) + + self._primary_key = field_name + + def fit(self, data): + """Fit this metadata to the given data. + + Args: + data (pandas.DataFrame): + Table to be analyzed. + """ + self._field_names = self._field_names or list(data.columns) + self._metadata['fields'] = self._build_fields_metadata(data) + self.set_primary_key(self._primary_key) + + # TODO: Treat/Learn constraints + + self._fit_hyper_transformer(data) + + def get_fields(self): + """Get fields metadata. + + Returns: + dict: + Mapping of field names and their metadata dicts. + """ + return self._metadata['fields'] + + def transform(self, data): + """Transform the given data. + + Args: + data (pandas.DataFrame): + Table data. + + Returns: + pandas.DataFrame: + Transformed data. + """ + + # TODO: Do this by column + # if self.anonymize: + # data = data.map(MAPS[id(self)]) + + fields = list(self._hyper_transformer.transformers.keys()) + return self._hyper_transformer.transform(data[fields]) + + def reverse_transform(self, data): + """Reverse the transformed data to the original format. + + Args: + data (pandas.DataFrame): + Data to be reverse transformed. + + Returns: + pandas.DataFrame + """ + reversed_data = self._hyper_transformer.reverse_transform(data) + + fields = self._metadata['fields'] + for name, dtype in self.get_dtypes(ids=True).items(): + field_type = fields[name]['type'] + if field_type == 'id': + field_data = pd.Series(np.arange(len(reversed_data))) + else: + field_data = reversed_data[name] + + reversed_data[name] = field_data.dropna().astype(dtype) + + return reversed_data[self._field_names] + + def get_children(self): + """Get tables for which this table is parent. + + Returns: + set: + Set of children for this table. + """ + return self._children + + def get_parents(self): + """Get tables for with this table is child. + + Returns: + set: + Set of parents for this table. + """ + return self._parents + + def get_field(self, field_name): + """Get the metadata dict for a field. + + Args: + field_name (str): + Name of the field to get data for. + + Returns: + dict: + field metadata + + Raises: + ValueError: + If the table or the field do not exist in this metadata. + """ + field_meta = self._metadata['fields'].get(field_name) + if field_meta is None: + raise ValueError('Invalid field name "{}"'.format(field_name)) + + return copy.deepcopy(field_meta) + + # def _read_csv_dtypes(self): + # """Get the dtypes specification that needs to be passed to read_csv.""" + # dtypes = dict() + # for name, field in self._metadata['fields'].items(): + # field_type = field['type'] + # if field_type == 'categorical': + # dtypes[name] = str + # elif field_type == 'id' and field.get('subtype', 'integer') == 'string': + # dtypes[name] = str + + # return dtypes + + # def _parse_dtypes(self, data): + # """Convert the data columns to the right dtype after loading the CSV.""" + # for name, field in self._metadata['fields'].items(): + # field_type = field['type'] + # if field_type == 'datetime': + # datetime_format = field.get('format') + # data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) + # elif field_type == 'numerical' and field.get('subtype') == 'integer': + # data[name] = data[name].dropna().astype(int) + # elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': + # data[name] = data[name].dropna().astype(int) + + # return data + + # def load(self): + # """Load table data. + + # First load the CSV with the right dtypes and then parse the columns + # to the final dtypes. + + # Returns: + # pandas.DataFrame: + # DataFrame with the contents of the table. + # """ + # relative_path = os.path.join(self.root_path, self.path) + # dtypes = self._read_csv_dtypes() + + # data = pd.read_csv(relative_path, dtype=dtypes) + # data = self._parse_dtypes(data) + + # return data + + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + + if field_type == 'categorical': + field_subtype = 'string' + + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + + def _check_field(self, field, exists=False): + """Validate the existance of the table and existance (or not) of field.""" + table_fields = self.get_fields(table) + if exists and (field not in table_fields): + raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) + + if not exists and (field in table_fields): + raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) + + # ###################### # + # Metadata Serialization # + # ###################### # + + def to_dict(self): + """Get a dict representation of this metadata. + + Returns: + dict: + dict representation of this metadata. + """ + return copy.deepcopy(self._metadata) + + def to_json(self, path): + """Dump this metadata into a JSON file. + + Args: + path (str): + Path of the JSON file where this metadata will be stored. + """ + with open(path, 'w') as out_file: + json.dump(self._metadata, out_file, indent=4) diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py new file mode 100644 index 000000000..1eb1dca6d --- /dev/null +++ b/sdv/metadata/visualization.py @@ -0,0 +1,120 @@ +import graphviz + + +def _get_graphviz_extension(path): + if path: + path_splitted = path.split('.') + if len(path_splitted) == 1: + raise ValueError('Path without graphviz extansion.') + + graphviz_extension = path_splitted[-1] + + if graphviz_extension not in graphviz.backend.FORMATS: + raise ValueError( + '"{}" not a valid graphviz extension format.'.format(graphviz_extension) + ) + + return '.'.join(path_splitted[:-1]), graphviz_extension + + return None, None + +def _add_nodes(metadata, digraph): + """Add nodes into a `graphviz.Digraph`. + + Each node represent a metadata table. + + Args: + metadata (Metadata): + Metadata object to plot. + digraph (graphviz.Digraph): + graphviz.Digraph being built + """ + for table in metadata.get_tables(): + # Append table fields + fields = [] + + for name, value in metadata.get_fields(table).items(): + if value.get('subtype') is not None: + fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) + + else: + fields.append('{} : {}'.format(name, value['type'])) + + fields = r'\l'.join(fields) + + # Append table extra information + extras = [] + + primary_key = metadata.get_primary_key(table) + if primary_key is not None: + extras.append('Primary key: {}'.format(primary_key)) + + parents = metadata.get_parents(table) + for parent in parents: + foreign_key = metadata.get_foreign_key(parent, table) + extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) + + path = metadata.get_table_meta(table).get('path') + if path is not None: + extras.append('Data path: {}'.format(path)) + + extras = r'\l'.join(extras) + + # Add table node + title = r'{%s|%s\l|%s\l}' % (table, fields, extras) + digraph.node(table, label=title) + + +def _add_edges(metadata, digraph): + """Add edges into a `graphviz.Digraph`. + + Each edge represents a relationship between two metadata tables. + + Args: + digraph (graphviz.Digraph) + """ + for table in metadata.get_tables(): + for parent in list(metadata.get_parents(table)): + digraph.edge( + parent, + table, + label=' {}.{} -> {}.{}'.format( + table, metadata.get_foreign_key(parent, table), + parent, metadata.get_primary_key(parent) + ), + arrowhead='oinv' + ) + + +def visualize(metadata, path=None): + """Plot metadata usign graphviz. + + Try to generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + metadata (Metadata): + Metadata object to plot. + path (str): + Output file path to save the plot, it requires a graphviz + supported extension. If ``None`` do not save the plot. + Defaults to ``None``. + """ + filename, graphviz_extension = _get_graphviz_extension(path) + digraph = graphviz.Digraph( + 'Metadata', + format=graphviz_extension, + node_attr={ + "shape": "Mrecord", + "fillcolor": "lightgoldenrod1", + "style": "filled" + }, + ) + + _add_nodes(metadata, digraph) + _add_edges(metadata, digraph) + + if filename: + digraph.render(filename=filename, cleanup=True, format=graphviz_extension) + else: + return digraph diff --git a/sdv/tabular/__init__.py b/sdv/tabular/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py new file mode 100644 index 000000000..940282fd4 --- /dev/null +++ b/sdv/tabular/base.py @@ -0,0 +1,180 @@ +from sdv.metadata import Table + + +ANONYMIZATION_MAPS = {} + + +class BaseTabularModel(): + """Base class for all the tabular models. + + The ``BaseTabularModel`` class defines the common API that all the + TabularModels need to implement, as well as common functionality. + """ + + _metadata = None + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, constraints=None, table_metadata=None, + model_kwargs=None, *args, **kwargs): + """Initialize a Tabular Model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. + anonymize_fields (dict[str, str]): + Dict specifying which fields to anonymize and what faker + category they belong to. + constraints (list[dict]): + List of dicts specifying field and inter-field constraints. + TODO: Format TBD + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + *args, **kwargs: + Subclasses will add any arguments or keyword arguments needed. + """ + if table_metadata is not None: + if isinstance(table_metadata, dict): + table_metadata = Table(table_metadata) + + for arg in (field_names, primary_key, field_types, anonymize, constraints): + if arg: + raise ValueError( + 'If table_metadata is given {} must be None'.format(arg.__name__)) + + self._metadata = table_metadata + + else: + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types + self._anonymize_fields = anonymize_fields + self._constraints = constraints + + self._model_kwargs = model_kwargs + + def _fit_metadata(self, data): + metadata = Table( + field_names=self._field_names, + primary_key=self._primary_key, + field_types=self._field_types, + anonymize_fields=self._anonymize_fields, + constraints=self._constraints, + ) + metadata.fit(data) + + self._metadata = metadata + + def fit(self, data): + """Fit this model to the data. + + If table metadata has not been given, learn it from the data. + + Args: + data (pandas.DataFrame or str): + Data to fit the model to. It can be passed as a + ``pandas.DataFrame`` or as an ``str``. + If an ``str`` is passed, it is assumed to be + the path to a CSV file which can be loaded using + ``pandas.read_csv``. + """ + if self._metadata is None: + self._fit_metadata(data) + + transformed = self._metadata.transform(data) + self._fit(transformed) + + def get_metadata(self): + """Get metadata about the table. + + This will return an ``sdv.metadata.Table`` object containing + the information about the data that this model has learned. + + This Table metadata will contain some common information, + such as field names and data types, as well as additional + information that each Sub-class might add, such as the + observed data field distributions and their parameters. + + Returns: + sdv.metadata.Table: + Table metadata. + """ + return self._metadata + + def sample(self, size=None, values=None): + """Sample rows from this table. + + Args: + size (int): + Number of rows to sample. If not given the model + will generate as many rows as there were in the + data passed to the ``fit`` method. + values (dict): <- FUTURE + Fixed values to use for knowledge-based sampling. + In case the model does not support knowledge-based + sampling, a discard+resample strategy will be used. + + Returns: + pandas.DataFrame: + Sampled data. + """ + sampled = self._sample(size) + return self._metadata.reverse_transform(sampled) + + def get_parameters(self): + """Get the parameters learned from the data. + + The result is a flat dict (single level) which contains + all the necessary parameters to be able to reproduce + this model. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + parameters (dict): + flat dict (single level) which contains all the + necessary parameters to be able to reproduce + this model. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() + + @classmethod + def from_parameters(cls): + """Regenerate a previously learned model from its parameters. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + BaseTabularModel: + New instance with its parameters set. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py new file mode 100644 index 000000000..389fdc6ff --- /dev/null +++ b/sdv/tabular/copulas.py @@ -0,0 +1,186 @@ +import numpy as np +import copulas + +from sdv.tabular.base import BaseTabularModel +from sdv.tabular.utils import ( + check_matrix_symmetric_positive_definite, flatten_dict, make_positive_definite, + square_matrix, unflatten_dict) + + +class GaussianCopula(BaseTabularModel): + """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. + + Args: + distribution (copulas.univariate.Univariate or str): + Copulas univariate distribution to use. + """ + + DISTRIBUTION = copulas.univariate.GaussianUnivariate + _distribution = None + _model = None + + HYPERPARAMETERS = { + 'distribution': { + 'type': 'str or copulas.univariate.Univariate', + 'default': 'copulas.univariate.Univariate', + 'choices': [ + 'copulas.univariate.Univariate', + 'copulas.univariate.GaussianUnivariate', + 'copulas.univariate.GammaUnivariate', + 'copulas.univariate.BetaUnivariate', + 'copulas.univariate.StudentTUnivariate', + 'copulas.univariate.GaussianKDE', + 'copulas.univariate.TruncatedGaussian', + ] + } + } + + def __init__(self, distribution=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distribution = distribution or self.DISTRIBUTION + + def _update_metadata(self): + parameters = self._model.to_dict() + univariates = parameters['univariates'] + columns = parameters['columns'] + + fields = self._metadata.get_fields() + for field_name, univariate in zip(columns, univariates): + field_meta = fields[field_name] + field_meta['distribution'] = univariate['type'] + + def _fit(self, data): + """Fit the model to the table. + + Args: + table_data (pandas.DataFrame): + Data to be fitted. + """ + self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) + self._model.fit(data) + # self._update_metadata() + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) + + def get_parameters(self): + """Get copula model parameters. + + Compute model ``covariance`` and ``distribution.std`` + before it returns the flatten dict. + + Returns: + dict: + Copula flatten parameters. + """ + values = list() + triangle = np.tril(self._model.covariance) + + for index, row in enumerate(triangle.tolist()): + values.append(row[:index + 1]) + + self._model.covariance = np.array(values) + params = self._model.to_dict() + univariates = dict() + for name, univariate in zip(params.pop('columns'), params['univariates']): + univariates[name] = univariate + if 'scale' in univariate: + scale = univariate['scale'] + if scale == 0: + scale = copulas.EPSILON + + univariate['scale'] = np.log(scale) + + params['univariates'] = univariates + + return flatten_dict(params) + + def _prepare_sampled_covariance(self, covariance): + """Prepare a covariance matrix. + + Args: + covariance (list): + covariance after unflattening model parameters. + + Result: + list[list]: + symmetric Positive semi-definite matrix. + """ + covariance = np.array(square_matrix(covariance)) + covariance = (covariance + covariance.T - (np.identity(covariance.shape[0]) * covariance)) + + if not check_matrix_symmetric_positive_definite(covariance): + covariance = make_positive_definite(covariance) + + return covariance.tolist() + + def _unflatten_gaussian_copula(self, model_parameters): + """Prepare unflattened model params to recreate Gaussian Multivariate instance. + + The preparations consist basically in: + + - Transform sampled negative standard deviations from distributions into positive + numbers + + - Ensure the covariance matrix is a valid symmetric positive-semidefinite matrix. + + - Add string parameters kept inside the class (as they can't be modelled), + like ``distribution_type``. + + Args: + model_parameters (dict): + Sampled and reestructured model parameters. + + Returns: + dict: + Model parameters ready to recreate the model. + """ + + univariate_kwargs = { + 'type': model_parameters['distribution'] + } + + columns = list() + univariates = list() + for column, univariate in model_parameters['univariates'].items(): + columns.append(column) + univariate.update(univariate_kwargs) + if 'scale' in univariate: + univariate['scale'] = np.exp(univariate['scale']) + + univariates.append(univariate) + + model_parameters['univariates'] = univariates + model_parameters['columns'] = columns + + covariance = model_parameters.get('covariance') + model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) + + return model_parameters + + def set_parameters(self, parameters): + """Set copula model parameters. + + Add additional keys after unflatte the parameters + in order to set expected parameters for the copula. + + Args: + dict: + Copula flatten parameters. + """ + parameters = unflatten_dict(parameters) + parameters.setdefault('distribution', self.distribution) + + parameters = self._unflatten_gaussian_copula(parameters) + + self._model = copulas.multivariate.GaussianMultivariate.from_dict(parameters) diff --git a/sdv/tabular/utils.py b/sdv/tabular/utils.py new file mode 100644 index 000000000..241a20048 --- /dev/null +++ b/sdv/tabular/utils.py @@ -0,0 +1,219 @@ +import numpy as np + +IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] + + +def flatten_array(nested, prefix=''): + """Flatten an array as a dict. + + Args: + nested (list, numpy.array): + Iterable to flatten. + prefix (str): + Name to append to the array indices. Defaults to ``''``. + + Returns: + dict: + Flattened array. + """ + result = dict() + for index in range(len(nested)): + prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) + + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def flatten_dict(nested, prefix=''): + """Flatten a dictionary. + + This method returns a flatten version of a dictionary, concatenating key names with + double underscores. + + Args: + nested (dict): + Original dictionary to flatten. + prefix (str): + Prefix to append to key name. Defaults to ``''``. + + Returns: + dict: + Flattened dictionary. + """ + result = dict() + + for key, value in nested.items(): + prefix_key = '__'.join([prefix, str(key)]) if len(prefix) else key + + if key in IGNORED_DICT_KEYS and not isinstance(value, (dict, list)): + continue + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix_key)) + + elif isinstance(value, (np.ndarray, list)): + result.update(flatten_array(value, prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def _key_order(key_value): + parts = list() + for part in key_value[0].split('__'): + if part.isdigit(): + part = int(part) + + parts.append(part) + + return parts + + +def unflatten_dict(flat): + """Transform a flattened dict into its original form. + + Args: + flat (dict): + Flattened dict. + + Returns: + dict: + Nested dict (if corresponds) + """ + unflattened = dict() + + for key, value in sorted(flat.items(), key=_key_order): + if '__' in key: + key, subkey = key.split('__', 1) + subkey, name = subkey.rsplit('__', 1) + + if name.isdigit(): + column_index = int(name) + row_index = int(subkey) + + array = unflattened.setdefault(key, list()) + + if len(array) == row_index: + row = list() + array.append(row) + elif len(array) == row_index + 1: + row = array[row_index] + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + if len(row) == column_index: + row.append(value) + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + else: + subdict = unflattened.setdefault(key, dict()) + if subkey.isdigit(): + subkey = int(subkey) + + inner = subdict.setdefault(subkey, dict()) + inner[name] = value + + else: + unflattened[key] = value + + return unflattened + + +def impute(data): + for column in data: + column_data = data[column] + if column_data.dtype in (np.int, np.float): + fill_value = column_data.mean() + else: + fill_value = column_data.mode()[0] + + data[column] = data[column].fillna(fill_value) + + return data + + +def square_matrix(triangular_matrix): + """Fill with zeros a triangular matrix to reshape it to a square one. + + Args: + triangular_matrix (list [list [float]]): + Array of arrays of + + Returns: + list: + Square matrix. + """ + length = len(triangular_matrix) + zero = [0.0] + + for item in triangular_matrix: + item.extend(zero * (length - len(item))) + + return triangular_matrix + + +def check_matrix_symmetric_positive_definite(matrix): + """Check if a matrix is symmetric positive-definite. + + Args: + matrix (list or numpy.ndarray): + Matrix to evaluate. + + Returns: + bool + """ + try: + if len(matrix.shape) != 2 or matrix.shape[0] != matrix.shape[1]: + # Not 2-dimensional or square, so not simmetric. + return False + + np.linalg.cholesky(matrix) + return True + + except np.linalg.LinAlgError: + return False + + +def make_positive_definite(matrix): + """Find the nearest positive-definite matrix to input. + + Args: + matrix (numpy.ndarray): + Matrix to transform + + Returns: + numpy.ndarray: + Closest symetric positive-definite matrix. + """ + symetric_matrix = (matrix + matrix.T) / 2 + _, s, V = np.linalg.svd(symetric_matrix) + symmetric_polar = np.dot(V.T, np.dot(np.diag(s), V)) + A2 = (symetric_matrix + symmetric_polar) / 2 + A3 = (A2 + A2.T) / 2 + + if check_matrix_symmetric_positive_definite(A3): + return A3 + + spacing = np.spacing(np.linalg.norm(matrix)) + identity = np.eye(matrix.shape[0]) + iterations = 1 + while not check_matrix_symmetric_positive_definite(A3): + min_eigenvals = np.min(np.real(np.linalg.eigvals(A3))) + A3 += identity * (-min_eigenvals * iterations**2 + spacing) + iterations += 1 + + return A3 From 25284d4e092095c3d512d47eae16391e84115536 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 19:46:35 +0200 Subject: [PATCH 20/33] Tabular models. WIP. Finish Table and Copulas and add CTGAN --- sdv/metadata/__init__.py | 6 - sdv/metadata/table.py | 345 ++++++++-------------- sdv/metadata/visualization.py | 1 + sdv/tabular/base.py | 32 +- sdv/tabular/copulas.py | 38 ++- sdv/tabular/ctgan.py | 142 +++++++++ setup.py | 17 +- tests/integration/tabular/test_copulas.py | 48 +++ tests/integration/tabular/test_ctgan.py | 29 ++ 9 files changed, 408 insertions(+), 250 deletions(-) create mode 100644 sdv/tabular/ctgan.py create mode 100644 tests/integration/tabular/test_copulas.py create mode 100644 tests/integration/tabular/test_ctgan.py diff --git a/sdv/metadata/__init__.py b/sdv/metadata/__init__.py index fc91ad549..d9b3ad368 100644 --- a/sdv/metadata/__init__.py +++ b/sdv/metadata/__init__.py @@ -4,7 +4,6 @@ import os from collections import defaultdict -import graphviz import numpy as np import pandas as pd from rdt import HyperTransformer, transformers @@ -62,10 +61,6 @@ def _load_csv(root_path, table_meta): return data -class MetadataError(Exception): - pass - - class Metadata: """Dataset Metadata. @@ -1013,4 +1008,3 @@ def visualize(self, path=None): Defaults to ``None``. """ return visualization.visualize(self, path) - diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py index 85b129d78..893549272 100644 --- a/sdv/metadata/table.py +++ b/sdv/metadata/table.py @@ -1,14 +1,13 @@ import copy import json import logging -import os import numpy as np import pandas as pd +import rdt from faker import Faker -from rdt import HyperTransformer, transformers -from sdv.metadata.error import MetadataError +from sdv.metadata.errors import MetadataError LOGGER = logging.getLogger(__name__) @@ -19,18 +18,11 @@ class Table: The Metadata class provides a unified layer of abstraction over the metadata of a single Table, which includes both the necessary details to load the data from the filesystem and to know how to parse and transform it to numerical data. - - Args: - metadata (str or dict): - Path to a ``json`` file that contains the metadata or a ``dict`` representation - of ``metadata`` following the same structure. - root_path (str): - The path to which the dataset is located. Defaults to ``None``. """ _metadata = None _hyper_transformer = None - _root_path = None + _anonymization_mappings = None _FIELD_TEMPLATES = { 'i': { @@ -63,11 +55,47 @@ class Table: ('id', 'string'): 'str' } - def __init__(self, metadata=None, root_path=None, field_names=None, primary_key=None, - field_types=None, anonymize_fields=None, constraints=None,): + def _get_faker(self, category): + """Return the faker object to anonymize data. + + Args: + category (str or tuple): + Fake category to use. If a tuple is passed, the first element is + the category and the rest are additional arguments for the Faker. + + Returns: + function: + Faker function to generate new fake data instances. + + Raises: + ValueError: + A ``ValueError`` is raised if the faker category we want don't exist. + """ + if isinstance(category, (tuple, list)): + category, *args = category + else: + args = tuple() + + try: + faker_method = getattr(Faker(), category) + + if not args: + return faker_method + + def faker(): + return faker_method(*args) + + return faker + + except AttributeError: + raise ValueError('Category "{}" couldn\'t be found on faker'.format(category)) + + def __init__(self, metadata=None, field_names=None, primary_key=None, + field_types=None, anonymize_fields=None, constraints=None, + transformer_templates=None): + # TODO: Validate that the given metadata is a valid dict self._metadata = metadata or dict() self._hyper_transformer = None - self._root_path = None self._field_names = field_names self._primary_key = primary_key @@ -75,6 +103,13 @@ def __init__(self, metadata=None, root_path=None, field_names=None, primary_key= self._anonymize_fields = anonymize_fields self._constraints = constraints + self._transformer_templates = transformer_templates or {} + + self._fakers = { + name: self._get_faker(category) + for name, category in (anonymize_fields or {}).items() + } + def _get_field_dtype(self, field_name, field_metadata): field_type = field_metadata['type'] field_subtype = field_metadata.get('subtype') @@ -87,6 +122,15 @@ def _get_field_dtype(self, field_name, field_metadata): return dtype + def get_fields(self): + """Get fields metadata. + + Returns: + dict: + Dictionary of fields metadata for this table. + """ + return copy.deepcopy(self._metadata['fields']) + def get_dtypes(self, ids=False): """Get a ``dict`` with the ``dtypes`` for each field of the table. @@ -111,45 +155,6 @@ def get_dtypes(self, ids=False): return dtypes - def _get_faker(self): - """Return the faker object to anonymize data. - - Returns: - function: - Faker function to generate new data instances with ``self.anonymize`` arguments. - - Raises: - ValueError: - A ``ValueError`` is raised if the faker category we want don't exist. - """ - if isinstance(self._anonymize, (tuple, list)): - category, *args = self._anonymize - else: - category = self._anonymize - args = tuple() - - try: - faker_method = getattr(Faker(), category) - - def faker(): - return faker_method(*args) - - return faker - except AttributeError: - raise ValueError('Category "{}" couldn\'t be found on faker'.format(self.anonymize)) - - def _anonymize(self, data): - """Anonymize data and save the anonymization mapping in-memory.""" - # TODO: Do this by column - faker = self._get_faker() - uniques = data.unique() - fake_data = [faker() for x in range(len(uniques))] - - mapping = dict(zip(uniques, fake_data)) - MAPS[id(self)] = mapping - - return data.map(mapping) - def _build_fields_metadata(self, data): """Build all the fields metadata. @@ -169,7 +174,7 @@ def _build_fields_metadata(self, data): fields_metadata = dict() for field_name in field_names: - if not field_name in data: + if field_name not in data: raise ValueError('Field {} not found in given data'.format(field_name)) field_meta = self._field_types.get(field_name) @@ -188,59 +193,39 @@ def _build_fields_metadata(self, data): return fields_metadata - def _get_pii_fields(self): - """Get the ``pii_category`` for each field that contains PII. - - Returns: - dict: - pii field names and categories. - """ - pii_fields = dict() - for name, field in self._metadata['fields'].items(): - if field['type'] == 'categorical' and field.get('pii', False): - pii_fields[name] = field['pii_category'] - - return pii_fields - def _get_transformers(self, dtypes): """Create the transformer instances needed to process the given dtypes. - Temporary drop-in replacement of ``HyperTransformer._analyze`` method, - before RDT catches up. - Args: dtypes (dict): mapping of field names and dtypes. - pii_fields (dict): - mapping of pii field names and categories. Returns: dict: mapping of field names and transformer instances. """ - transformers_dict = dict() - pii_fields = self._get_pii_fields() + transformer_templates = { + 'i': rdt.transformers.NumericalTransformer(dtype=int), + 'f': rdt.transformers.NumericalTransformer(dtype=float), + 'O': rdt.transformers.CategoricalTransformer, + 'b': rdt.transformers.BooleanTransformer, + 'M': rdt.transformers.DatetimeTransformer, + } + transformer_templates.update(self._transformer_templates) + + transformers = dict() for name, dtype in dtypes.items(): - dtype = np.dtype(dtype) - if dtype.kind == 'i': - transformer = transformers.NumericalTransformer(dtype=int) - elif dtype.kind == 'f': - transformer = transformers.NumericalTransformer(dtype=float) - elif dtype.kind == 'O': - anonymize = pii_fields.get(name) - transformer = transformers.CategoricalTransformer(anonymize=anonymize) - elif dtype.kind == 'b': - transformer = transformers.BooleanTransformer() - elif dtype.kind == 'M': - transformer = transformers.DatetimeTransformer() + transformer_template = transformer_templates[np.dtype(dtype).kind] + if isinstance(transformer_template, type): + transformer = transformer_template() else: - raise ValueError('Unsupported dtype: {}'.format(dtype)) + transformer = copy.deepcopy(transformer_template) LOGGER.info('Loading transformer %s for field %s', transformer.__class__.__name__, name) - transformers_dict[name] = transformer + transformers[name] = transformer - return transformers_dict + return transformers def _fit_hyper_transformer(self, data): """Create and return a new ``rdt.HyperTransformer`` instance. @@ -253,9 +238,31 @@ def _fit_hyper_transformer(self, data): """ dtypes = self.get_dtypes(ids=False) transformers_dict = self._get_transformers(dtypes) - self._hyper_transformer = HyperTransformer(transformers=transformers_dict) + self._hyper_transformer = rdt.HyperTransformer(transformers=transformers_dict) self._hyper_transformer.fit(data[list(dtypes.keys())]) + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + + if field_type == 'categorical': + field_subtype = 'string' + + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + def set_primary_key(self, field_name): """Set the primary key of this table. @@ -271,7 +278,7 @@ def set_primary_key(self, field_name): invalid type or subtype. """ if field_name is not None: - if field_name not in self.get_fields(): + if field_name not in self._metadata['fields']: raise ValueError('Field "{}" does not exist in this table'.format(field_name)) field_metadata = self._metadata['fields'][field_name] @@ -284,6 +291,23 @@ def set_primary_key(self, field_name): self._primary_key = field_name + def _make_anonymization_mappings(self, data): + mappings = {} + for name, faker in self._fakers.items(): + uniques = data[name].unique() + fake_values = [faker() for x in range(len(uniques))] + mappings[name] = dict(zip(uniques, fake_values)) + + self._anonymization_mappings = mappings + + def _anonymize(self, data): + if self._anonymization_mappings: + data = data.copy() + for name, mapping in self._anonymization_mappings.items(): + data[name] = data[name].map(mapping) + + return data + def fit(self, data): """Fit this metadata to the given data. @@ -295,19 +319,13 @@ def fit(self, data): self._metadata['fields'] = self._build_fields_metadata(data) self.set_primary_key(self._primary_key) - # TODO: Treat/Learn constraints + if self._fakers: + self._make_anonymization_mappings(data) + data = self._anonymize(data) + # TODO: Treat/Learn constraints self._fit_hyper_transformer(data) - def get_fields(self): - """Get fields metadata. - - Returns: - dict: - Mapping of field names and their metadata dicts. - """ - return self._metadata['fields'] - def transform(self, data): """Transform the given data. @@ -319,13 +337,8 @@ def transform(self, data): pandas.DataFrame: Transformed data. """ - - # TODO: Do this by column - # if self.anonymize: - # data = data.map(MAPS[id(self)]) - - fields = list(self._hyper_transformer.transformers.keys()) - return self._hyper_transformer.transform(data[fields]) + data = self._anonymize(data[self._field_names]) + return self._hyper_transformer.transform(data) def reverse_transform(self, data): """Reverse the transformed data to the original format. @@ -351,120 +364,6 @@ def reverse_transform(self, data): return reversed_data[self._field_names] - def get_children(self): - """Get tables for which this table is parent. - - Returns: - set: - Set of children for this table. - """ - return self._children - - def get_parents(self): - """Get tables for with this table is child. - - Returns: - set: - Set of parents for this table. - """ - return self._parents - - def get_field(self, field_name): - """Get the metadata dict for a field. - - Args: - field_name (str): - Name of the field to get data for. - - Returns: - dict: - field metadata - - Raises: - ValueError: - If the table or the field do not exist in this metadata. - """ - field_meta = self._metadata['fields'].get(field_name) - if field_meta is None: - raise ValueError('Invalid field name "{}"'.format(field_name)) - - return copy.deepcopy(field_meta) - - # def _read_csv_dtypes(self): - # """Get the dtypes specification that needs to be passed to read_csv.""" - # dtypes = dict() - # for name, field in self._metadata['fields'].items(): - # field_type = field['type'] - # if field_type == 'categorical': - # dtypes[name] = str - # elif field_type == 'id' and field.get('subtype', 'integer') == 'string': - # dtypes[name] = str - - # return dtypes - - # def _parse_dtypes(self, data): - # """Convert the data columns to the right dtype after loading the CSV.""" - # for name, field in self._metadata['fields'].items(): - # field_type = field['type'] - # if field_type == 'datetime': - # datetime_format = field.get('format') - # data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) - # elif field_type == 'numerical' and field.get('subtype') == 'integer': - # data[name] = data[name].dropna().astype(int) - # elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': - # data[name] = data[name].dropna().astype(int) - - # return data - - # def load(self): - # """Load table data. - - # First load the CSV with the right dtypes and then parse the columns - # to the final dtypes. - - # Returns: - # pandas.DataFrame: - # DataFrame with the contents of the table. - # """ - # relative_path = os.path.join(self.root_path, self.path) - # dtypes = self._read_csv_dtypes() - - # data = pd.read_csv(relative_path, dtype=dtypes) - # data = self._parse_dtypes(data) - - # return data - - @staticmethod - def _get_key_subtype(field_meta): - """Get the appropriate key subtype.""" - field_type = field_meta['type'] - - if field_type == 'categorical': - field_subtype = 'string' - - elif field_type in ('numerical', 'id'): - field_subtype = field_meta['subtype'] - if field_subtype not in ('integer', 'string'): - raise ValueError( - 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) - ) - - else: - raise ValueError( - 'Invalid field "type" for key field: "{}"'.format(field_type) - ) - - return field_subtype - - def _check_field(self, field, exists=False): - """Validate the existance of the table and existance (or not) of field.""" - table_fields = self.get_fields(table) - if exists and (field not in table_fields): - raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) - - if not exists and (field in table_fields): - raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) - # ###################### # # Metadata Serialization # # ###################### # diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py index 1eb1dca6d..b14fb160a 100644 --- a/sdv/metadata/visualization.py +++ b/sdv/metadata/visualization.py @@ -18,6 +18,7 @@ def _get_graphviz_extension(path): return None, None + def _add_nodes(metadata, digraph): """Add nodes into a `graphviz.Digraph`. diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 940282fd4..1bdf4b854 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -1,9 +1,6 @@ from sdv.metadata import Table -ANONYMIZATION_MAPS = {} - - class BaseTabularModel(): """Base class for all the tabular models. @@ -11,11 +8,13 @@ class BaseTabularModel(): TabularModels need to implement, as well as common functionality. """ + TRANSFORMER_TEMPLATES = None + _metadata = None def __init__(self, field_names=None, primary_key=None, field_types=None, anonymize_fields=None, constraints=None, table_metadata=None, - model_kwargs=None, *args, **kwargs): + *args, **kwargs): """Initialize a Tabular Model. Args: @@ -32,7 +31,8 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, to generate them. field_types (dict[str, dict]): Dictinary specifying the data types and subtypes - of the fields that will be modeled. + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. anonymize_fields (dict[str, str]): Dict specifying which fields to anonymize and what faker category they belong to. @@ -50,9 +50,9 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, """ if table_metadata is not None: if isinstance(table_metadata, dict): - table_metadata = Table(table_metadata) + table_metadata = Table(table_metadata,) - for arg in (field_names, primary_key, field_types, anonymize, constraints): + for arg in (field_names, primary_key, field_types, anonymize_fields, constraints): if arg: raise ValueError( 'If table_metadata is given {} must be None'.format(arg.__name__)) @@ -66,15 +66,24 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, self._anonymize_fields = anonymize_fields self._constraints = constraints - self._model_kwargs = model_kwargs - def _fit_metadata(self, data): + """Generate a new Table metadata and fit it to the data. + + The information provided will be used to create the Table instance + and then the rest of information will be learned from the given + data. + + Args: + data (pandas.DataFrame): + Data to learn from. + """ metadata = Table( field_names=self._field_names, primary_key=self._primary_key, field_types=self._field_types, anonymize_fields=self._anonymize_fields, constraints=self._constraints, + transformer_templates=self.TRANSFORMER_TEMPLATES, ) metadata.fit(data) @@ -83,7 +92,7 @@ def _fit_metadata(self, data): def fit(self, data): """Fit this model to the data. - If table metadata has not been given, learn it from the data. + If the table metadata has not been given, learn it from the data. Args: data (pandas.DataFrame or str): @@ -96,6 +105,8 @@ def fit(self, data): if self._metadata is None: self._fit_metadata(data) + self._size = len(data) + transformed = self._metadata.transform(data) self._fit(transformed) @@ -133,6 +144,7 @@ def sample(self, size=None, values=None): pandas.DataFrame: Sampled data. """ + size = size or self._size sampled = self._sample(size) return self._metadata.reverse_transform(sampled) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index 389fdc6ff..024205cad 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -1,10 +1,11 @@ -import numpy as np import copulas +import numpy as np +import rdt from sdv.tabular.base import BaseTabularModel from sdv.tabular.utils import ( - check_matrix_symmetric_positive_definite, flatten_dict, make_positive_definite, - square_matrix, unflatten_dict) + check_matrix_symmetric_positive_definite, flatten_dict, make_positive_definite, square_matrix, + unflatten_dict) class GaussianCopula(BaseTabularModel): @@ -13,6 +14,10 @@ class GaussianCopula(BaseTabularModel): Args: distribution (copulas.univariate.Univariate or str): Copulas univariate distribution to use. + categorical_transformer (str): + Type of transformer to use for the categorical variables, to choose + from ``one_hot_encoding``, ``label_encoding``, ``categorical`` and + ``categorical_fuzzy``. """ DISTRIBUTION = copulas.univariate.GaussianUnivariate @@ -23,6 +28,7 @@ class GaussianCopula(BaseTabularModel): 'distribution': { 'type': 'str or copulas.univariate.Univariate', 'default': 'copulas.univariate.Univariate', + 'description': 'Univariate distribution to use to model each column', 'choices': [ 'copulas.univariate.Univariate', 'copulas.univariate.GaussianUnivariate', @@ -32,12 +38,34 @@ class GaussianCopula(BaseTabularModel): 'copulas.univariate.GaussianKDE', 'copulas.univariate.TruncatedGaussian', ] + }, + 'categorical_transformer': { + 'type': 'str', + 'default': 'categoircal_fuzzy', + 'description': 'Type of transformer to use for the categorical variables', + 'choices': [ + 'categorical', + 'categorical_fuzzy', + 'one_hot_encoding', + 'label_encoding' + ] } } + CATEGORICAL_TRANSFORMERS = { + 'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False), + 'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True), + 'one_hot_encoding': rdt.transformers.OneHotEncodingTransformer, + 'label_encoding': rdt.transformers.LabelEncodingTransformer, + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.OneHotEncodingTransformer + } - def __init__(self, distribution=None, *args, **kwargs): + def __init__(self, distribution=None, categorical_transformer='categorical', + *args, **kwargs): super().__init__(*args, **kwargs) self._distribution = distribution or self.DISTRIBUTION + self.TRANSFORMER_TEMPLATES['O'] = self.CATEGORICAL_TRANSFORMERS[categorical_transformer] def _update_metadata(self): parameters = self._model.to_dict() @@ -56,9 +84,9 @@ def _fit(self, data): table_data (pandas.DataFrame): Data to be fitted. """ + params = self._metadata.get_model_params() self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) self._model.fit(data) - # self._update_metadata() def _sample(self, size): """Sample ``size`` rows from the model. diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py new file mode 100644 index 000000000..eb1f87d33 --- /dev/null +++ b/sdv/tabular/ctgan.py @@ -0,0 +1,142 @@ +import rdt + +from sdv.tabular.base import BaseTabularModel + + +class CTGAN(BaseTabularModel): + """Model wrapping ``CTGANSynthesizer`` copula. + + Args: + TBD + """ + + _CTGAN_CLASS = None + _model = None + + HYPERPARAMETERS = { + 'TBD' + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.LabelEncodingTransformer + } + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, constraints=None, table_metadata=None, + epochs=300, log_frequency=True, embedding_dim=128, gen_dim=(256, 256), + dis_dim=(256, 256), l2scale=1e-6, batch_size=500): + """Initialize this CTGAN model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. + anonymize_fields (dict[str, str or tuple]): + Dict specifying which fields to anonymize and what faker + category they belong to. If arguments for the faker need to be + passed to fine tune the value generation a tuple can be passed, + where the first element is the category and the rest are additional + positional arguments for the Faker. + constraints (list[dict]): + List of dicts specifying field and inter-field constraints. + TODO: Format TBD + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + epochs (int): + Number of training epochs. Defaults to 300. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. + embedding_dim (int): + Size of the random sample passed to the Generator. Defaults to 128. + gen_dim (tuple or list of ints): + Size of the output samples for each one of the Residuals. A Resiudal Layer + will be created for each one of the values provided. Defaults to (256, 256). + dis_dim (tuple or list of ints): + Size of the output samples for each one of the Discriminator Layers. A Linear + Layer will be created for each one of the values provided. Defaults to (256, 256). + l2scale (float): + Wheight Decay for the Adam Optimizer. Defaults to 1e-6. + batch_size (int): + Number of data samples to process in each step. + """ + super().__init__( + field_names=field_names, + primary_key=primary_key, + field_types=field_types, + anonymize_fields=anonymize_fields, + constraints=constraints, + table_metadata=table_metadata + ) + try: + from ctgan import CTGANSynthesizer # Lazy import to make dependency optional + + self._CTGAN_CLASS = CTGANSynthesizer + except ImportError as ie: + ie.msg += ( + '\n\nIt seems like `ctgan` is not installed.\n' + 'Please install it using:\n\n pip install ctgan' + ) + raise + + self._embedding_dim = embedding_dim + self._gen_dim = gen_dim + self._dis_dim = dis_dim + self._l2scale = l2scale + self._batch_size = batch_size + self._epochs = epochs + self._log_frequency = log_frequency + + def _fit(self, data): + """Fit the model to the table. + + Args: + data (pandas.DataFrame): + Data to be learned. + """ + self._model = self._CTGAN_CLASS( + embedding_dim=self._embedding_dim, + gen_dim=self._gen_dim, + dis_dim=self._dis_dim, + l2scale=self._l2scale, + batch_size=self._batch_size, + ) + categoricals = [ + field + for field, meta in self._metadata.get_fields().items() + if meta['type'] == 'categorical' + ] + self._model.fit( + data, + epochs=self._epochs, + discrete_columns=categoricals, + log_frequency=self._log_frequency, + ) + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) diff --git a/setup.py b/setup.py index c8a4a56b5..ef713040b 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,17 @@ install_requires = [ 'exrex>=0.9.4,<0.11', - 'numpy>=1.15.4,<1.17', + 'numpy>=1.15.4,<2', 'pandas>=0.23.4,<0.25', - 'copulas>=0.3,<0.4', - 'rdt>=0.2.1,<0.3', + 'copulas>=0.3.1.dev0,<0.4', + 'rdt>=0.2.3.dev0,<0.3', 'graphviz>=0.13.2', - 'sdmetrics>=0.0.1,<0.0.2' + 'sdmetrics>=0.0.2.dev0,<0.0.3', + 'scikit-learn<0.23,>=0.21', +] + +ctgan_requires = [ + 'ctgan>=0.2.2.dev0,<0.3', ] setup_requires = [ @@ -76,8 +81,8 @@ ], description='Automated Generative Modeling and Sampling', extras_require={ - 'test': tests_require, - 'dev': development_requires + tests_require, + 'test': tests_require + ctgan_requires, + 'dev': development_requires + tests_require + ctgan_requires, }, install_package_data=True, install_requires=install_requires, diff --git a/tests/integration/tabular/test_copulas.py b/tests/integration/tabular/test_copulas.py new file mode 100644 index 000000000..037e32f0c --- /dev/null +++ b/tests/integration/tabular/test_copulas.py @@ -0,0 +1,48 @@ +from sdv.demo import load_demo +from sdv.tabular.copulas import GaussianCopula + + +def test_gaussian_copula(): + users = load_demo(metadata=False)['users'] + + field_types = { + 'age': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'country': { + 'type': 'categorical' + } + } + anonymize_fields = { + 'country': 'country_code' + } + + gc = GaussianCopula( + field_names=['user_id', 'country', 'gender', 'age'], + field_types=field_types, + primary_key='user_id', + anonymize_fields=anonymize_fields, + categorical_transformer='one_hot_encoding', + ) + gc.fit(users) + + sampled = gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + # country codes have been replaced with new ones + assert set(sampled.country.unique()) & set(users.country.unique()) == set() + + assert gc.get_metadata().to_dict() == { + 'fields': { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + } diff --git a/tests/integration/tabular/test_ctgan.py b/tests/integration/tabular/test_ctgan.py new file mode 100644 index 000000000..8cc8d4ad6 --- /dev/null +++ b/tests/integration/tabular/test_ctgan.py @@ -0,0 +1,29 @@ +from sdv.demo import load_demo +from sdv.tabular.ctgan import CTGAN + + +def test_ctgan(): + users = load_demo(metadata=False)['users'] + + gc = CTGAN( + primary_key='user_id', + epochs=1 + ) + gc.fit(users) + + sampled = gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + assert gc.get_metadata().to_dict() == { + 'fields': { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + } From a198a2dc5a916d383342605661dca3294ed7e80b Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 19:51:01 +0200 Subject: [PATCH 21/33] Remove unused line --- sdv/tabular/copulas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index 024205cad..dc82dd862 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -84,7 +84,6 @@ def _fit(self, data): table_data (pandas.DataFrame): Data to be fitted. """ - params = self._metadata.get_model_params() self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) self._model.fit(data) From 92b18f75dccc6d5835439a304f5ab6c80723754f Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 20:05:14 +0200 Subject: [PATCH 22/33] Fix tests --- .../test___init__.py} | 121 +--------------- tests/metadata/test_visualization.py | 132 ++++++++++++++++++ 2 files changed, 133 insertions(+), 120 deletions(-) rename tests/{test_metadata.py => metadata/test___init__.py} (87%) create mode 100644 tests/metadata/test_visualization.py diff --git a/tests/test_metadata.py b/tests/metadata/test___init__.py similarity index 87% rename from tests/test_metadata.py rename to tests/metadata/test___init__.py index 812ac7bcd..9968597ae 100644 --- a/tests/test_metadata.py +++ b/tests/metadata/test___init__.py @@ -1,7 +1,6 @@ from unittest import TestCase -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import Mock, call, patch -import graphviz import pandas as pd import pytest @@ -963,121 +962,3 @@ def test_add_field(self): assert metadata._metadata == expected_metadata metadata._check_field.assert_called_once_with('a_table', 'a_field', exists=False) - - def test__get_graphviz_extension_path_without_extension(self): - """Raises a ValueError when the path doesn't contains an extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path') - - def test__get_graphviz_extension_invalid_extension(self): - """Raises a ValueError when the path contains an invalid extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path.foo') - - def test__get_graphviz_extension_none(self): - """Get graphviz with path equals to None.""" - # Run - result = Metadata._get_graphviz_extension(None) - - # Asserts - assert result == (None, None) - - def test__get_graphviz_extension_valid(self): - """Get a valid graphviz extension.""" - # Run - result = Metadata._get_graphviz_extension('/some/path.png') - - # Asserts - assert result == ('/some/path', 'png') - - def test__visualize_add_nodes(self): - """Add nodes into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - minimock = Mock() - - # pass tests in python3.5 - minimock.items.return_value = ( - ('a_field', {'type': 'numerical', 'subtype': 'integer'}), - ('b_field', {'type': 'id'}), - ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) - ) - - metadata.get_tables.return_value = ['demo'] - metadata.get_fields.return_value = minimock - - metadata.get_primary_key.return_value = 'b_field' - metadata.get_parents.return_value = set(['other']) - metadata.get_foreign_key.return_value = 'c_field' - - metadata.get_table_meta.return_value = {'path': None} - - plot = Mock() - - # Run - Metadata._visualize_add_nodes(metadata, plot) - - # Asserts - expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ - r"c_field : id\l|Primary key: b_field\l" \ - r"Foreign key (other): c_field\l}" - - metadata.get_fields.assert_called_once_with('demo') - metadata.get_primary_key.assert_called_once_with('demo') - metadata.get_parents.assert_called_once_with('demo') - metadata.get_table_meta.assert_called_once_with('demo') - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_table_meta.assert_called_once_with('demo') - - plot.node.assert_called_once_with('demo', label=expected_node_label) - - def test__visualize_add_edges(self): - """Add edges into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - - metadata.get_tables.return_value = ['demo', 'other'] - metadata.get_parents.side_effect = [set(['other']), set()] - - metadata.get_foreign_key.return_value = 'fk' - metadata.get_primary_key.return_value = 'pk' - - plot = Mock() - - # Run - Metadata._visualize_add_edges(metadata, plot) - - # Asserts - expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') - - metadata.get_tables.assert_called_once_with() - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_primary_key.assert_called_once_with('other') - assert metadata.get_parents.call_args_list == [call('demo'), call('other')] - - plot.edge.assert_called_once_with( - 'other', - 'demo', - label=expected_edge_label, - arrowhead='crow' - ) - - @patch('sdv.metadata.graphviz') - def test_visualize(self, graphviz_mock): - """Metadata visualize digraph""" - # Setup - plot = Mock(spec_set=graphviz.Digraph) - graphviz_mock.Digraph.return_value = plot - - metadata = MagicMock(spec_set=Metadata) - metadata._get_graphviz_extension.return_value = ('output', 'png') - - # Run - Metadata.visualize(metadata, path='output.png') - - # Asserts - metadata._get_graphviz_extension.assert_called_once_with('output.png') - metadata._visualize_add_nodes.assert_called_once_with(plot) - metadata._visualize_add_edges.assert_called_once_with(plot) - - plot.render.assert_called_once_with(filename='output', cleanup=True, format='png') diff --git a/tests/metadata/test_visualization.py b/tests/metadata/test_visualization.py new file mode 100644 index 000000000..b9f588593 --- /dev/null +++ b/tests/metadata/test_visualization.py @@ -0,0 +1,132 @@ +from unittest.mock import MagicMock, Mock, call, patch + +import graphviz +import pytest + +from sdv.metadata import Metadata, visualization + + +def test__get_graphviz_extension_path_without_extension(): + """Raises a ValueError when the path doesn't contains an extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path') + + +def test__get_graphviz_extension_invalid_extension(): + """Raises a ValueError when the path contains an invalid extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path.foo') + + +def test__get_graphviz_extension_none(): + """Get graphviz with path equals to None.""" + # Run + result = visualization._get_graphviz_extension(None) + + # Asserts + assert result == (None, None) + + +def test__get_graphviz_extension_valid(): + """Get a valid graphviz extension.""" + # Run + result = visualization._get_graphviz_extension('/some/path.png') + + # Asserts + assert result == ('/some/path', 'png') + + +def test__add_nodes(): + """Add nodes into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + minimock = Mock() + + # pass tests in python3.5 + minimock.items.return_value = ( + ('a_field', {'type': 'numerical', 'subtype': 'integer'}), + ('b_field', {'type': 'id'}), + ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) + ) + + metadata.get_tables.return_value = ['demo'] + metadata.get_fields.return_value = minimock + + metadata.get_primary_key.return_value = 'b_field' + metadata.get_parents.return_value = set(['other']) + metadata.get_foreign_key.return_value = 'c_field' + + metadata.get_table_meta.return_value = {'path': None} + + plot = Mock() + + # Run + visualization._add_nodes(metadata, plot) + + # Asserts + expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ + r"c_field : id\l|Primary key: b_field\l" \ + r"Foreign key (other): c_field\l}" + + metadata.get_fields.assert_called_once_with('demo') + metadata.get_primary_key.assert_called_once_with('demo') + metadata.get_parents.assert_called_once_with('demo') + metadata.get_table_meta.assert_called_once_with('demo') + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_table_meta.assert_called_once_with('demo') + + plot.node.assert_called_once_with('demo', label=expected_node_label) + + +def test__add_edges(): + """Add edges into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + + metadata.get_tables.return_value = ['demo', 'other'] + metadata.get_parents.side_effect = [set(['other']), set()] + + metadata.get_foreign_key.return_value = 'fk' + metadata.get_primary_key.return_value = 'pk' + + plot = Mock() + + # Run + visualization._add_edges(metadata, plot) + + # Asserts + expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') + + metadata.get_tables.assert_called_once_with() + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_primary_key.assert_called_once_with('other') + assert metadata.get_parents.call_args_list == [call('demo'), call('other')] + + plot.edge.assert_called_once_with( + 'other', + 'demo', + label=expected_edge_label, + arrowhead='oinv' + ) + + +@patch('sdv.metadata.visualization.graphviz.Digraph', spec_set=graphviz.Digraph) +@patch('sdv.metadata.visualization._add_nodes') +@patch('sdv.metadata.visualization._add_edges') +def test_visualize(add_nodes_mock, add_edges_mock, digraph_mock): + """Metadata visualize digraph""" + # Setup + # plot = Mock(spec_set=graphviz.Digraph) + # graphviz_mock.Digraph.return_value = plot + + metadata = MagicMock(spec_set=Metadata) + + # Run + visualization.visualize(metadata, path='output.png') + + # Asserts + digraph = digraph_mock.return_value + add_nodes_mock.assert_called_once_with(metadata, digraph) + add_edges_mock.assert_called_once_with(metadata, digraph) + + digraph.render.assert_called_once_with(filename='output', cleanup=True, format='png') From 762b26c97cff50a8bf6086e6d6591cdd250a13f3 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 20:12:51 +0200 Subject: [PATCH 23/33] Fix test that fails randomly --- tests/integration/tabular/test_copulas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/tabular/test_copulas.py b/tests/integration/tabular/test_copulas.py index 037e32f0c..a40fbe7bb 100644 --- a/tests/integration/tabular/test_copulas.py +++ b/tests/integration/tabular/test_copulas.py @@ -36,7 +36,7 @@ def test_gaussian_copula(): assert list(sampled['user_id']) == list(range(0, len(users))) # country codes have been replaced with new ones - assert set(sampled.country.unique()) & set(users.country.unique()) == set() + assert set(sampled.country.unique()) != set(users.country.unique()) assert gc.get_metadata().to_dict() == { 'fields': { From 06df78c6271f949068c78b13566a9565f603eef8 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:09:12 +0200 Subject: [PATCH 24/33] Add save/load and fix imports --- sdv/metadata.py.composite | 1054 +++++++++++++++++++++++++++++++++ sdv/metadata.py.non_primary | 1109 +++++++++++++++++++++++++++++++++++ sdv/tabular/__init__.py | 8 + sdv/tabular/base.py | 27 + 4 files changed, 2198 insertions(+) create mode 100644 sdv/metadata.py.composite create mode 100644 sdv/metadata.py.non_primary diff --git a/sdv/metadata.py.composite b/sdv/metadata.py.composite new file mode 100644 index 000000000..10dda41e3 --- /dev/null +++ b/sdv/metadata.py.composite @@ -0,0 +1,1054 @@ +import copy +import json +import logging +import os +from collections import defaultdict + +import graphviz +import numpy as np +import pandas as pd +from rdt import HyperTransformer, transformers + +LOGGER = logging.getLogger(__name__) + + +def _read_csv_dtypes(table_meta): + """Get the dtypes specification that needs to be passed to read_csv.""" + dtypes = dict() + for name, field in table_meta['fields'].items(): + field_type = field['type'] + if field_type == 'categorical': + dtypes[name] = str + elif field_type == 'id' and field.get('subtype', 'integer') == 'string': + dtypes[name] = str + + return dtypes + + +def _parse_dtypes(data, table_meta): + """Convert the data columns to the right dtype after loading the CSV.""" + for name, field in table_meta['fields'].items(): + field_type = field['type'] + if field_type == 'datetime': + datetime_format = field.get('format') + data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) + elif field_type == 'numerical' and field.get('subtype') == 'integer': + data[name] = data[name].dropna().astype(int) + elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': + data[name] = data[name].dropna().astype(int) + + return data + + +def _load_csv(root_path, table_meta): + """Load a CSV with the right dtypes and then parse the columns.""" + relative_path = os.path.join(root_path, table_meta['path']) + dtypes = _read_csv_dtypes(table_meta) + + data = pd.read_csv(relative_path, dtype=dtypes) + data = _parse_dtypes(data, table_meta) + + return data + + +class MetadataError(Exception): + pass + + +class Metadata: + """Dataset Metadata. + + The Metadata class provides a unified layer of abstraction over the dataset + metadata, which includes both the necessary details to load the data from + the hdd and to know how to parse and transform it to numerical data. + + Args: + metadata (str or dict): + Path to a ``json`` file that contains the metadata or a ``dict`` representation + of ``metadata`` following the same structure. + + root_path (str): + The path where the ``metadata.json`` is located. Defaults to ``None``. + """ + + _child_map = None + _hyper_transformers = None + _metadata = None + _parent_map = None + + root_path = None + + _FIELD_TEMPLATES = { + 'i': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'f': { + 'type': 'numerical', + 'subtype': 'float', + }, + 'O': { + 'type': 'categorical', + }, + 'b': { + 'type': 'boolean', + }, + 'M': { + 'type': 'datetime', + } + } + _DTYPES = { + ('categorical', None): 'object', + ('boolean', None): 'bool', + ('numerical', None): 'float', + ('numerical', 'float'): 'float', + ('numerical', 'integer'): 'int', + ('datetime', None): 'datetime64', + ('id', None): 'int', + ('id', 'integer'): 'int', + ('id', 'string'): 'str' + } + + def _analyze_relationships(self): + """Extract information about child-parent relationships. + + Creates the following attributes: + * ``_child_map``: set of child tables that each table has. + * ``_parent_map``: set ot parents that each table has. + """ + self._child_map = defaultdict(set) + self._parent_map = defaultdict(set) + + for table, table_meta in self._metadata['tables'].items(): + if table_meta.get('use', True): + for field_meta in table_meta['fields'].values(): + ref = field_meta.get('ref') + if ref: + parent = ref['table'] + self._child_map[parent].add(table) + self._parent_map[table].add(parent) + + @staticmethod + def _transform_metadata(metadata): + """Ensure metadata has the internal SDV format. + + Convert list of tables and list of fields to dicts. + Ensure primary keys are defined as lists of fields. + + Args: + metadata (dict): + Original metadata to format. + + Returns: + dict: + Formated metadata dict. + """ + new_metadata = copy.deepcopy(metadata) + tables = new_metadata['tables'] + if isinstance(tables, dict): + new_metadata['tables'] = { + table: meta + for table, meta in tables.items() + if meta.pop('use', True) + } + return new_metadata + + new_tables = dict() + for table in tables: + if table.pop('use', True): + new_tables[table.pop('name')] = table + + fields = table['fields'] + new_fields = dict() + for field in fields: + new_fields[field.pop('name')] = field + + table['fields'] = new_fields + + primary_key = table.get('primary_key') + if isinstance(primary_key, str): + table['primary_key'] = [primary_key] + + new_metadata['tables'] = new_tables + + return new_metadata + + def __init__(self, metadata=None, root_path=None): + if isinstance(metadata, str): + self.root_path = root_path or os.path.dirname(metadata) + with open(metadata) as metadata_file: + metadata = json.load(metadata_file) + else: + self.root_path = root_path or '.' + + if metadata is not None: + self._metadata = self._transform_metadata(metadata) + else: + self._metadata = {'tables': {}} + + self._hyper_transformers = dict() + self._analyze_relationships() + + def get_children(self, table_name): + """Get tables for which the given table is parent. + + Args: + table_name (str): + Name of the table from which to get the children. + + Returns: + set: + Set of children for the given table. + """ + return self._child_map[table_name] + + def get_parents(self, table_name): + """Get tables for with the given table is child. + + Args: + table_name (str): + Name of the table from which to get the parents. + + Returns: + set: + Set of parents for the given table. + """ + return self._parent_map[table_name] + + def get_table_meta(self, table_name): + """Get the metadata dict for a table. + + Args: + table_name (str): + Name of table to get data for. + + Returns: + dict: + table metadata + + Raises: + ValueError: + If table does not exist in this metadata. + """ + table = self._metadata['tables'].get(table_name) + if table is None: + raise ValueError('Table "{}" does not exist'.format(table_name)) + + return copy.deepcopy(table) + + def get_tables(self): + """Get the list of table names. + + Returns: + list: + table names. + """ + return list(self._metadata['tables'].keys()) + + def get_fields(self, table_name): + """Get table fields metadata. + + Args: + table_name (str): + Name of the table to get the fields from. + + Returns: + dict: + Mapping of field names and their metadata dicts. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + return self.get_table_meta(table_name)['fields'] + + def get_primary_key(self, table_name): + """Get the primary key name of the indicated table. + + Args: + table_name (str): + Name of table for which to get the primary key field. + + Returns: + list or None: + Primary key field names. ``None`` if the table has no primary key. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + return self.get_table_meta(table_name).get('primary_key') + + def get_foreign_key(self, parent, child): + """Get table foreign key field name. + + Args: + parent (str): + Name of the parent table. + child (str): + Name of the child table. + + Returns: + str or None: + Foreign key field name. + + Raises: + ValueError: + If the relationship does not exist. + """ + primary = self.get_primary_key(parent) + + for name, field in self.get_fields(child).items(): + ref = field.get('ref') + if ref and ref['field'] == primary: + return name + + raise ValueError('{} is not parent of {}'.format(parent, child)) + + def load_table(self, table_name): + """Load table data. + + Args: + table_name (str): + Name of the table to load. + + Returns: + pandas.DataFrame: + DataFrame with the contents of the table. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + LOGGER.info('Loading table %s', table_name) + table_meta = self.get_table_meta(table_name) + return _load_csv(self.root_path, table_meta) + + def load_tables(self, tables=None): + """Get a dictionary with data from multiple tables. + + If a ``tables`` list is given, only load the indicated tables. + Otherwise, load all the tables from this metadata. + + Args: + tables (list): + List of table names. Defaults to ``None``. + + Returns: + dict(str, pandasd.DataFrame): + mapping of table names and their data loaded as ``pandas.DataFrame`` instances. + """ + return { + table_name: self.load_table(table_name) + for table_name in tables or self.get_tables() + } + + def get_dtypes(self, table_name, ids=False): + """Get a ``dict`` with the ``dtypes`` for each field of a given table. + + Args: + table_name (str): + Table name for which to retrive the ``dtypes``. + ids (bool): + Whether or not include the id fields. Defaults to ``False``. + + Returns: + dict: + Dictionary that contains the field names and data types from a table. + + Raises: + ValueError: + If a field has an invalid type or subtype or if the table does not + exist in this metadata. + """ + dtypes = dict() + table_meta = self.get_table_meta(table_name) + for name, field in table_meta['fields'].items(): + field_type = field['type'] + field_subtype = field.get('subtype') + dtype = self._DTYPES.get((field_type, field_subtype)) + if not dtype: + raise MetadataError( + 'Invalid type and subtype combination for field {}: ({}, {})'.format( + name, field_type, field_subtype) + ) + + if ids and field_type == 'id': + if (name not in table_meta.get('primary_key', [])) and not field.get('ref'): + raise MetadataError( + 'id field `{}` is neither a primary or a foreign key'.format(name)) + + if ids or (field_type != 'id'): + dtypes[name] = dtype + + return dtypes + + def _get_pii_fields(self, table_name): + """Get the ``pii_category`` for each field that contains PII. + + Args: + table_name (str): + Table name for which to get the pii fields. + + Returns: + dict: + pii field names and categories. + """ + pii_fields = dict() + for name, field in self.get_table_meta(table_name)['fields'].items(): + if field['type'] == 'categorical' and field.get('pii', False): + pii_fields[name] = field['pii_category'] + + return pii_fields + + @staticmethod + def _get_transformers(dtypes, pii_fields): + """Create the transformer instances needed to process the given dtypes. + + Temporary drop-in replacement of ``HyperTransformer._analyze`` method, + before RDT catches up. + + Args: + dtypes (dict): + mapping of field names and dtypes. + pii_fields (dict): + mapping of pii field names and categories. + + Returns: + dict: + mapping of field names and transformer instances. + """ + transformers_dict = dict() + for name, dtype in dtypes.items(): + dtype = np.dtype(dtype) + if dtype.kind == 'i': + transformer = transformers.NumericalTransformer(dtype=int) + elif dtype.kind == 'f': + transformer = transformers.NumericalTransformer(dtype=float) + elif dtype.kind == 'O': + anonymize = pii_fields.get(name) + transformer = transformers.CategoricalTransformer(anonymize=anonymize) + elif dtype.kind == 'b': + transformer = transformers.BooleanTransformer() + elif dtype.kind == 'M': + transformer = transformers.DatetimeTransformer() + else: + raise ValueError('Unsupported dtype: {}'.format(dtype)) + + LOGGER.info('Loading transformer %s for field %s', + transformer.__class__.__name__, name) + transformers_dict[name] = transformer + + return transformers_dict + + def _load_hyper_transformer(self, table_name): + """Create and return a new ``rdt.HyperTransformer`` instance for a table. + + First get the ``dtypes`` and ``pii fields`` from a given table, then use + those to build a transformer dictionary to be used by the ``HyperTransformer``. + + Args: + table_name (str): + Table name for which to load the HyperTransformer. + + Returns: + rdt.HyperTransformer: + Instance of ``rdt.HyperTransformer`` for the given table. + """ + dtypes = self.get_dtypes(table_name) + pii_fields = self._get_pii_fields(table_name) + transformers_dict = self._get_transformers(dtypes, pii_fields) + return HyperTransformer(transformers=transformers_dict) + + def transform(self, table_name, data): + """Transform data for a given table. + + If the ``HyperTransformer`` for a table is ``None`` it is created. + + Args: + table_name (str): + Name of the table that is being transformer. + data (pandas.DataFrame): + Table data. + + Returns: + pandas.DataFrame: + Transformed data. + """ + hyper_transformer = self._hyper_transformers.get(table_name) + if hyper_transformer is None: + hyper_transformer = self._load_hyper_transformer(table_name) + fields = list(hyper_transformer.transformers.keys()) + hyper_transformer.fit(data[fields]) + self._hyper_transformers[table_name] = hyper_transformer + + hyper_transformer = self._hyper_transformers.get(table_name) + fields = list(hyper_transformer.transformers.keys()) + return hyper_transformer.transform(data[fields]) + + def reverse_transform(self, table_name, data): + """Reverse the transformed data for a given table. + + Args: + table_name (str): + Name of the table to reverse transform. + data (pandas.DataFrame): + Data to be reversed. + + Returns: + pandas.DataFrame + """ + hyper_transformer = self._hyper_transformers[table_name] + reversed_data = hyper_transformer.reverse_transform(data) + + for name, dtype in self.get_dtypes(table_name, ids=True).items(): + reversed_data[name] = reversed_data[name].dropna().astype(dtype) + + return reversed_data + + # ################### # + # Metadata Validation # + # ################### # + + def _validate_table(self, table_name, table_meta, table_data=None): + """Validate table metadata. + + Validate the type and subtype combination for each field in ``table_meta``. + If a field has type ``id``, validate that it either is the ``primary_key`` or + has a ``ref`` entry. + + If the table has ``primary_key``, make sure that the corresponding field exists + and its type is ``id``. + + If ``table_data`` is provided, also check that the list of columns corresponds + to the ones indicated in the metadata and that all the dtypes are valid. + + Args: + table_name (str): + Name of the table to validate. + table_meta (dict): + Metadata of the table to validate. + table_data (pandas.DataFrame): + If provided, make sure that the data matches the one described + on the metadata. + + Raises: + MetadataError: + If there is any error in the metadata or the data does not + match the metadata description. + """ + dtypes = self.get_dtypes(table_name, ids=True) + + # Primary key field exists and its type is 'id' + primary_key_fields = table_meta.get('primary_key', []) + for primary_key in primary_key_fields: + pk_field = table_meta['fields'].get(primary_key) + + if not pk_field: + raise MetadataError('Primary key is not an existing field.') + + if pk_field['type'] != 'id': + raise MetadataError('Primary key is not of type `id`.') + + if table_data is not None: + for column in table_data: + try: + dtype = dtypes.pop(column) + table_data[column].dropna().astype(dtype) + except KeyError: + message = 'Unexpected column in table `{}`: `{}`'.format(table_name, column) + raise MetadataError(message) from None + except ValueError as ve: + message = 'Invalid values found in column `{}` of table `{}`: `{}`'.format( + column, table_name, ve) + raise MetadataError(message) from None + + # assert all dtypes are in data + if dtypes: + raise MetadataError( + 'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys())) + ) + + def _validate_circular_relationships(self, parent, children=None): + """Validate that there is no circular relatioship in the metadata.""" + if children is None: + children = self.get_children(parent) + + if parent in children: + raise MetadataError('Circular relationship found for table "{}"'.format(parent)) + + for child in children: + self._validate_circular_relationships(parent, self.get_children(child)) + + def validate(self, tables=None): + """Validate this metadata. + + For each table from in metadata ``tables`` entry: + * Validate the table metadata is correct. + + * If ``tables`` are provided or they have been loaded, check + that all the metadata tables exists in the ``tables`` dictionary. + * Validate the type/subtype combination for each field and + if a field of type ``id`` exists it must be the ``primary_key`` + or must have a ``ref`` entry. + * If ``primary_key`` entry exists, check that it's an existing + field and its type is ``id``. + * If ``tables`` are provided or they have been loaded, check + all the data types for the table correspond to each column and + all the data types exists on the table. + * Validate that there is no circular relatioship in the metadata. + * Check that all the tables have at most one parent. + + Args: + tables (bool, dict): + If a dict of table is passed, validate that the columns and + dtypes match the metadata. If ``True`` is passed, load the + tables from the Metadata instead. If ``None``, omit the data + validation. Defaults to ``None``. + """ + tables_meta = self._metadata.get('tables') + if not tables_meta: + raise MetadataError('"tables" entry not found in Metadata.') + + if tables and not isinstance(tables, dict): + tables = self.load_tables() + + for table_name, table_meta in tables_meta.items(): + if tables: + table = tables.get(table_name) + if table is None: + raise MetadataError('Table `{}` not found in tables'.format(table_name)) + + else: + table = None + + self._validate_table(table_name, table_meta, table) + self._validate_circular_relationships(table_name) + + def _check_field(self, table, field, exists=False): + """Validate the existance of the table and existance (or not) of field.""" + table_fields = self.get_fields(table) + if exists and (field not in table_fields): + raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) + + if not exists and (field in table_fields): + raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) + + # ################# # + # Metadata Creation # + # ################# # + + def add_field(self, table, field, field_type, field_subtype=None, properties=None): + """Add a new field to the indicated table. + + Args: + table (str): + Table name to add the new field, it must exist. + field (str): + Field name to be added, it must not exist. + field_type (str): + Data type of field to be added. Required. + field_subtype (str): + Data subtype of field to be added. Optional. + Defaults to ``None``. + properties (dict): + Extra properties of field like: ref, format, min, max, etc. Optional. + Defaults to ``None``. + + Raises: + ValueError: + If the table does not exist or it already contains the field. + """ + self._check_field(table, field, exists=False) + + field_details = { + 'type': field_type + } + + if field_subtype: + field_details['subtype'] = field_subtype + + if properties: + field_details.update(properties) + + self._metadata['tables'][table]['fields'][field] = field_details + + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + if field_type == 'categorical': + field_subtype = 'string' + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + + def set_primary_key(self, table, field): + """Set the primary key field of the indicated table. + + The field must exist and either be an integer or categorical field. + + Args: + table (str): + Name of the table where the primary key will be set. + field (str): + Field to be used as the new primary key. + + Raises: + ValueError: + If the table or the field do not exist or if the field has an + invalid type or subtype. + """ + self._check_field(table, field, exists=True) + + field_meta = self.get_fields(table).get(field) + field_subtype = self._get_key_subtype(field_meta) + + table_meta = self._metadata['tables'][table] + table_meta['fields'][field] = { + 'type': 'id', + 'subtype': field_subtype + } + table_meta['primary_key'] = field + + def add_relationship(self, parent, child, foreign_key=None): + """Add a new relationship between the parent and child tables. + + The relationship is created by adding a reference (``ref``) on the ``foreign_key`` + field of the ``child`` table pointing at the ``parent`` primary key. + + Args: + parent (str): + Name of the parent table. + child (str): + Name of the child table. + foreign_key (str): + Field in the child table through which the relationship is created. + If ``None``, use the parent primary key name. + + Raises: + ValueError: + If any of the following happens: + * The parent table does not exist. + * The child table does not exist. + * The parent table does not have a primary key. + * The foreign_key field already exists in the child table. + * The child table already has a parent. + * The new relationship closes a relationship circle. + """ + # Validate table and field names + primary_key = self.get_primary_key(parent) + if not primary_key: + raise ValueError('Parent table "{}" does not have a primary key'.format(parent)) + + if foreign_key is None: + foreign_key = primary_key + + # Validate relationships + if self.get_parents(child): + raise ValueError('Table "{}" already has a parent'.format(child)) + + grandchildren = self.get_children(child) + if grandchildren: + self._validate_circular_relationships(parent, grandchildren) + + # Copy primary key details over to the foreign key + foreign_key_details = copy.deepcopy(self.get_fields(parent)[primary_key]) + foreign_key_details['ref'] = { + 'table': parent, + 'field': primary_key + } + + # Make sure that key subtypes are the same + foreign_meta = self.get_fields(child).get(foreign_key) + if foreign_meta: + foreign_subtype = self._get_key_subtype(foreign_meta) + if foreign_subtype != foreign_key_details['subtype']: + raise ValueError('Primary and Foreign key subtypes mismatch') + + self._metadata['tables'][child]['fields'][foreign_key] = foreign_key_details + + # Re-analyze the relationships + self._analyze_relationships() + + def _get_field_details(self, data, fields): + """Get or build all the fields metadata. + + Analyze a ``pandas.DataFrame`` to build a ``dict`` with the name of the column, and + their data type and subtype. If ``columns`` are provided, only those columns will be + analyzed. + + Args: + data (pandas.DataFrame): + Table to be analyzed. + fields (set): + Set of field names or field specifications. + + Returns: + dict: + Dict of valid fields. + + Raises: + TypeError: + If a field specification is not a str or a dict. + ValueError: + If a column from the data analyzed is an unsupported data type or + """ + fields_metadata = dict() + for field in fields: + dtype = data[field].dtype + field_template = self._FIELD_TEMPLATES.get(dtype.kind) + if not field_template: + raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field)) + + field_details = copy.deepcopy(field_template) + fields_metadata[field] = field_details + + return fields_metadata + + def add_table(self, name, data=None, fields=None, fields_metadata=None, + primary_key=None, parent=None, foreign_key=None): + """Add a new table to this metadata. + + ``fields`` list can be a mixture of field names, which will be build automatically + from the data, or dictionaries specifying the field details. If a field needs to be + analyzed, data has to be also passed. + + If ``parent`` is given, a relationship will be established between this table + and the specified parent. + + Args: + name (str): + Name of the new table. + data (str or pandas.DataFrame): + Table to be analyzed or path to the csv file. + If it's a relative path, use ``root_path`` to find the file. + Only used if fields is not ``None``. + Defaults to ``None``. + fields (list): + List of field names to build. If ``None`` is given, all the fields + found in the data will be used. + Defaults to ``None``. + fields_metadata (dict): + Metadata to be used when creating fields. This will overwrite the + metadata built from the fields found in data. + Defaults to ``None``. + primary_key (str): + Field name to add as primary key, it must not exists. Defaults to ``None``. + parent (str): + Table name to refere a foreign key field. Defaults to ``None``. + foreign_key (str): + Foreing key field name to ``parent`` table primary key. Defaults to ``None``. + + Raises: + ValueError: + If the table ``name`` already exists or ``data`` is not passed and + fields need to be built from it. + """ + if name in self.get_tables(): + raise ValueError('Table "{}" already exists.'.format(name)) + + path = None + if data is not None: + if isinstance(data, str): + path = data + if not os.path.isabs(data): + data = os.path.join(self.root_path, data) + + data = pd.read_csv(data) + + fields = set(fields or data.columns) + if fields_metadata: + fields = fields - set(fields_metadata.keys()) + else: + fields_metadata = dict() + + fields_metadata.update(self._get_field_details(data, fields)) + + elif fields_metadata is None: + fields_metadata = dict() + + table_metadata = {'fields': fields_metadata} + if path: + table_metadata['path'] = path + + self._metadata['tables'][name] = table_metadata + + try: + if primary_key: + self.set_primary_key(name, primary_key) + + if parent: + self.add_relationship(parent, name, foreign_key) + + except ValueError: + # Cleanup + del self._metadata['tables'][name] + raise + + # ###################### # + # Metadata Serialization # + # ###################### # + + def to_dict(self): + """Get a dict representation of this metadata. + + Returns: + dict: + dict representation of this metadata. + """ + return copy.deepcopy(self._metadata) + + def to_json(self, path): + """Dump this metadata into a JSON file. + + Args: + path (str): + Path of the JSON file where this metadata will be stored. + """ + with open(path, 'w') as out_file: + json.dump(self._metadata, out_file, indent=4) + + @staticmethod + def _get_graphviz_extension(path): + if path: + path_splitted = path.split('.') + if len(path_splitted) == 1: + raise ValueError('Path without graphviz extansion.') + + graphviz_extension = path_splitted[-1] + + if graphviz_extension not in graphviz.backend.FORMATS: + raise ValueError( + '"{}" not a valid graphviz extension format.'.format(graphviz_extension) + ) + + return '.'.join(path_splitted[:-1]), graphviz_extension + + return None, None + + def _visualize_add_nodes(self, plot): + """Add nodes into a `graphviz.Digraph`. + + Each node represent a metadata table. + + Args: + plot (graphviz.Digraph) + """ + for table in self.get_tables(): + # Append table fields + fields = [] + + for name, value in self.get_fields(table).items(): + if value.get('subtype') is not None: + fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) + + else: + fields.append('{} : {}'.format(name, value['type'])) + + fields = r'\l'.join(fields) + + # Append table extra information + extras = [] + + primary_key = self.get_primary_key(table) + if primary_key is not None: + extras.append('Primary key: {}'.format(primary_key)) + + parents = self.get_parents(table) + for parent in parents: + foreign_key = self.get_foreign_key(parent, table) + extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) + + path = self.get_table_meta(table).get('path') + if path is not None: + extras.append('Data path: {}'.format(path)) + + extras = r'\l'.join(extras) + + # Add table node + title = r'{%s|%s\l|%s\l}' % (table, fields, extras) + plot.node(table, label=title) + + def _visualize_add_edges(self, plot): + """Add edges into a `graphviz.Digraph`. + + Each edge represents a relationship between two metadata tables. + + Args: + plot (graphviz.Digraph) + """ + for table in self.get_tables(): + for parent in list(self.get_parents(table)): + plot.edge( + parent, + table, + label=' {}.{} -> {}.{}'.format( + table, self.get_foreign_key(parent, table), + parent, self.get_primary_key(parent) + ), + arrowhead='crow' + ) + + def visualize(self, path=None): + """Plot metadata usign graphviz. + + Try to generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + path (str): + Output file path to save the plot, it requires a graphviz + supported extension. If ``None`` do not save the plot. + Defaults to ``None``. + """ + filename, graphviz_extension = self._get_graphviz_extension(path) + plot = graphviz.Digraph( + 'Metadata', + format=graphviz_extension, + node_attr={ + "shape": "Mrecord", + "fillcolor": "lightgoldenrod1", + "style": "filled" + }, + ) + + self._visualize_add_nodes(plot) + self._visualize_add_edges(plot) + + if filename: + plot.render(filename=filename, cleanup=True, format=graphviz_extension) + else: + return plot + + def __str__(self): + tables = self.get_tables() + relationships = [ + ' {}.{} -> {}.{}'.format( + table, self.get_foreign_key(parent, table), + parent, self.get_primary_key(parent) + ) + for table in tables + for parent in list(self.get_parents(table)) + ] + + return ( + "Metadata\n" + " root_path: {}\n" + " tables: {}\n" + " relationships:\n" + "{}" + ).format( + os.path.abspath(self.root_path), + tables, + '\n'.join(relationships) + ) diff --git a/sdv/metadata.py.non_primary b/sdv/metadata.py.non_primary new file mode 100644 index 000000000..2df337bd2 --- /dev/null +++ b/sdv/metadata.py.non_primary @@ -0,0 +1,1109 @@ +import copy +import json +import logging +import os +from collections import defaultdict + +import graphviz +import numpy as np +import pandas as pd +from rdt import HyperTransformer, transformers + +LOGGER = logging.getLogger(__name__) + + +def _read_csv_dtypes(table_meta): + """Get the dtypes specification that needs to be passed to read_csv.""" + dtypes = dict() + for name, field in table_meta['fields'].items(): + field_type = field['type'] + if field_type == 'categorical': + dtypes[name] = str + elif field_type == 'id' and field.get('subtype', 'integer') == 'string': + dtypes[name] = str + + return dtypes + + +def _parse_dtypes(data, table_meta): + """Convert the data columns to the right dtype after loading the CSV.""" + for name, field in table_meta['fields'].items(): + field_type = field['type'] + if field_type == 'datetime': + datetime_format = field.get('format') + data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) + elif field_type == 'numerical' and field.get('subtype') == 'integer': + data[name] = data[name].dropna().astype(int) + elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': + data[name] = data[name].dropna().astype(int) + + return data + + +def _load_csv(root_path, table_meta): + """Load a CSV with the right dtypes and then parse the columns.""" + relative_path = os.path.join(root_path, table_meta['path']) + dtypes = _read_csv_dtypes(table_meta) + + data = pd.read_csv(relative_path, dtype=dtypes) + data = _parse_dtypes(data, table_meta) + + return data + + +class MetadataError(Exception): + pass + + +class Metadata: + """Dataset Metadata. + + The Metadata class provides a unified layer of abstraction over the dataset + metadata, which includes both the necessary details to load the data from + the hdd and to know how to parse and transform it to numerical data. + + Args: + metadata (str or dict): + Path to a ``json`` file that contains the metadata or a ``dict`` representation + of ``metadata`` following the same structure. + + root_path (str): + The path where the ``metadata.json`` is located. Defaults to ``None``. + """ + + _child_map = None + _hyper_transformers = None + _metadata = None + _parent_map = None + + root_path = None + + _FIELD_TEMPLATES = { + 'i': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'f': { + 'type': 'numerical', + 'subtype': 'float', + }, + 'O': { + 'type': 'categorical', + }, + 'b': { + 'type': 'boolean', + }, + 'M': { + 'type': 'datetime', + } + } + _DTYPES = { + ('categorical', None): 'object', + ('boolean', None): 'bool', + ('numerical', None): 'float', + ('numerical', 'float'): 'float', + ('numerical', 'integer'): 'int', + ('datetime', None): 'datetime64', + ('id', None): 'int', + ('id', 'integer'): 'int', + ('id', 'string'): 'str' + } + + def _analyze_relationships(self): + """Extract information about child-parent relationships. + + Creates the following attributes: + * ``_child_map``: set of child tables that each table has. + * ``_parent_map``: set ot parents that each table has. + """ + self._child_map = defaultdict(set) + self._parent_map = defaultdict(set) + + for table, table_meta in self._metadata['tables'].items(): + if table_meta.get('use', True): + for field_meta in table_meta['fields'].values(): + ref = field_meta.get('ref') + if ref: + parent = ref['table'] + self._child_map[parent].add(table) + self._parent_map[table].add(parent) + + @staticmethod + def _dict_metadata(metadata): + """Get a metadata ``dict`` with SDV format. + + For each table create a dict of fields from a previous list of fields. + + Args: + metadata (dict): + Original metadata to format. + + Returns: + dict: + Formated metadata dict. + """ + new_metadata = copy.deepcopy(metadata) + tables = new_metadata['tables'] + if isinstance(tables, dict): + new_metadata['tables'] = { + table: meta + for table, meta in tables.items() + if meta.pop('use', True) + } + return new_metadata + + new_tables = dict() + for table in tables: + if table.pop('use', True): + new_tables[table.pop('name')] = table + + fields = table['fields'] + new_fields = dict() + for field in fields: + new_fields[field.pop('name')] = field + + table['fields'] = new_fields + + new_metadata['tables'] = new_tables + + return new_metadata + + def __init__(self, metadata=None, root_path=None): + if isinstance(metadata, str): + self.root_path = root_path or os.path.dirname(metadata) + with open(metadata) as metadata_file: + metadata = json.load(metadata_file) + else: + self.root_path = root_path or '.' + + if metadata is not None: + self._metadata = self._dict_metadata(metadata) + else: + self._metadata = {'tables': {}} + + self._hyper_transformers = dict() + self._analyze_relationships() + + def get_children(self, table_name): + """Get tables for which the given table is parent. + + Args: + table_name (str): + Name of the table from which to get the children. + + Returns: + set: + Set of children for the given table. + """ + return self._child_map[table_name] + + def get_parents(self, table_name): + """Get tables for with the given table is child. + + Args: + table_name (str): + Name of the table from which to get the parents. + + Returns: + set: + Set of parents for the given table. + """ + return self._parent_map[table_name] + + def get_table_meta(self, table_name): + """Get the metadata dict for a table. + + Args: + table_name (str): + Name of table to get data for. + + Returns: + dict: + table metadata + + Raises: + ValueError: + If table does not exist in this metadata. + """ + table = self._metadata['tables'].get(table_name) + if table is None: + raise ValueError('Table "{}" does not exist'.format(table_name)) + + return copy.deepcopy(table) + + def get_tables(self): + """Get the list of table names. + + Returns: + list: + table names. + """ + return list(self._metadata['tables'].keys()) + + def get_field_meta(self, table_name, field_name): + """Get the metadata dict for a table. + + Args: + table_name (str): + Name of the table to which the field belongs. + field_name (str): + Name of the field to get data for. + + Returns: + dict: + field metadata + + Raises: + ValueError: + If the table or the field do not exist in this metadata. + """ + field_meta = self.get_fields(table_name).get(field_name) + if field_meta is None: + raise ValueError( + 'Table "{}" does not contain a field name "{}"'.format(table_name, field_name)) + + return copy.deepcopy(field_meta) + + def get_fields(self, table_name): + """Get table fields metadata. + + Args: + table_name (str): + Name of the table to get the fields from. + + Returns: + dict: + Mapping of field names and their metadata dicts. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + return self.get_table_meta(table_name)['fields'] + + def get_primary_key(self, table_name): + """Get the primary key name of the indicated table. + + Args: + table_name (str): + Name of table for which to get the primary key field. + + Returns: + str or None: + Primary key field name. ``None`` if the table has no primary key. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + return self.get_table_meta(table_name).get('primary_key') + + def get_foreign_key(self, parent, child): + """Get table foreign key field name. + + Args: + parent (str): + Name of the parent table. + child (str): + Name of the child table. + + Returns: + str or None: + Foreign key field name. + + Raises: + ValueError: + If the relationship does not exist. + """ + for name, field in self.get_fields(child).items(): + ref = field.get('ref') + if ref and ref['table'] == parent: + return name + + raise ValueError('{} is not parent of {}'.format(parent, child)) + + def load_table(self, table_name): + """Load table data. + + Args: + table_name (str): + Name of the table to load. + + Returns: + pandas.DataFrame: + DataFrame with the contents of the table. + + Raises: + ValueError: + If table does not exist in this metadata. + """ + LOGGER.info('Loading table %s', table_name) + table_meta = self.get_table_meta(table_name) + return _load_csv(self.root_path, table_meta) + + def load_tables(self, tables=None): + """Get a dictionary with data from multiple tables. + + If a ``tables`` list is given, only load the indicated tables. + Otherwise, load all the tables from this metadata. + + Args: + tables (list): + List of table names. Defaults to ``None``. + + Returns: + dict(str, pandasd.DataFrame): + mapping of table names and their data loaded as ``pandas.DataFrame`` instances. + """ + return { + table_name: self.load_table(table_name) + for table_name in tables or self.get_tables() + } + + def get_dtypes(self, table_name, ids=False): + """Get a ``dict`` with the ``dtypes`` for each field of a given table. + + Args: + table_name (str): + Table name for which to retrive the ``dtypes``. + ids (bool): + Whether or not include the id fields. Defaults to ``False``. + + Returns: + dict: + Dictionary that contains the field names and data types from a table. + + Raises: + ValueError: + If a field has an invalid type or subtype or if the table does not + exist in this metadata. + """ + dtypes = dict() + table_meta = self.get_table_meta(table_name) + for name, field in table_meta['fields'].items(): + field_type = field['type'] + field_subtype = field.get('subtype') + dtype = self._DTYPES.get((field_type, field_subtype)) + if not dtype: + raise MetadataError( + 'Invalid type and subtype combination for field {}: ({}, {})'.format( + name, field_type, field_subtype) + ) + + if ids and field_type == 'id': + if (name != table_meta.get('primary_key')) and not field.get('ref'): + for child_table in self.get_children(table_name): + if name == self.get_foreign_key(table_name, child_table): + break + + else: + raise MetadataError( + 'id field `{}` is neither a primary or a foreign key'.format(name)) + + if ids or (field_type != 'id'): + dtypes[name] = dtype + + return dtypes + + def _get_pii_fields(self, table_name): + """Get the ``pii_category`` for each field that contains PII. + + Args: + table_name (str): + Table name for which to get the pii fields. + + Returns: + dict: + pii field names and categories. + """ + pii_fields = dict() + for name, field in self.get_table_meta(table_name)['fields'].items(): + if field['type'] == 'categorical' and field.get('pii', False): + pii_fields[name] = field['pii_category'] + + return pii_fields + + @staticmethod + def _get_transformers(dtypes, pii_fields): + """Create the transformer instances needed to process the given dtypes. + + Temporary drop-in replacement of ``HyperTransformer._analyze`` method, + before RDT catches up. + + Args: + dtypes (dict): + mapping of field names and dtypes. + pii_fields (dict): + mapping of pii field names and categories. + + Returns: + dict: + mapping of field names and transformer instances. + """ + transformers_dict = dict() + for name, dtype in dtypes.items(): + dtype = np.dtype(dtype) + if dtype.kind == 'i': + transformer = transformers.NumericalTransformer(dtype=int) + elif dtype.kind == 'f': + transformer = transformers.NumericalTransformer(dtype=float) + elif dtype.kind == 'O': + anonymize = pii_fields.get(name) + transformer = transformers.CategoricalTransformer(anonymize=anonymize) + elif dtype.kind == 'b': + transformer = transformers.BooleanTransformer() + elif dtype.kind == 'M': + transformer = transformers.DatetimeTransformer() + else: + raise ValueError('Unsupported dtype: {}'.format(dtype)) + + LOGGER.info('Loading transformer %s for field %s', + transformer.__class__.__name__, name) + transformers_dict[name] = transformer + + return transformers_dict + + def _load_hyper_transformer(self, table_name): + """Create and return a new ``rdt.HyperTransformer`` instance for a table. + + First get the ``dtypes`` and ``pii fields`` from a given table, then use + those to build a transformer dictionary to be used by the ``HyperTransformer``. + + Args: + table_name (str): + Table name for which to load the HyperTransformer. + + Returns: + rdt.HyperTransformer: + Instance of ``rdt.HyperTransformer`` for the given table. + """ + dtypes = self.get_dtypes(table_name) + pii_fields = self._get_pii_fields(table_name) + transformers_dict = self._get_transformers(dtypes, pii_fields) + return HyperTransformer(transformers=transformers_dict) + + def transform(self, table_name, data): + """Transform data for a given table. + + If the ``HyperTransformer`` for a table is ``None`` it is created. + + Args: + table_name (str): + Name of the table that is being transformer. + data (pandas.DataFrame): + Table data. + + Returns: + pandas.DataFrame: + Transformed data. + """ + hyper_transformer = self._hyper_transformers.get(table_name) + if hyper_transformer is None: + hyper_transformer = self._load_hyper_transformer(table_name) + fields = list(hyper_transformer.transformers.keys()) + hyper_transformer.fit(data[fields]) + self._hyper_transformers[table_name] = hyper_transformer + + hyper_transformer = self._hyper_transformers.get(table_name) + fields = list(hyper_transformer.transformers.keys()) + return hyper_transformer.transform(data[fields]) + + def reverse_transform(self, table_name, data): + """Reverse the transformed data for a given table. + + Args: + table_name (str): + Name of the table to reverse transform. + data (pandas.DataFrame): + Data to be reversed. + + Returns: + pandas.DataFrame + """ + hyper_transformer = self._hyper_transformers[table_name] + reversed_data = hyper_transformer.reverse_transform(data) + + for name, dtype in self.get_dtypes(table_name, ids=True).items(): + reversed_data[name] = reversed_data[name].dropna().astype(dtype) + + return reversed_data + + # ################### # + # Metadata Validation # + # ################### # + + def _validate_table(self, table_name, table_meta, table_data=None): + """Validate table metadata. + + Validate the type and subtype combination for each field in ``table_meta``. + If a field has type ``id``, validate that it either is the ``primary_key`` or + has a ``ref`` entry. + + If the table has ``primary_key``, make sure that the corresponding field exists + and its type is ``id``. + + If ``table_data`` is provided, also check that the list of columns corresponds + to the ones indicated in the metadata and that all the dtypes are valid. + + Args: + table_name (str): + Name of the table to validate. + table_meta (dict): + Metadata of the table to validate. + table_data (pandas.DataFrame): + If provided, make sure that the data matches the one described + on the metadata. + + Raises: + MetadataError: + If there is any error in the metadata or the data does not + match the metadata description. + """ + dtypes = self.get_dtypes(table_name, ids=True) + + # Primary key field exists and its type is 'id' + primary_key = table_meta.get('primary_key') + if primary_key: + pk_field = table_meta['fields'].get(primary_key) + + if not pk_field: + raise MetadataError('Primary key is not an existing field.') + + if pk_field['type'] != 'id': + raise MetadataError('Primary key is not of type `id`.') + + if table_data is not None: + for column in table_data: + try: + dtype = dtypes.pop(column) + table_data[column].dropna().astype(dtype) + except KeyError: + message = 'Unexpected column in table `{}`: `{}`'.format(table_name, column) + raise MetadataError(message) from None + except ValueError as ve: + message = 'Invalid values found in column `{}` of table `{}`: `{}`'.format( + column, table_name, ve) + raise MetadataError(message) from None + + # assert all dtypes are in data + if dtypes: + raise MetadataError( + 'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys())) + ) + + def _validate_circular_relationships(self, parent, children=None): + """Validate that there is no circular relatioship in the metadata.""" + if children is None: + children = self.get_children(parent) + + if parent in children: + raise MetadataError('Circular relationship found for table "{}"'.format(parent)) + + for child in children: + self._validate_circular_relationships(parent, self.get_children(child)) + + def validate(self, tables=None): + """Validate this metadata. + + For each table from in metadata ``tables`` entry: + * Validate the table metadata is correct. + + * If ``tables`` are provided or they have been loaded, check + that all the metadata tables exists in the ``tables`` dictionary. + * Validate the type/subtype combination for each field and + if a field of type ``id`` exists it must be the ``primary_key`` + or must have a ``ref`` entry. + * If ``primary_key`` entry exists, check that it's an existing + field and its type is ``id``. + * If ``tables`` are provided or they have been loaded, check + all the data types for the table correspond to each column and + all the data types exists on the table. + * Validate that there is no circular relatioship in the metadata. + * Check that all the tables have at most one parent. + + Args: + tables (bool, dict): + If a dict of table is passed, validate that the columns and + dtypes match the metadata. If ``True`` is passed, load the + tables from the Metadata instead. If ``None``, omit the data + validation. Defaults to ``None``. + """ + tables_meta = self._metadata.get('tables') + if not tables_meta: + raise MetadataError('"tables" entry not found in Metadata.') + + if tables and not isinstance(tables, dict): + tables = self.load_tables() + + for table_name, table_meta in tables_meta.items(): + if tables: + table = tables.get(table_name) + if table is None: + raise MetadataError('Table `{}` not found in tables'.format(table_name)) + + else: + table = None + + self._validate_table(table_name, table_meta, table) + self._validate_circular_relationships(table_name) + + def _check_field(self, table, field, exists=False): + """Validate the existance of the table and existance (or not) of field.""" + table_fields = self.get_fields(table) + if exists and (field not in table_fields): + raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) + + if not exists and (field in table_fields): + raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) + + # ################# # + # Metadata Creation # + # ################# # + + def add_field(self, table, field, field_type, field_subtype=None, properties=None): + """Add a new field to the indicated table. + + Args: + table (str): + Table name to add the new field, it must exist. + field (str): + Field name to be added, it must not exist. + field_type (str): + Data type of field to be added. Required. + field_subtype (str): + Data subtype of field to be added. Optional. + Defaults to ``None``. + properties (dict): + Extra properties of field like: ref, format, min, max, etc. Optional. + Defaults to ``None``. + + Raises: + ValueError: + If the table does not exist or it already contains the field. + """ + self._check_field(table, field, exists=False) + + field_details = { + 'type': field_type + } + + if field_subtype: + field_details['subtype'] = field_subtype + + if properties: + field_details.update(properties) + + self._metadata['tables'][table]['fields'][field] = field_details + + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + + if field_type == 'categorical': + field_subtype = 'string' + + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + + def set_primary_key(self, table, field): + """Set the primary key field of the indicated table. + + The field must exist and either be an integer or categorical field. + + Args: + table (str): + Name of the table where the primary key will be set. + field (str): + Field to be used as the new primary key. + + Raises: + ValueError: + If the table or the field do not exist or if the field has an + invalid type or subtype. + """ + self._check_field(table, field, exists=True) + + field_meta = self.get_fields(table).get(field) + field_subtype = self._get_key_subtype(field_meta) + + table_meta = self._metadata['tables'][table] + table_meta['fields'][field] = { + 'type': 'id', + 'subtype': field_subtype + } + table_meta['primary_key'] = field + + def add_relationship(self, parent, child, parent_key=None, child_key=None): + """Add a new relationship between the parent and child tables. + + The relationship is created by adding a reference (``ref``) on the ``child_key`` + field of the ``child`` table pointing at the ``parent_key`` field from the + ``parent`` table. + + Args: + parent (str): + Name of the parent table. + child (str): + Name of the child table. + parent_key (str): + Field in the parent table through which the relationship is created. + If ``None``, use the parent primary key name. + child_key (str): + Field in the child table through which the relationship is created. + If ``None``, use the name of the parent key. + + Raises: + ValueError: + If any of the following happens: + * The parent or child tables do not exist. + * The parent_key or child_key fields do not exist. + * The child_key already is a foreign key. + * The new relationship closes a relationship circle. + """ + # Validate tables exists + self.get_table_meta(parent) + self.get_table_meta(child) + + # Validate fields exists + if parent_key is None: + parent_key = self.get_primary_key(parent) + if not parent_key: + msg = 'If parent table does not have a primary key, a `parent_key` must be given' + raise ValueError(msg.format(parent)) + + if child_key is None: + child_key = parent_key + + parent_key_meta = copy.deepcopy(self.get_field_meta(parent, parent_key)) + child_key_meta = copy.deepcopy(self.get_field_meta(child, child_key)) + + # Validate relationships + child_ref = child_key_meta.get('ref') + if child_ref: + raise ValueError( + 'Field "{}.{}" already defines a relationship'.format(child, child_key)) + + grandchildren = self.get_children(child) + if grandchildren: + self._validate_circular_relationships(parent, grandchildren) + + # Make sure that the parent key is an id + if parent_key_meta['type'] != 'id': + parent_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) + parent_key_meta['type'] = 'id' + + # Update the child key meta + child_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) + child_key_meta['type'] = 'id' + child_key_meta['ref'] = { + 'table': parent, + 'field': parent_key + } + + # Make sure that key subtypes are the same + if child_key_meta['subtype'] != parent_key_meta['subtype']: + raise ValueError('Parent and Child key subtypes mismatch') + + # Make a backup + metadata_backup = copy.deepcopy(self._metadata) + + self._metadata['tables'][parent]['fields'][parent_key] = parent_key_meta + self._metadata['tables'][child]['fields'][child_key] = child_key_meta + + # Re-analyze the relationships + self._analyze_relationships() + + try: + self.validate() + except MetadataError: + self._metadata = metadata_backup + raise + + def _get_field_details(self, data, fields): + """Get or build all the fields metadata. + + Analyze a ``pandas.DataFrame`` to build a ``dict`` with the name of the column, and + their data type and subtype. If ``columns`` are provided, only those columns will be + analyzed. + + Args: + data (pandas.DataFrame): + Table to be analyzed. + fields (set): + Set of field names or field specifications. + + Returns: + dict: + Dict of valid fields. + + Raises: + TypeError: + If a field specification is not a str or a dict. + ValueError: + If a column from the data analyzed is an unsupported data type or + """ + fields_metadata = dict() + for field in fields: + dtype = data[field].dtype + field_template = self._FIELD_TEMPLATES.get(dtype.kind) + if not field_template: + raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field)) + + field_details = copy.deepcopy(field_template) + fields_metadata[field] = field_details + + return fields_metadata + + def add_table(self, name, data=None, fields=None, fields_metadata=None, + primary_key=None, parent=None, parent_key=None, foreign_key=None): + """Add a new table to this metadata. + + ``fields`` list can be a mixture of field names, which will be build automatically + from the data, or dictionaries specifying the field details. If a field needs to be + analyzed, data has to be also passed. + + If ``parent`` is given, a relationship will be established between this table + and the specified parent. + + Args: + name (str): + Name of the new table. + data (str or pandas.DataFrame): + Table to be analyzed or path to the csv file. + If it's a relative path, use ``root_path`` to find the file. + Only used if fields is not ``None``. + Defaults to ``None``. + fields (list): + List of field names to build. If ``None`` is given, all the fields + found in the data will be used. + Defaults to ``None``. + fields_metadata (dict): + Metadata to be used when creating fields. This will overwrite the + metadata built from the fields found in data. + Defaults to ``None``. + primary_key (str): + Field name to add as primary key, it must not exists. Defaults to ``None``. + parent (str): + Table name to refere a foreign key field. Defaults to ``None``. + parent_key (str): + Name of the field from the ``parent`` table that is pointed by the given + ``foreign_key``. Defaults to the ``parent`` primary key. + foreign_key (str): + Name of the field from the added table that forms a relationship with + the ``parent`` table. Defaults to the same name as ``parent_key``. + + Raises: + ValueError: + If the table ``name`` already exists or ``data`` is not passed and + fields need to be built from it. + """ + if name in self.get_tables(): + raise ValueError('Table "{}" already exists.'.format(name)) + + path = None + if data is not None: + if isinstance(data, str): + path = data + if not os.path.isabs(data): + data = os.path.join(self.root_path, data) + + data = pd.read_csv(data) + + fields = set(fields or data.columns) + if fields_metadata: + fields = fields - set(fields_metadata.keys()) + else: + fields_metadata = dict() + + fields_metadata.update(self._get_field_details(data, fields)) + + elif fields_metadata is None: + fields_metadata = dict() + + table_metadata = {'fields': fields_metadata} + if path: + table_metadata['path'] = path + + self._metadata['tables'][name] = table_metadata + + try: + if primary_key: + self.set_primary_key(name, primary_key) + + if parent: + self.add_relationship(parent, name, parent_key, foreign_key) + + except ValueError: + # Cleanup + del self._metadata['tables'][name] + raise + + # ###################### # + # Metadata Serialization # + # ###################### # + + def to_dict(self): + """Get a dict representation of this metadata. + + Returns: + dict: + dict representation of this metadata. + """ + return copy.deepcopy(self._metadata) + + def to_json(self, path): + """Dump this metadata into a JSON file. + + Args: + path (str): + Path of the JSON file where this metadata will be stored. + """ + with open(path, 'w') as out_file: + json.dump(self._metadata, out_file, indent=4) + + @staticmethod + def _get_graphviz_extension(path): + if path: + path_splitted = path.split('.') + if len(path_splitted) == 1: + raise ValueError('Path without graphviz extansion.') + + graphviz_extension = path_splitted[-1] + + if graphviz_extension not in graphviz.backend.FORMATS: + raise ValueError( + '"{}" not a valid graphviz extension format.'.format(graphviz_extension) + ) + + return '.'.join(path_splitted[:-1]), graphviz_extension + + return None, None + + def _visualize_add_nodes(self, plot): + """Add nodes into a `graphviz.Digraph`. + + Each node represent a metadata table. + + Args: + plot (graphviz.Digraph) + """ + for table in self.get_tables(): + # Append table fields + fields = [] + + for name, value in self.get_fields(table).items(): + if value.get('subtype') is not None: + fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) + + else: + fields.append('{} : {}'.format(name, value['type'])) + + fields = r'\l'.join(fields) + + # Append table extra information + extras = [] + + primary_key = self.get_primary_key(table) + if primary_key is not None: + extras.append('Primary key: {}'.format(primary_key)) + + parents = self.get_parents(table) + for parent in parents: + foreign_key = self.get_foreign_key(parent, table) + extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) + + path = self.get_table_meta(table).get('path') + if path is not None: + extras.append('Data path: {}'.format(path)) + + extras = r'\l'.join(extras) + + # Add table node + title = r'{%s|%s\l|%s\l}' % (table, fields, extras) + plot.node(table, label=title) + + def _visualize_add_edges(self, plot): + """Add edges into a `graphviz.Digraph`. + + Each edge represents a relationship between two metadata tables. + + Args: + plot (graphviz.Digraph) + """ + for table in self.get_tables(): + for parent in list(self.get_parents(table)): + plot.edge( + parent, + table, + label=' {}.{} -> {}.{}'.format( + table, self.get_foreign_key(parent, table), + parent, self.get_primary_key(parent) + ), + arrowhead='crow' + ) + + def visualize(self, path=None): + """Plot metadata usign graphviz. + + Try to generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + path (str): + Output file path to save the plot, it requires a graphviz + supported extension. If ``None`` do not save the plot. + Defaults to ``None``. + """ + filename, graphviz_extension = self._get_graphviz_extension(path) + plot = graphviz.Digraph( + 'Metadata', + format=graphviz_extension, + node_attr={ + "shape": "Mrecord", + "fillcolor": "lightgoldenrod1", + "style": "filled" + }, + ) + + self._visualize_add_nodes(plot) + self._visualize_add_edges(plot) + + if filename: + plot.render(filename=filename, cleanup=True, format=graphviz_extension) + else: + return plot + + def __str__(self): + tables = self.get_tables() + relationships = [ + ' {}.{} -> {}.{}'.format( + table, self.get_foreign_key(parent, table), + parent, self.get_primary_key(parent) + ) + for table in tables + for parent in list(self.get_parents(table)) + ] + + return ( + "Metadata\n" + " root_path: {}\n" + " tables: {}\n" + " relationships:\n" + "{}" + ).format( + os.path.abspath(self.root_path), + tables, + '\n'.join(relationships) + ) diff --git a/sdv/tabular/__init__.py b/sdv/tabular/__init__.py index e69de29bb..752184c61 100644 --- a/sdv/tabular/__init__.py +++ b/sdv/tabular/__init__.py @@ -0,0 +1,8 @@ +from sdv.tabular.copulas import GaussianCopula +from sdv.tabular.ctgan import CTGAN + + +__all__ = [ + 'GaussianCopula', + 'CTGAN' +] diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 1bdf4b854..70b5aeb7a 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -1,3 +1,5 @@ +import pickle + from sdv.metadata import Table @@ -190,3 +192,28 @@ def from_parameters(cls): using a simple dictionary. """ raise NotImplementedError() + + def save(self, path): + """Save this model instance to the given path using pickle. + + Args: + path (str): + Path where the SDV instance will be serialized. + """ + with open(path, 'wb') as output: + pickle.dump(self, output) + + @classmethod + def load(cls, path): + """Load a TabularModel instance from a given path. + + Args: + path (str): + Path from which to load the instance. + + Returns: + TabularModel: + The loaded tabular model. + """ + with open(path, 'rb') as f: + return pickle.load(f) From bf1624fdc55d30cf9109b3268e3cc421bd26cc56 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:09:25 +0200 Subject: [PATCH 25/33] Cleanup examples and add tutorials --- examples/0. Quickstart - README.ipynb | 414 ------ ...uickstart - Single Table - In Memory.ipynb | 430 ------ .... Quickstart - Single Table - Census.ipynb | 562 -------- .../3. Quickstart - Multitable - Files.ipynb | 653 --------- examples/4. Anonymization.ipynb | 398 ------ examples/6. Metadata Validation.ipynb | 239 ---- .../Aibnb Dataset Subsampling.ipynb | 206 --- .../Aibnb Example.ipynb | 730 ---------- .../metadata.json | 116 -- examples/demo_metadata.json | 74 - examples/quickstart/customers.csv | 8 - examples/quickstart/metadata.json | 96 -- examples/quickstart/order_items.csv | 50 - examples/quickstart/orders.csv | 11 - tutorials/01_Quickstart.ipynb | 838 ++++++++++++ tutorials/02_Single_Table_Modeling.ipynb | 1204 +++++++++++++++++ .../03_Relational_Data_Modeling.ipynb | 854 ++++++------ .../04_Working_with_Metadata.ipynb | 0 18 files changed, 2472 insertions(+), 4411 deletions(-) delete mode 100644 examples/0. Quickstart - README.ipynb delete mode 100644 examples/1. Quickstart - Single Table - In Memory.ipynb delete mode 100644 examples/2. Quickstart - Single Table - Census.ipynb delete mode 100644 examples/3. Quickstart - Multitable - Files.ipynb delete mode 100644 examples/4. Anonymization.ipynb delete mode 100644 examples/6. Metadata Validation.ipynb delete mode 100644 examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb delete mode 100644 examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb delete mode 100644 examples/airbnb-recruiting-new-user-bookings/metadata.json delete mode 100644 examples/demo_metadata.json delete mode 100644 examples/quickstart/customers.csv delete mode 100644 examples/quickstart/metadata.json delete mode 100644 examples/quickstart/order_items.csv delete mode 100644 examples/quickstart/orders.csv create mode 100644 tutorials/01_Quickstart.ipynb create mode 100644 tutorials/02_Single_Table_Modeling.ipynb rename examples/Demo - Walmart.ipynb => tutorials/03_Relational_Data_Modeling.ipynb (55%) rename examples/5. Generate Metadata from Dataframes.ipynb => tutorials/04_Working_with_Metadata.ipynb (100%) diff --git a/examples/0. Quickstart - README.ipynb b/examples/0. Quickstart - README.ipynb deleted file mode 100644 index f6f2fecdc..000000000 --- a/examples/0. Quickstart - README.ipynb +++ /dev/null @@ -1,414 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.visualize()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:24:31,338 - INFO - modeler - Modeling users\n", - "2020-07-02 14:24:31,339 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-07-02 14:24:31,340 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", - "2020-07-02 14:24:31,341 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-07-02 14:24:31,356 - INFO - modeler - Modeling sessions\n", - "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field device\n", - "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field os\n", - "2020-07-02 14:24:31,371 - INFO - modeler - Modeling transactions\n", - "2020-07-02 14:24:31,372 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp\n", - "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer NumericalTransformer for field amount\n", - "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer BooleanTransformer for field approved\n", - "2020-07-02 14:24:31,386 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", - " return getattr(obj, method)(*args, **kwds)\n", - "2020-07-02 14:24:31,404 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " baseCov = np.cov(mat.T)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", - " c *= np.true_divide(1, fact)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", - " c *= np.true_divide(1, fact)\n", - "2020-07-02 14:24:31,413 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,418 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,425 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,431 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,435 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,471 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,485 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,502 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,516 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,530 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,563 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,629 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,734 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 ES F 50\n", - " 1 1 FR NaN 61\n", - " 2 2 UK F 19\n", - " 3 3 USA F 37\n", - " 4 4 UK M 14\n", - " 5 5 USA M 20\n", - " 6 6 FR F 41\n", - " 7 7 UK NaN 15\n", - " 8 8 USA F 17\n", - " 9 9 UK M 27,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 5 mobile android\n", - " 1 1 3 tablet ios\n", - " 2 2 5 mobile ios\n", - " 3 3 2 mobile ios\n", - " 4 4 1 mobile ios\n", - " 5 5 8 mobile ios\n", - " 6 6 3 mobile ios\n", - " 7 7 2 tablet ios\n", - " 8 8 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount \\\n", - " 0 0 1 2019-01-05 19:44:17.109082880 108.520480 \n", - " 1 1 0 2019-01-05 19:44:17.109082880 108.520446 \n", - " 2 2 4 2019-01-05 19:44:17.109082880 108.520456 \n", - " 3 3 4 2019-01-05 19:44:17.109082880 108.520463 \n", - " 4 4 0 2019-01-05 19:44:17.109082880 108.520451 \n", - " 5 5 4 2019-01-14 10:59:23.654829312 -293.744075 \n", - " 6 6 0 2019-01-08 21:06:22.124677120 -157.555578 \n", - " 7 7 4 2019-01-05 19:44:17.109082880 108.520405 \n", - " 8 8 4 2019-01-20 05:49:39.108628736 1436.388992 \n", - " 9 9 0 2019-01-13 01:21:03.496772608 -688.014399 \n", - " 10 10 5 2019-01-15 19:14:13.132234496 92.573682 \n", - " 11 11 0 2019-01-15 19:13:47.834800128 88.733610 \n", - " 12 12 6 2019-01-15 19:13:25.752608512 87.612329 \n", - " 13 13 6 2019-01-15 19:14:06.131340544 91.149239 \n", - " 14 14 4 2019-01-05 19:44:17.109082880 108.520315 \n", - " 15 15 4 2019-01-10 14:54:25.195068416 -405.271997 \n", - " 16 16 0 2019-01-06 11:03:17.424094208 321.267007 \n", - " 17 17 1 2019-01-15 19:13:30.394386944 89.473644 \n", - " 18 18 5 2019-01-15 19:13:23.178177536 87.084668 \n", - " 19 19 5 2019-01-15 19:13:53.814526976 87.328847 \n", - " 20 20 6 2019-01-15 19:13:47.778595840 87.698023 \n", - " 21 21 0 2019-01-05 19:44:17.109082880 108.520444 \n", - " 22 22 4 2019-01-03 13:01:55.231368192 -171.151299 \n", - " 23 23 4 2019-01-10 03:03:38.631104256 -711.512501 \n", - " 24 24 5 2019-01-15 19:13:39.923056128 92.779222 \n", - " 25 25 5 2019-01-15 19:13:54.239550976 87.046705 \n", - " 26 26 1 2019-01-15 19:13:51.475357952 89.768894 \n", - " 27 27 6 2019-01-15 19:13:38.378663168 85.864682 \n", - " 28 28 7 2019-01-20 22:53:01.679367168 84.550410 \n", - " 29 29 7 2019-01-20 22:53:01.672930560 84.739006 \n", - " 30 30 1 2019-01-20 22:53:01.444472832 84.790641 \n", - " 31 31 8 2019-01-20 22:53:01.404470784 84.595276 \n", - " 32 32 4 2019-01-05 19:44:17.109082880 108.520411 \n", - " 33 33 4 2019-01-15 17:22:02.305563904 -602.769803 \n", - " 34 34 1 2019-01-15 04:01:37.883063808 -493.475733 \n", - " 35 35 0 2019-01-15 19:13:20.536071936 82.887047 \n", - " 36 36 5 2019-01-15 19:13:50.685021696 90.577240 \n", - " 37 37 6 2019-01-15 19:14:05.128847616 91.543630 \n", - " 38 38 6 2019-01-15 19:13:10.650768896 90.478404 \n", - " 39 39 7 2019-01-20 22:53:01.680873472 84.876897 \n", - " 40 40 7 2019-01-20 22:53:01.735595520 85.080808 \n", - " 41 41 0 2019-01-20 22:53:01.332932352 84.584933 \n", - " 42 42 8 2019-01-20 22:53:01.451664640 84.779657 \n", - " \n", - " approved \n", - " 0 False \n", - " 1 False \n", - " 2 False \n", - " 3 False \n", - " 4 False \n", - " 5 True \n", - " 6 True \n", - " 7 False \n", - " 8 True \n", - " 9 True \n", - " 10 True \n", - " 11 True \n", - " 12 True \n", - " 13 True \n", - " 14 False \n", - " 15 True \n", - " 16 True \n", - " 17 True \n", - " 18 True \n", - " 19 True \n", - " 20 True \n", - " 21 False \n", - " 22 True \n", - " 23 True \n", - " 24 True \n", - " 25 True \n", - " 26 True \n", - " 27 True \n", - " 28 True \n", - " 29 True \n", - " 30 True \n", - " 31 True \n", - " 32 False \n", - " 33 True \n", - " 34 True \n", - " 35 True \n", - " 36 True \n", - " 37 True \n", - " 38 True \n", - " 39 True \n", - " 40 True \n", - " 41 True \n", - " 42 True }" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sdv.sample_all(10)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/1. Quickstart - Single Table - In Memory.ipynb b/examples/1. Quickstart - Single Table - In Memory.ipynb deleted file mode 100644 index b9f48bbd2..000000000 --- a/examples/1. Quickstart - Single Table - In Memory.ipynb +++ /dev/null @@ -1,430 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "data = pd.DataFrame({\n", - " 'index': [1, 2, 3, 4, 5, 6, 7, 8],\n", - " 'integer': [1, None, 1, 2, 1, 2, 3, 2],\n", - " 'float': [0.1, None, 0.1, 0.2, 0.1, 0.2, 0.3, 0.1],\n", - " 'categorical': ['a', 'b', 'a', 'b', 'a', None, 'c', None],\n", - " 'bool': [False, True, False, True, False, False, False, None],\n", - " 'nullable': [1, None, 3, None, 5, None, 7, None],\n", - " 'datetime': [\n", - " '2010-01-01', '2010-02-01', '2010-01-01', '2010-02-01',\n", - " '2010-01-01', '2010-02-01', '2010-03-01', None\n", - " ]\n", - "})\n", - "data['datetime'] = pd.to_datetime(data['datetime'])\n", - "\n", - "tables = {\n", - " 'data': data\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
011.00.1aFalse1.02010-01-01
12NaNNaNbTrueNaN2010-02-01
231.00.1aFalse3.02010-01-01
342.00.2bTrueNaN2010-02-01
451.00.1aFalse5.02010-01-01
562.00.2NoneFalseNaN2010-02-01
673.00.3cFalse7.02010-03-01
782.00.1NoneNoneNaNNaT
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable datetime\n", - "0 1 1.0 0.1 a False 1.0 2010-01-01\n", - "1 2 NaN NaN b True NaN 2010-02-01\n", - "2 3 1.0 0.1 a False 3.0 2010-01-01\n", - "3 4 2.0 0.2 b True NaN 2010-02-01\n", - "4 5 1.0 0.1 a False 5.0 2010-01-01\n", - "5 6 2.0 0.2 None False NaN 2010-02-01\n", - "6 7 3.0 0.3 c False 7.0 2010-03-01\n", - "7 8 2.0 0.1 None None NaN NaT" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"integer\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"float\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"categorical\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": False,\n", - " \"pii_category\": \"email\"\n", - " },\n", - " {\n", - " \"name\": \"bool\",\n", - " \"type\": \"boolean\",\n", - " },\n", - " {\n", - " \"name\": \"nullable\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"datetime\",\n", - " \"type\": \"datetime\",\n", - " \"format\": \"%Y-%m-%d\"\n", - " },\n", - " ],\n", - " \"name\": \"data\",\n", - " \"primary_key\": \"index\"\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:11,511 - INFO - modeler - Modeling data\n", - "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field integer\n", - "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field float\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer CategoricalTransformer for field categorical\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer BooleanTransformer for field bool\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer NumericalTransformer for field nullable\n", - "2020-07-02 14:25:11,514 - INFO - metadata - Loading transformer DatetimeTransformer for field datetime\n", - "2020-07-02 14:25:11,551 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:11,564 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables={'data': data})" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
00NaNNaNbFalseNaN2010-01-21 11:39:29.986688000
11NaNNaNNaNTrueNaN2010-02-17 13:51:34.408565760
223.00.304326cFalseNaN2010-02-25 12:07:19.982103552
332.00.174580aTrueNaN2010-01-21 04:04:59.336169472
443.00.208637aNaNNaNNaT
551.00.026796bNaNNaNNaT
662.00.166949NaNFalseNaN2010-01-28 23:23:34.873413888
771.00.086972bFalseNaN2010-01-08 09:44:47.101891840
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable \\\n", - "0 0 NaN NaN b False NaN \n", - "1 1 NaN NaN NaN True NaN \n", - "2 2 3.0 0.304326 c False NaN \n", - "3 3 2.0 0.174580 a True NaN \n", - "4 4 3.0 0.208637 a NaN NaN \n", - "5 5 1.0 0.026796 b NaN NaN \n", - "6 6 2.0 0.166949 NaN False NaN \n", - "7 7 1.0 0.086972 b False NaN \n", - "\n", - " datetime \n", - "0 2010-01-21 11:39:29.986688000 \n", - "1 2010-02-17 13:51:34.408565760 \n", - "2 2010-02-25 12:07:19.982103552 \n", - "3 2010-01-21 04:04:59.336169472 \n", - "4 NaT \n", - "5 NaT \n", - "6 2010-01-28 23:23:34.873413888 \n", - "7 2010-01-08 09:44:47.101891840 " - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples = sdv.sample_all()\n", - "samples['data']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/2. Quickstart - Single Table - Census.ipynb b/examples/2. Quickstart - Single Table - Census.ipynb deleted file mode 100644 index 6857b7827..000000000 --- a/examples/2. Quickstart - Single Table - Census.ipynb +++ /dev/null @@ -1,562 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'\n", - "columns = [\n", - " 'age',\n", - " 'workclass',\n", - " 'fnlwgt',\n", - " 'education',\n", - " 'education-num',\n", - " 'marital-status',\n", - " 'occupation',\n", - " 'relationship',\n", - " 'race',\n", - " 'sex',\n", - " 'capital-gain',\n", - " 'capital-loss',\n", - " 'hours-per-week',\n", - " 'native-country',\n", - " 'income'\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num \\\n", - "0 39 State-gov 77516 Bachelors 13 \n", - "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", - "2 38 Private 215646 HS-grad 9 \n", - "3 53 Private 234721 11th 7 \n", - "4 28 Private 338409 Bachelors 13 \n", - "\n", - " marital-status occupation relationship race sex \\\n", - "0 Never-married Adm-clerical Not-in-family White Male \n", - "1 Married-civ-spouse Exec-managerial Husband White Male \n", - "2 Divorced Handlers-cleaners Not-in-family White Male \n", - "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", - "4 Married-civ-spouse Prof-specialty Wife Black Female \n", - "\n", - " capital-gain capital-loss hours-per-week native-country income \n", - "0 2174 0 40 United-States <=50K \n", - "1 0 0 13 United-States <=50K \n", - "2 0 0 40 United-States <=50K \n", - "3 0 0 40 United-States <=50K \n", - "4 0 0 40 Cuba <=50K " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pd.read_csv(url, names=columns)\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'census': df\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"age\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"workclass\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"fnlwgt\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"education\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"education-num\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"marital-status\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"occupation\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"relationship\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"race\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"sex\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"capital-gain\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"capital-loss\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"hours-per-week\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"native-country\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"income\",\n", - " \"type\": \"categorical\",\n", - " }\n", - " ],\n", - " \"name\": \"census\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:31,656 - INFO - modeler - Modeling census\n", - "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer CategoricalTransformer for field workclass\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field fnlwgt\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer CategoricalTransformer for field education\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field education-num\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field marital-status\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field occupation\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field relationship\n", - "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field race\n", - "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field sex\n", - "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-gain\n", - "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-loss\n", - "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer NumericalTransformer for field hours-per-week\n", - "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer CategoricalTransformer for field native-country\n", - "2020-07-02 14:25:31,663 - INFO - metadata - Loading transformer CategoricalTransformer for field income\n", - "2020-07-02 14:25:31,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:31,928 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
035Private468185HS-grad9Married-civ-spouseSalesHusbandWhiteMale203943648United-States<=50K
147Private104447HS-grad7Married-civ-spouseAdm-clericalHusbandWhiteMale-7038153United-States<=50K
220Private231391HS-grad11Never-marriedExec-managerialNot-in-familyWhiteMale654-857United-States<=50K
335Private223275Masters12Married-civ-spouseSalesNot-in-familyWhiteMale-392-17871United-States<=50K
426Private-11408HS-grad8Married-civ-spouseMachine-op-inspctHusbandWhiteMale57991347United-States<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num marital-status \\\n", - "0 35 Private 468185 HS-grad 9 Married-civ-spouse \n", - "1 47 Private 104447 HS-grad 7 Married-civ-spouse \n", - "2 20 Private 231391 HS-grad 11 Never-married \n", - "3 35 Private 223275 Masters 12 Married-civ-spouse \n", - "4 26 Private -11408 HS-grad 8 Married-civ-spouse \n", - "\n", - " occupation relationship race sex capital-gain \\\n", - "0 Sales Husband White Male 2039 \n", - "1 Adm-clerical Husband White Male -703 \n", - "2 Exec-managerial Not-in-family White Male 654 \n", - "3 Sales Not-in-family White Male -392 \n", - "4 Machine-op-inspct Husband White Male 5799 \n", - "\n", - " capital-loss hours-per-week native-country income \n", - "0 436 48 United-States <=50K \n", - "1 81 53 United-States <=50K \n", - "2 -8 57 United-States <=50K \n", - "3 -178 71 United-States <=50K \n", - "4 13 47 United-States <=50K " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled = sdv.sample('census', num_rows=len(df))\n", - "sampled['census'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "-43.151506729008716" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from sdv.evaluation import evaluate\n", - "\n", - "samples = sdv.sample_all(len(tables['census']))\n", - "\n", - "evaluate(samples, real=tables, metadata=sdv.metadata)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/3. Quickstart - Multitable - Files.ipynb b/examples/3. Quickstart - Multitable - Files.ipynb deleted file mode 100644 index d00a5989e..000000000 --- a/examples/3. Quickstart - Multitable - Files.ipynb +++ /dev/null @@ -1,653 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:44,066 - INFO - modeler - Modeling customers\n", - "2020-07-02 14:25:44,067 - INFO - metadata - Loading table customers\n", - "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer CategoricalTransformer for field cust_postal_code\n", - "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer NumericalTransformer for field phone_number1\n", - "2020-07-02 14:25:44,075 - INFO - metadata - Loading transformer NumericalTransformer for field credit_limit\n", - "2020-07-02 14:25:44,076 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-07-02 14:25:44,094 - INFO - modeler - Modeling orders\n", - "2020-07-02 14:25:44,095 - INFO - metadata - Loading table orders\n", - "2020-07-02 14:25:44,098 - INFO - metadata - Loading transformer NumericalTransformer for field order_total\n", - "2020-07-02 14:25:44,101 - INFO - modeler - Modeling order_items\n", - "2020-07-02 14:25:44,102 - INFO - metadata - Loading table order_items\n", - "2020-07-02 14:25:44,106 - INFO - metadata - Loading transformer CategoricalTransformer for field product_id\n", - "2020-07-02 14:25:44,107 - INFO - metadata - Loading transformer NumericalTransformer for field unit_price\n", - "2020-07-02 14:25:44,108 - INFO - metadata - Loading transformer NumericalTransformer for field quantity\n", - "2020-07-02 14:25:44,120 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,138 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,147 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,155 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,164 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,171 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,183 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,189 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,196 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,210 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,231 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,258 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,281 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,303 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,370 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,496 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('quickstart/metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:46,641 - INFO - metadata - Loading table customers\n", - "2020-07-02 14:25:46,646 - INFO - metadata - Loading table orders\n", - "2020-07-02 14:25:46,649 - INFO - metadata - Loading table order_items\n" - ] - } - ], - "source": [ - "real = sdv.metadata.load_tables()\n", - "\n", - "samples = sdv.sample_all(len(real['customers']), reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
00631457811758514510US
11201666720163478317SPAIN
221137174478617191371SPAIN
336314568350267041024US
44113715371158268739FRANCE
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 0 63145 7811758514 510 US\n", - "1 1 20166 6720163478 317 SPAIN\n", - "2 2 11371 7447861719 1371 SPAIN\n", - "3 3 63145 6835026704 1024 US\n", - "4 4 11371 5371158268 739 FRANCE" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
0501137161755532951000UK
14631458605551835500US
297338810609670355521431000CANADA
36304076314570355521432000US
48263621137161755532951000UK
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 50 11371 6175553295 1000 UK\n", - "1 4 63145 8605551835 500 US\n", - "2 97338810 6096 7035552143 1000 CANADA\n", - "3 630407 63145 7035552143 2000 US\n", - "4 826362 11371 6175553295 1000 UK" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
0001784
1111864
2242677
3311485
4412435
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 0 0 1784\n", - "1 1 1 1864\n", - "2 2 4 2677\n", - "3 3 1 1485\n", - "4 4 1 2435" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
01502310
1241507
21097338810730
3655996144730
4355996144939
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 1 50 2310\n", - "1 2 4 1507\n", - "2 10 97338810 730\n", - "3 6 55996144 730\n", - "4 3 55996144 939" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
00262087
117101162
22010576
3336-241
4436384
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 0 2 6 208 7\n", - "1 1 7 10 116 2\n", - "2 2 0 10 57 6\n", - "3 3 3 6 -24 1\n", - "4 4 3 6 38 4" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['order_items'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
0100107528
1101861254
2102161254
3103491254
4104191134
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 100 10 7 52 8\n", - "1 101 8 6 125 4\n", - "2 102 1 6 125 4\n", - "3 103 4 9 125 4\n", - "4 104 1 9 113 4" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['order_items'].head()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/4. Anonymization.ipynb b/examples/4. Anonymization.ipynb deleted file mode 100644 index 5ff4caa51..000000000 --- a/examples/4. Anonymization.ipynb +++ /dev/null @@ -1,398 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "data = pd.DataFrame([\n", - " {\n", - " 'index': 1,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '1111222233334444'\n", - " },\n", - " {\n", - " 'index': 2,\n", - " 'name': 'Jeff',\n", - " 'credit_card_number': '0000000000000000'\n", - " },\n", - " {\n", - " 'index': 3,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '9999999999999999'\n", - " },\n", - " {\n", - " 'index': 4,\n", - " 'name': 'Joe',\n", - " 'credit_card_number': '8888888888888888'\n", - " },\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
credit_card_numberindexname
011112222333344441Bill
100000000000000002Jeff
299999999999999993Bill
388888888888888884Joe
\n", - "
" - ], - "text/plain": [ - " credit_card_number index name\n", - "0 1111222233334444 1 Bill\n", - "1 0000000000000000 2 Jeff\n", - "2 9999999999999999 3 Bill\n", - "3 8888888888888888 4 Joe" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": \"first_name\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": [\n", - " \"credit_card_number\",\n", - " \"visa\"\n", - " ]\n", - " }\n", - " ],\n", - " \"name\": \"anonymized\",\n", - " \"primary_key\": \"index\",\n", - " },\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\"\n", - " }\n", - " ],\n", - " \"name\": \"normal\",\n", - " \"primary_key\": \"index\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'anonymized': data,\n", - " 'normal': data.copy()\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:26:26,044 - INFO - modeler - Modeling anonymized\n", - "2020-07-02 14:26:26,044 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-07-02 14:26:26,045 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-07-02 14:26:26,087 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:26:26,092 - INFO - modeler - Modeling normal\n", - "2020-07-02 14:26:26,092 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-07-02 14:26:26,093 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-07-02 14:26:26,109 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:26:26,113 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "sampled = sdv.sample_all()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Pamela4801288395665668
11Kimberly4801288395665668
22Pamela4801288395665668
33Kimberly4592405566223480
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Pamela 4801288395665668\n", - "1 1 Kimberly 4801288395665668\n", - "2 2 Pamela 4801288395665668\n", - "3 3 Kimberly 4592405566223480" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['anonymized']" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Joe8888888888888888
11Joe0000000000000000
22Joe8888888888888888
33Bill8888888888888888
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Joe 8888888888888888\n", - "1 1 Joe 0000000000000000\n", - "2 2 Joe 8888888888888888\n", - "3 3 Bill 8888888888888888" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['normal']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/6. Metadata Validation.ipynb b/examples/6. Metadata Validation.ipynb deleted file mode 100644 index 70ab58913..000000000 --- a/examples/6. Metadata Validation.ipynb +++ /dev/null @@ -1,239 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate(tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'country'" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: id field `user_id` is neither a primary or a foreign key\n" - ] - } - ], - "source": [ - "from sdv.metadata import MetadataError\n", - "\n", - "try:\n", - " metadata.validate()\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'user_id'" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['fields']['gender']['type'] = 'numerical'" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: Invalid values found in column `gender` of table `users`: `could not convert string to float: 'M'`\n" - ] - } - ], - "source": [ - "try:\n", - " metadata.validate(tables)\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb deleted file mode 100644 index d4b39df8b..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb +++ /dev/null @@ -1,206 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Subsampling\n", - "\n", - "This notebooks shows how to generate the `users_sample.csv` and `sessions_sample.csv` files.\n", - "Before running this notebook, make sure to having downloaded the `train_users_2.csv` and `sessions.csv`\n", - "files from https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings/data" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "users = pd.read_csv('train_users_2.csv')\n", - "sessions = pd.read_csv('sessions.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(213451, 16)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "users.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(10567737, 6)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample = users.sample(1000).reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample = sessions[\n", - " sessions['user_id'].isin(users_sample['id'])\n", - "].reset_index(drop=True).copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(24988, 6)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions_sample.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "user_ids = users_sample['id'].reset_index().set_index('id')['index']" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id\n", - "c2cga16ny5 0\n", - "08yitv7nz0 1\n", - "3oobewtnls 2\n", - "owd2csvj6u 3\n", - "i66dlbdyrt 4\n", - "Name: index, dtype: int64" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "user_ids.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample['id'] = users_sample['id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample['user_id'] = sessions_sample['user_id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample.to_csv('users_sample.csv', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample.to_csv('sessions_sample.csv', index=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb deleted file mode 100644 index e4b43f916..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb +++ /dev/null @@ -1,730 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Synthesis\n", - "\n", - "This notebook shows a usage example over a sample of the Airbnb New User Bookings dataset\n", - "available on Kaggle: https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings\n", - "\n", - "Before running this notebook, make sure to having downloaded the data and run the\n", - "`Airbnb Dataset Sampling` notebook which can be found in the same folder as this one." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:17,764 - INFO - modeler - Modeling users\n", - "2019-11-03 16:07:17,764 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:17,781 - INFO - metadata - Loading transformer DatetimeTransformer for field date_account_created\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp_first_active\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field date_first_booking\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_method\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_flow\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field language\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_channel\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_provider\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field first_affiliate_tracked\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_app\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_device_type\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_browser\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field country_destination\n", - "2019-11-03 16:07:17,852 - INFO - modeler - Modeling sessions\n", - "2019-11-03 16:07:17,853 - INFO - metadata - Loading table sessions\n", - "2019-11-03 16:07:17,872 - INFO - metadata - Loading transformer CategoricalTransformer for field action\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_type\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_detail\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field device_type\n", - "2019-11-03 16:07:17,874 - INFO - metadata - Loading transformer NumericalTransformer for field secs_elapsed\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " baseCov = np.cov(mat.T)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", - " c *= np.true_divide(1, fact)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", - " c *= np.true_divide(1, fact)\n", - "2019-11-03 16:07:22,394 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:22,443 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:22,454 - INFO - metadata - Loading table sessions\n" - ] - } - ], - "source": [ - "real = sdv.metadata.get_tables()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
7907902011-04-302011-04-30 17:58:152012-03-22FEMALE36.0basic2ensem-non-brandgoogleuntrackedWebMac DesktopFirefoxUS
5175172011-09-202011-09-20 08:14:062011-09-22-unknown-NaNbasic2ensem-brandgoogleomgWebMac DesktopChromeUS
9029022014-05-172014-05-17 01:03:00NaTMALE48.0basic24endirectdirectuntrackedMowebiPhoneMobile SafariNDF
70702013-10-212013-10-21 21:50:15NaT-unknown-NaNbasic0endirectdirectlinkedWebMac DesktopFirefoxNDF
6086082013-04-012013-04-01 20:05:132013-04-02-unknown-NaNbasic24endirectdirectlinkedMowebMac DesktopSafariother
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active date_first_booking \\\n", - "790 790 2011-04-30 2011-04-30 17:58:15 2012-03-22 \n", - "517 517 2011-09-20 2011-09-20 08:14:06 2011-09-22 \n", - "902 902 2014-05-17 2014-05-17 01:03:00 NaT \n", - "70 70 2013-10-21 2013-10-21 21:50:15 NaT \n", - "608 608 2013-04-01 2013-04-01 20:05:13 2013-04-02 \n", - "\n", - " gender age signup_method signup_flow language affiliate_channel \\\n", - "790 FEMALE 36.0 basic 2 en sem-non-brand \n", - "517 -unknown- NaN basic 2 en sem-brand \n", - "902 MALE 48.0 basic 24 en direct \n", - "70 -unknown- NaN basic 0 en direct \n", - "608 -unknown- NaN basic 24 en direct \n", - "\n", - " affiliate_provider first_affiliate_tracked signup_app first_device_type \\\n", - "790 google untracked Web Mac Desktop \n", - "517 google omg Web Mac Desktop \n", - "902 direct untracked Moweb iPhone \n", - "70 direct linked Web Mac Desktop \n", - "608 direct linked Moweb Mac Desktop \n", - "\n", - " first_browser country_destination \n", - "790 Firefox US \n", - "517 Chrome US \n", - "902 Mobile Safari NDF \n", - "70 Firefox NDF \n", - "608 Safari other " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
6345101search_resultsclickview_search_resultsMac Desktop1592.0
541772lookupNaNNaNMac Desktop651.0
22406823showviewp3Android Phone1198.0
6279101search_resultsclickview_search_resultsMac Desktop3389.0
6561883ajax_refresh_subtotalclickchange_trip_characteristicsMac Desktop268.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type \\\n", - "6345 101 search_results click \n", - "5417 72 lookup NaN \n", - "22406 823 show view \n", - "6279 101 search_results click \n", - "6561 883 ajax_refresh_subtotal click \n", - "\n", - " action_detail device_type secs_elapsed \n", - "6345 view_search_results Mac Desktop 1592.0 \n", - "5417 NaN Mac Desktop 651.0 \n", - "22406 p3 Android Phone 1198.0 \n", - "6279 view_search_results Mac Desktop 3389.0 \n", - "6561 change_trip_characteristics Mac Desktop 268.0 " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['sessions'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "length = len(real['users'])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1000" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "length" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "samples = sdv.sample_all(length, reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
8578572014-03-16 21:04:00.3134963202014-03-15 05:39:43.249984000NaT-unknown-45.0facebook0endirectdirectuntrackedWebiPhoneIENDF
4894892014-10-03 17:46:33.6981690882014-10-14 06:11:35.937526784NaTFEMALE56.0basic0endirectdirectuntrackedWebWindows Desktop-unknown-NDF
7807802014-03-02 14:25:57.2513733122014-03-05 14:28:40.767416832NaT-unknown-NaNbasic0endirectdirectlinkedWebWindows DesktopSafariNDF
5965962014-01-12 05:40:57.1618867202014-01-18 03:57:01.785481984NaTMALE51.0basic0ensem-brandgooglelinkedWebMac DesktopSafariNDF
7767762013-05-28 14:15:41.3111449602013-05-28 06:13:45.2424908802013-08-07 07:45:31.016155904FEMALE46.0basic0endirectdirectuntrackedWebMac DesktopChromeNDF
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active \\\n", - "857 857 2014-03-16 21:04:00.313496320 2014-03-15 05:39:43.249984000 \n", - "489 489 2014-10-03 17:46:33.698169088 2014-10-14 06:11:35.937526784 \n", - "780 780 2014-03-02 14:25:57.251373312 2014-03-05 14:28:40.767416832 \n", - "596 596 2014-01-12 05:40:57.161886720 2014-01-18 03:57:01.785481984 \n", - "776 776 2013-05-28 14:15:41.311144960 2013-05-28 06:13:45.242490880 \n", - "\n", - " date_first_booking gender age signup_method signup_flow \\\n", - "857 NaT -unknown- 45.0 facebook 0 \n", - "489 NaT FEMALE 56.0 basic 0 \n", - "780 NaT -unknown- NaN basic 0 \n", - "596 NaT MALE 51.0 basic 0 \n", - "776 2013-08-07 07:45:31.016155904 FEMALE 46.0 basic 0 \n", - "\n", - " language affiliate_channel affiliate_provider first_affiliate_tracked \\\n", - "857 en direct direct untracked \n", - "489 en direct direct untracked \n", - "780 en direct direct linked \n", - "596 en sem-brand google linked \n", - "776 en direct direct untracked \n", - "\n", - " signup_app first_device_type first_browser country_destination \n", - "857 Web iPhone IE NDF \n", - "489 Web Windows Desktop -unknown- NDF \n", - "780 Web Windows Desktop Safari NDF \n", - "596 Web Mac Desktop Safari NDF \n", - "776 Web Mac Desktop Chrome NDF " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
7985222ajax_refresh_subtotalNaNlisting_reviewsWindows Desktop6730.0
23883629ajax_refresh_subtotaldata-unknown-Windows Desktop4842.0
15908418termsclickview_search_resultsMac Desktop111772.0
8900250personalizedataNaNMac Desktop20486.0
26026691search-unknown--unknown-Mac Desktop149526.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type action_detail \\\n", - "7985 222 ajax_refresh_subtotal NaN listing_reviews \n", - "23883 629 ajax_refresh_subtotal data -unknown- \n", - "15908 418 terms click view_search_results \n", - "8900 250 personalize data NaN \n", - "26026 691 search -unknown- -unknown- \n", - "\n", - " device_type secs_elapsed \n", - "7985 Windows Desktop 6730.0 \n", - "23883 Windows Desktop 4842.0 \n", - "15908 Mac Desktop 111772.0 \n", - "8900 Mac Desktop 20486.0 \n", - "26026 Mac Desktop 149526.0 " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['sessions'].sample(5)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/metadata.json b/examples/airbnb-recruiting-new-user-bookings/metadata.json deleted file mode 100644 index 4c03506bf..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/metadata.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "tables": [ - { - "headers": true, - "name": "users", - "path": "users_sample.csv", - "primary_key": "id", - "fields": [ - { - "name": "id", - "type": "id" - }, - { - "name": "date_account_created", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "timestamp_first_active", - "type": "datetime", - "format": "%Y%m%d%H%M%S" - }, - { - "name": "date_first_booking", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "gender", - "type": "categorical" - }, - { - "name": "age", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "signup_method", - "type": "categorical" - }, - { - "name": "signup_flow", - "type": "categorical" - }, - { - "name": "language", - "type": "categorical" - }, - { - "name": "affiliate_channel", - "type": "categorical" - }, - { - "name": "affiliate_provider", - "type": "categorical" - }, - { - "name": "first_affiliate_tracked", - "type": "categorical" - }, - { - "name": "signup_app", - "type": "categorical" - }, - { - "name": "first_device_type", - "type": "categorical" - }, - { - "name": "first_browser", - "type": "categorical" - }, - { - "name": "country_destination", - "type": "categorical" - } - ] - }, - { - "headers": true, - "name": "sessions", - "path": "sessions_sample.csv", - "fields": [ - { - "name": "user_id", - "ref": { - "field": "id", - "table": "users" - }, - "type": "id" - }, - { - "name": "action", - "type": "categorical" - }, - { - "name": "action_type", - "type": "categorical" - }, - { - "name": "action_detail", - "type": "categorical" - }, - { - "name": "device_type", - "type": "categorical" - }, - { - "name": "secs_elapsed", - "type": "numerical", - "subtype": "integer" - } - ] - } - ] -} diff --git a/examples/demo_metadata.json b/examples/demo_metadata.json deleted file mode 100644 index 1becd1764..000000000 --- a/examples/demo_metadata.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "tables": { - "users": { - "fields": { - "country": { - "type": "categorical" - }, - "user_id": { - "type": "id", - "subtype": "integer" - }, - "gender": { - "type": "categorical" - }, - "age": { - "type": "numerical", - "subtype": "integer" - } - }, - "primary_key": "user_id" - }, - "sessions": { - "fields": { - "device": { - "type": "categorical" - }, - "session_id": { - "type": "id", - "subtype": "integer" - }, - "user_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "users", - "field": "user_id" - } - }, - "os": { - "type": "categorical" - } - }, - "primary_key": "session_id" - }, - "transactions": { - "fields": { - "timestamp": { - "type": "datetime", - "format": "%Y-%m-%d" - }, - "amount": { - "type": "numerical", - "subtype": "float" - }, - "session_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "sessions", - "field": "session_id" - } - }, - "transaction_id": { - "type": "id", - "subtype": "integer" - }, - "approved": { - "type": "boolean" - } - }, - "primary_key": "transaction_id" - } - } -} \ No newline at end of file diff --git a/examples/quickstart/customers.csv b/examples/quickstart/customers.csv deleted file mode 100644 index 14922885f..000000000 --- a/examples/quickstart/customers.csv +++ /dev/null @@ -1,8 +0,0 @@ -customer_id,cust_postal_code,phone_number1,credit_limit,country -50,11371,6175553295,1000,UK -4,63145,8605551835,500,US -97338810,6096,7035552143,1000,CANADA -630407,63145,7035552143,2000,US -826362,11371,6175553295,1000,UK -55996144,20166,4045553285,1000,FRANCE -598112,63145,6175553295,1000,SPAIN diff --git a/examples/quickstart/metadata.json b/examples/quickstart/metadata.json deleted file mode 100644 index 119d30c76..000000000 --- a/examples/quickstart/metadata.json +++ /dev/null @@ -1,96 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "customer_id", - "type": "id" - }, - { - "name": "cust_postal_code", - "type": "categorical" - }, - { - "name": "phone_number1", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "credit_limit", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "country", - "type": "categorical" - } - ], - "headers": true, - "name": "customers", - "path": "customers.csv", - "primary_key": "customer_id", - "use": true - }, - { - "fields": [ - { - "name": "order_id", - "type": "id" - }, - { - "name": "customer_id", - "ref": { - "field": "customer_id", - "table": "customers" - }, - "type": "id" - }, - { - "name": "order_total", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "orders", - "path": "orders.csv", - "primary_key": "order_id", - "use": true - }, - { - "fields": [ - { - "name": "order_item_id", - "type": "id" - }, - { - "name": "order_id", - "ref": { - "field": "order_id", - "table": "orders" - }, - "type": "id" - }, - { - "name": "product_id", - "type": "categorical" - }, - { - "name": "unit_price", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "quantity", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "order_items", - "path": "order_items.csv", - "primary_key": "order_item_id", - "use": true - } - ] -} diff --git a/examples/quickstart/order_items.csv b/examples/quickstart/order_items.csv deleted file mode 100644 index 7f47c3d69..000000000 --- a/examples/quickstart/order_items.csv +++ /dev/null @@ -1,50 +0,0 @@ -order_item_id,order_id,product_id,unit_price,quantity -100,10,7,52,8 -101,8,6,125,4 -102,1,6,125,4 -103,4,9,125,4 -104,1,9,113,4 -105,9,10,87,2 -106,10,6,39,4 -107,1,6,50,4 -108,2,3,31,2 -109,4,6,37,3 -110,10,10,50,4 -111,1,10,50,2 -112,6,14,50,1 -113,10,1,161,5 -114,2,2,73,4 -115,1,6,50,2 -116,5,6,63,2 -117,1,6,115,4 -118,4,10,50,4 -119,5,9,50,2 -120,10,15,37,0 -121,7,6,113,7 -122,1,6,125,2 -123,10,6,113,7 -124,8,10,125,4 -125,5,6,125,2 -126,1,6,125,2 -127,8,10,69,2 -128,8,10,50,2 -129,6,8,152,3 -130,5,2,103,4 -131,3,11,30,-1 -132,9,6,97,2 -133,7,1,165,5 -134,8,6,125,2 -135,8,6,50,2 -136,1,9,125,4 -137,6,8,152,3 -138,1,10,50,2 -139,3,6,37,3 -140,6,15,111,1 -141,9,10,125,2 -142,10,2,147,4 -143,3,6,58,4 -144,1,6,102,4 -145,9,10,50,4 -146,6,15,50,1 -147,7,3,30,-1 -148,1,6,125,4 diff --git a/examples/quickstart/orders.csv b/examples/quickstart/orders.csv deleted file mode 100644 index 08371dbb8..000000000 --- a/examples/quickstart/orders.csv +++ /dev/null @@ -1,11 +0,0 @@ -order_id,customer_id,order_total -1,50,2310 -2,4,1507 -10,97338810,730 -6,55996144,730 -3,55996144,939 -4,50,2380 -5,97338810,1570 -7,50,730 -8,97338810,2336 -9,4,743 diff --git a/tutorials/01_Quickstart.ipynb b/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/02_Single_Table_Modeling.ipynb b/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..0f5649a3e --- /dev/null +++ b/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:41:29,959 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:41:29,961 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:41:29,962 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:41:29,981 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAF29
11FRF44
22UKM38
33ESM19
44USAF55
55USANaN27
66USAF27
77UKNaN7
88FRF55
99UKNaN43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA F 29\n", + "1 1 FR F 44\n", + "2 2 UK M 38\n", + "3 3 ES M 19\n", + "4 4 USA F 55\n", + "5 5 USA NaN 27\n", + "6 6 USA F 27\n", + "7 7 UK NaN 7\n", + "8 8 FR F 55\n", + "9 9 UK NaN 43" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00BGM30
11USANaN61
22USAM35
33ESNaN34
44USAF41
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 BG M 30\n", + "1 1 USA NaN 61\n", + "2 2 USA M 35\n", + "3 3 ES NaN 34\n", + "4 4 USA F 41" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:48:49,870 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:51:21,668 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:51:21,669 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 20:51:21,670 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 20:51:21,670 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 20:51:21,671 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 20:51:21,672 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 20:51:21,672 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 20:51:21,673 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 20:51:21,673 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 20:51:21,674 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 20:51:21,674 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 20:51:21,675 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 20:51:21,675 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 20:51:21,676 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 20:51:21,678 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9379, Loss D: -0.1297\n", + "Epoch 2, Loss G: 1.6529, Loss D: 0.1716\n", + "Epoch 3, Loss G: 1.0939, Loss D: 0.0208\n", + "Epoch 4, Loss G: 0.9390, Loss D: -0.0918\n", + "Epoch 5, Loss G: -0.0550, Loss D: 0.0415\n", + "Epoch 6, Loss G: -0.0864, Loss D: -0.0104\n", + "Epoch 7, Loss G: -0.3865, Loss D: 0.0074\n", + "Epoch 8, Loss G: -1.0150, Loss D: -0.0708\n", + "Epoch 9, Loss G: -0.8685, Loss D: -0.1289\n", + "Epoch 10, Loss G: -1.1524, Loss D: -0.0487\n" + ] + } + ], + "source": [ + "model.fit(census)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
038Private212843Prof-school10Never-marriedTech-supportUnmarriedWhiteMale49-140United-States<=50K
130Federal-gov321819HS-grad14Married-civ-spouseCraft-repairUnmarriedBlackFemale51-339United-States<=50K
225Private169771HS-grad13Never-marriedOther-serviceNot-in-familyWhiteFemale-37-239United-States<=50K
347Private116751Some-college6DivorcedHandlers-cleanersNot-in-familyWhiteMale0205741United-States<=50K
438Private315119Assoc-acdm13Never-marriedCraft-repairHusbandBlackMale-11-149United-States<=50K
536State-gov172646Masters6Married-civ-spouseMachine-op-inspctWifeWhiteMale-50-240Japan<=50K
640Private163368Some-college4Married-civ-spouseTransport-movingOwn-childWhiteFemale-13539United-States<=50K
723Private369324Some-college11Married-civ-spouseTech-supportHusbandWhiteFemale-60540United-States<=50K
832Private192521Bachelors10Married-civ-spouseExec-managerialNot-in-familyWhiteFemale-15243United-States<=50K
928Private244118HS-grad9SeparatedProf-specialtyNot-in-familyWhiteMale-21039United-States>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 38 Private 212843 Prof-school 10 \n", + "1 30 Federal-gov 321819 HS-grad 14 \n", + "2 25 Private 169771 HS-grad 13 \n", + "3 47 Private 116751 Some-college 6 \n", + "4 38 Private 315119 Assoc-acdm 13 \n", + "5 36 State-gov 172646 Masters 6 \n", + "6 40 Private 163368 Some-college 4 \n", + "7 23 Private 369324 Some-college 11 \n", + "8 32 Private 192521 Bachelors 10 \n", + "9 28 Private 244118 HS-grad 9 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Tech-support Unmarried White Male \n", + "1 Married-civ-spouse Craft-repair Unmarried Black Female \n", + "2 Never-married Other-service Not-in-family White Female \n", + "3 Divorced Handlers-cleaners Not-in-family White Male \n", + "4 Never-married Craft-repair Husband Black Male \n", + "5 Married-civ-spouse Machine-op-inspct Wife White Male \n", + "6 Married-civ-spouse Transport-moving Own-child White Female \n", + "7 Married-civ-spouse Tech-support Husband White Female \n", + "8 Married-civ-spouse Exec-managerial Not-in-family White Female \n", + "9 Separated Prof-specialty Not-in-family White Male \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 49 -1 40 United-States <=50K \n", + "1 51 -3 39 United-States <=50K \n", + "2 -37 -2 39 United-States <=50K \n", + "3 0 2057 41 United-States <=50K \n", + "4 -11 -1 49 United-States <=50K \n", + "5 -50 -2 40 Japan <=50K \n", + "6 -13 5 39 United-States <=50K \n", + "7 -60 5 40 United-States <=50K \n", + "8 -15 2 43 United-States <=50K \n", + "9 -21 0 39 United-States >50K " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-28.22245129315498" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/Demo - Walmart.ipynb b/tutorials/03_Relational_Data_Modeling.ipynb similarity index 55% rename from examples/Demo - Walmart.ipynb rename to tutorials/03_Relational_Data_Modeling.ipynb index 2e3303259..230061b00 100644 --- a/examples/Demo - Walmart.ipynb +++ b/tutorials/03_Relational_Data_Modeling.ipynb @@ -4,7 +4,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Demo \"walmart\"\n", + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", "\n", "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", "\n", @@ -91,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "scrolled": true }, @@ -100,9 +104,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:09,230 - INFO - metadata - Loading table stores\n", - "2020-07-02 18:50:09,237 - INFO - metadata - Loading table features\n", - "2020-07-02 18:50:09,255 - INFO - metadata - Loading table depts\n" + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" ] } ], @@ -128,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -184,7 +188,7 @@ "\n", "stores->features\n", "\n", - "\n", + "\n", "   features.Store -> stores.Store\n", "\n", "\n", @@ -206,17 +210,17 @@ "\n", "stores->depts\n", "\n", - "\n", + "\n", "   depts.Store -> stores.Store\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -234,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -252,7 +256,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "scrolled": false }, @@ -261,128 +265,130 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:09,639 - INFO - modeler - Modeling stores\n", - "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", - "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", - "2020-07-02 18:50:09,653 - INFO - modeler - Modeling features\n", - "2020-07-02 18:50:09,653 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", - "2020-07-02 18:50:09,655 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", - "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", - "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", - "2020-07-02 18:50:09,659 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", - "2020-07-02 18:50:09,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,760 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,817 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,871 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,929 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,958 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,076 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,158 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,184 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,211 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,239 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,268 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,297 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,325 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,352 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,378 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,410 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,440 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,467 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,494 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,519 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,547 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,572 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,602 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,628 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,656 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,686 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,768 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,799 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,827 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,858 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,886 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,939 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,967 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,996 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,028 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,058 - INFO - modeler - Modeling depts\n", - "2020-07-02 18:50:11,058 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 18:50:11,169 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,445 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,461 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,474 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,489 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,505 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,536 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,568 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,603 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:11,618 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,634 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,683 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,698 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,728 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,741 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,754 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,769 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,782 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,800 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,814 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,829 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,844 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,859 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,874 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,887 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,900 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,914 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,928 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,940 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,955 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,003 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,017 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,032 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,058 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,069 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,084 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", - " return getattr(obj, method)(*args, **kwds)\n", - "2020-07-02 18:50:12,380 - INFO - modeler - Modeling Complete\n" + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" ] } ], @@ -413,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -422,7 +428,7 @@ "{'stores': 45, 'features': 8190, 'depts': 421570}" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -433,7 +439,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "scrolled": false }, @@ -451,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -484,32 +490,32 @@ " \n", " 0\n", " B\n", - " 85496\n", - " 3\n", + " 106106\n", + " 0\n", " \n", " \n", " 1\n", - " A\n", - " 178862\n", - " 4\n", + " B\n", + " 68071\n", + " 1\n", " \n", " \n", " 2\n", - " B\n", - " 69654\n", - " 5\n", + " C\n", + " 253603\n", + " 2\n", " \n", " \n", " 3\n", " A\n", - " 211981\n", - " 6\n", + " 223268\n", + " 3\n", " \n", " \n", " 4\n", - " A\n", - " 131188\n", - " 7\n", + " B\n", + " 166848\n", + " 4\n", " \n", " \n", "\n", @@ -517,14 +523,14 @@ ], "text/plain": [ " Type Size Store\n", - "0 B 85496 3\n", - "1 A 178862 4\n", - "2 B 69654 5\n", - "3 A 211981 6\n", - "4 A 131188 7" + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -535,7 +541,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -576,78 +582,78 @@ " \n", " \n", " 0\n", - " 2012-04-23 03:09:52.638271488\n", - " NaN\n", - " 3\n", + " 2009-12-13 04:59:49.832576000\n", + " 222.268822\n", + " 0\n", " False\n", + " 758.900545\n", + " -1080.569435\n", + " 4.574966\n", + " 6.052795\n", + " 51.354606\n", + " -4016.012141\n", " NaN\n", - " -9803.940199\n", - " 3.561375\n", - " 8.838728\n", - " 67.162475\n", - " NaN\n", - " 2703.17729\n", - " 186.471991\n", + " 224.973722\n", " \n", " \n", " 1\n", - " 2011-04-19 08:45:12.429521664\n", - " 483.892955\n", - " 3\n", + " 2011-07-20 14:40:40.400244736\n", + " 6857.712028\n", + " 0\n", " False\n", - " 7504.524416\n", " NaN\n", - " 3.495118\n", - " 7.360667\n", - " 42.785730\n", - " 2772.597105\n", + " -551.045257\n", + " 3.018296\n", + " 6.720669\n", + " 80.723862\n", + " NaN\n", " NaN\n", - " 192.048268\n", + " 204.595072\n", " \n", " \n", " 2\n", - " 2011-01-30 22:13:59.841415680\n", - " NaN\n", - " 3\n", + " 2011-10-12 08:23:54.987658240\n", + " 5530.929425\n", + " 0\n", " False\n", " NaN\n", " NaN\n", - " 3.361946\n", - " 7.524812\n", - " 34.945770\n", + " 3.466358\n", + " 7.256570\n", + " 73.954938\n", " NaN\n", " NaN\n", - " 192.100673\n", + " 202.591941\n", " \n", " \n", " 3\n", - " 2011-10-08 08:16:00.977235968\n", - " 4661.175670\n", - " 3\n", + " 2012-11-26 09:22:59.377291776\n", + " NaN\n", + " 0\n", " False\n", - " 3392.028528\n", - " -5837.649003\n", - " 2.994273\n", - " 7.993152\n", - " 66.818180\n", " NaN\n", + " 3703.137191\n", + " 3.623725\n", + " 6.598615\n", + " 72.041454\n", " NaN\n", - " 197.921916\n", + " 3925.456611\n", + " 169.457078\n", " \n", " \n", " 4\n", - " 2011-09-29 23:18:43.912751616\n", + " 2013-05-13 15:22:12.213717760\n", " NaN\n", - " 3\n", + " 0\n", " False\n", + " 3773.942572\n", " NaN\n", + " 3.381795\n", + " 5.115002\n", + " 83.058322\n", + " 15018.944605\n", " NaN\n", - " 3.116659\n", - " 4.945393\n", - " 48.979036\n", - " NaN\n", - " 2957.77612\n", - " 192.677797\n", + " 182.801974\n", " \n", " \n", "\n", @@ -655,28 +661,28 @@ ], "text/plain": [ " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", - "0 2012-04-23 03:09:52.638271488 NaN 3 False NaN \n", - "1 2011-04-19 08:45:12.429521664 483.892955 3 False 7504.524416 \n", - "2 2011-01-30 22:13:59.841415680 NaN 3 False NaN \n", - "3 2011-10-08 08:16:00.977235968 4661.175670 3 False 3392.028528 \n", - "4 2011-09-29 23:18:43.912751616 NaN 3 False NaN \n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", "\n", - " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", - "0 -9803.940199 3.561375 8.838728 67.162475 NaN \n", - "1 NaN 3.495118 7.360667 42.785730 2772.597105 \n", - "2 NaN 3.361946 7.524812 34.945770 NaN \n", - "3 -5837.649003 2.994273 7.993152 66.818180 NaN \n", - "4 NaN 3.116659 4.945393 48.979036 NaN \n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", "\n", - " MarkDown2 CPI \n", - "0 2703.17729 186.471991 \n", - "1 NaN 192.048268 \n", - "2 NaN 192.100673 \n", - "3 NaN 197.921916 \n", - "4 2957.77612 192.677797 " + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -687,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -721,42 +727,42 @@ " \n", " \n", " 0\n", - " 2011-04-22 15:24:18.057608704\n", - " 11196.989134\n", - " 3\n", - " 38\n", + " 2012-04-19 04:40:04.370382848\n", + " 7523.202470\n", + " 0\n", + " 36\n", " False\n", " \n", " \n", " 1\n", - " 2010-02-06 19:21:02.431538688\n", - " -14038.529503\n", - " 3\n", - " 29\n", + " 2012-10-24 12:49:45.630176512\n", + " 13675.425498\n", + " 0\n", + " 17\n", " False\n", " \n", " \n", " 2\n", - " 2012-06-04 16:23:09.227934976\n", - " -6519.738485\n", - " 3\n", - " 46\n", + " 2012-06-05 05:02:35.551761408\n", + " 2402.017324\n", + " 0\n", + " 31\n", " False\n", " \n", " \n", " 3\n", - " 2011-08-09 17:18:54.910250752\n", - " 23194.918038\n", - " 3\n", - " 45\n", + " 2012-05-31 22:15:14.903856896\n", + " -12761.073176\n", + " 0\n", + " 81\n", " False\n", " \n", " \n", " 4\n", - " 2010-09-01 23:10:54.986872576\n", - " 16761.426407\n", - " 3\n", - " 29\n", + " 2011-09-07 15:34:43.239835648\n", + " 3642.817612\n", + " 0\n", + " 58\n", " False\n", " \n", " \n", @@ -765,14 +771,14 @@ ], "text/plain": [ " Date Weekly_Sales Store Dept IsHoliday\n", - "0 2011-04-22 15:24:18.057608704 11196.989134 3 38 False\n", - "1 2010-02-06 19:21:02.431538688 -14038.529503 3 29 False\n", - "2 2012-06-04 16:23:09.227934976 -6519.738485 3 46 False\n", - "3 2011-08-09 17:18:54.910250752 23194.918038 3 45 False\n", - "4 2010-09-01 23:10:54.986872576 16761.426407 3 29 False" + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -792,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": { "scrolled": false }, @@ -801,198 +807,198 @@ "data": { "text/plain": [ "{'features': Date MarkDown1 Store IsHoliday \\\n", - " 0 2012-03-09 19:29:04.121987584 4547.265066 48 False \n", - " 1 2012-01-12 12:05:27.772331264 NaN 51 False \n", - " 2 2012-10-09 23:52:58.454752256 13291.983302 51 False \n", - " 3 2011-07-26 23:58:33.986580992 NaN 49 False \n", - " 4 2012-01-17 02:27:51.821400832 6865.756770 50 False \n", - " 5 2009-12-27 02:32:16.625306880 NaN 51 False \n", - " 6 2011-07-12 23:01:03.955540224 NaN 48 False \n", - " 7 2011-09-25 10:16:19.542280704 7114.339579 48 False \n", - " 8 2011-05-06 02:00:24.919497728 NaN 48 False \n", - " 9 2013-10-23 10:17:59.945650432 2267.991754 50 False \n", - " 10 2011-05-18 08:31:22.419635968 NaN 48 False \n", - " 11 2010-07-27 18:16:54.600127744 NaN 49 False \n", - " 12 2010-11-04 00:37:47.996870656 NaN 51 False \n", - " 13 2012-03-23 15:04:51.172607488 10192.668685 48 False \n", - " 14 2007-11-14 22:28:51.711726848 NaN 48 False \n", - " 15 2013-02-07 20:55:46.684175872 9146.998153 49 False \n", - " 16 2011-07-18 00:03:42.998976512 NaN 50 False \n", - " 17 2010-10-02 21:34:09.917561088 NaN 49 False \n", - " 18 2010-03-12 18:25:31.509475328 NaN 51 False \n", - " 19 2013-06-24 02:50:15.027540736 2943.979267 50 False \n", - " 20 2010-09-24 15:42:39.457930752 NaN 48 False \n", - " 21 2012-03-27 19:00:12.877346816 7750.605869 51 False \n", - " 22 2010-05-26 23:28:27.091251968 NaN 48 False \n", - " 23 2011-05-01 23:07:37.680983296 NaN 52 False \n", - " 24 2012-10-07 16:56:39.076689152 11838.898938 51 False \n", - " 25 2011-02-19 21:38:01.903840000 NaN 52 False \n", - " 26 2012-01-18 06:02:24.764590592 7428.686729 52 False \n", - " 27 2012-09-28 01:36:19.603453184 4962.148181 52 False \n", - " 28 2011-07-28 19:21:49.952510208 12773.901269 48 False \n", - " 29 2011-09-04 16:53:45.561477888 NaN 50 False \n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", " .. ... ... ... ... \n", - " 970 2011-06-14 18:48:47.646437632 7563.489728 48 False \n", - " 971 2011-05-11 09:27:03.321238016 3763.579757 50 False \n", - " 972 2012-10-27 08:44:16.987041024 -2107.559543 49 False \n", - " 973 2011-03-26 17:36:28.323915264 NaN 48 False \n", - " 974 2012-03-30 09:18:45.289797888 NaN 52 False \n", - " 975 2012-12-25 00:57:24.386340608 3483.807111 51 False \n", - " 976 2010-09-17 18:59:15.931293440 NaN 52 False \n", - " 977 2012-04-18 06:40:20.102041344 2869.173605 50 False \n", - " 978 2010-01-18 10:21:26.293353216 NaN 49 False \n", - " 979 2010-08-13 00:08:32.324416256 NaN 50 False \n", - " 980 2012-01-17 04:04:26.614297088 NaN 50 False \n", - " 981 2011-06-23 01:40:48.209999104 NaN 48 False \n", - " 982 2012-05-15 04:01:02.762140160 3935.159830 50 False \n", - " 983 2011-03-03 10:20:57.530126848 5704.369183 52 False \n", - " 984 2011-09-11 11:54:13.537061888 -3939.448245 49 False \n", - " 985 2012-07-01 08:38:12.122587648 NaN 49 False \n", - " 986 2013-01-15 14:24:51.253277952 7802.046675 52 False \n", - " 987 2011-03-16 12:56:22.305051648 NaN 49 False \n", - " 988 2012-02-09 21:21:51.383019776 3563.474259 50 False \n", - " 989 2011-04-28 07:35:57.312682752 NaN 49 False \n", - " 990 2012-07-06 11:32:42.706951424 4134.684319 48 False \n", - " 991 2012-09-10 11:27:22.375229184 2670.997722 48 False \n", - " 992 2010-04-19 00:59:34.282525696 NaN 51 False \n", - " 993 2012-05-21 02:19:56.773740288 725.161402 51 False \n", - " 994 2012-08-25 02:46:15.815232768 2736.817373 52 False \n", - " 995 2011-11-27 15:29:59.302161664 NaN 48 False \n", - " 996 2012-11-19 23:22:47.651802880 NaN 49 False \n", - " 997 2012-03-06 18:23:43.093342208 NaN 51 True \n", - " 998 2011-11-17 03:00:20.669487872 NaN 50 False \n", - " 999 2010-10-18 08:04:04.443984128 NaN 51 False \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", " \n", " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", - " 0 2262.924323 4314.258484 2.944899 12.207513 44.866684 \n", - " 1 648.327530 NaN 3.415876 8.296209 79.755000 \n", - " 2 7464.648400 4818.694540 3.438865 10.375869 53.011465 \n", - " 3 7722.936675 NaN 3.238451 8.340640 81.331769 \n", - " 4 385.720439 2933.746873 3.352123 4.517977 33.216711 \n", - " 5 NaN NaN 2.595952 8.423675 58.153432 \n", - " 6 NaN NaN 3.337125 7.289977 68.198691 \n", - " 7 1376.144152 -1375.405701 3.586827 4.433586 57.457667 \n", - " 8 NaN 885.964999 3.132711 7.941049 82.624443 \n", - " 9 4380.021000 4861.726545 4.617105 7.636990 64.616449 \n", - " 10 NaN NaN 3.490697 9.699187 113.890285 \n", - " 11 NaN NaN 3.337917 10.428851 76.354063 \n", - " 12 NaN NaN 3.016818 10.044523 52.925513 \n", - " 13 NaN 1505.002284 3.244111 7.329469 34.127333 \n", - " 14 NaN NaN 2.497412 12.247448 100.350997 \n", - " 15 7982.863954 7539.825400 3.909380 5.689230 65.945165 \n", - " 16 NaN NaN 2.895094 7.718521 54.750702 \n", - " 17 NaN NaN 3.722465 7.339952 56.152632 \n", - " 18 NaN NaN 2.665366 8.045471 78.257558 \n", - " 19 -3310.004199 -808.378062 4.379650 5.320308 39.604265 \n", - " 20 NaN 4604.622951 3.058964 8.584222 44.177016 \n", - " 21 3790.217545 3053.642374 3.091990 5.069122 25.777591 \n", - " 22 NaN NaN 2.931042 9.062068 41.649615 \n", - " 23 NaN NaN 3.721348 6.586825 67.785117 \n", - " 24 4610.034450 -617.120861 4.010528 8.686732 35.445785 \n", - " 25 NaN NaN 3.249523 8.450699 38.639637 \n", - " 26 4005.965014 -28.792693 3.808513 9.188265 57.584018 \n", - " 27 6433.867544 -3996.702966 2.976900 8.991618 6.179927 \n", - " 28 NaN NaN 3.208463 11.875517 50.400717 \n", - " 29 NaN NaN 3.455848 7.377382 56.480212 \n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", " .. ... ... ... ... ... \n", - " 970 7130.653517 613.434417 3.798558 8.849840 45.081481 \n", - " 971 -950.449794 8369.741392 3.691277 8.687173 21.625873 \n", - " 972 1423.082811 3915.095969 3.716980 7.812599 38.891102 \n", - " 973 NaN NaN 3.277952 7.881991 74.018427 \n", - " 974 NaN NaN 3.444190 4.147388 79.128181 \n", - " 975 5645.978692 3893.099917 3.821702 8.410268 87.162204 \n", - " 976 NaN NaN 2.640021 7.458197 35.588448 \n", - " 977 -2627.096894 -142.014936 3.320107 6.970687 49.700872 \n", - " 978 NaN NaN 2.710125 9.927722 83.348488 \n", - " 979 NaN NaN 3.139506 8.348364 46.708000 \n", - " 980 NaN NaN 3.578857 7.555184 25.595379 \n", - " 981 NaN -1147.726465 3.685507 9.407928 91.745352 \n", - " 982 375.025287 11130.134824 3.467141 6.285321 38.179293 \n", - " 983 NaN NaN 2.675670 10.258619 50.599203 \n", - " 984 NaN NaN 3.375990 9.175979 44.036130 \n", - " 985 NaN 649.171815 3.370836 6.948002 22.863602 \n", - " 986 1944.149217 896.663662 3.971550 6.242637 97.107161 \n", - " 987 NaN NaN 3.524965 7.059279 65.469921 \n", - " 988 2550.079436 121.672409 2.898533 5.929999 58.612442 \n", - " 989 NaN NaN 3.237647 6.992924 72.225708 \n", - " 990 NaN NaN 4.075160 8.392256 80.837964 \n", - " 991 -1543.390402 5264.766267 3.157407 7.017727 81.515109 \n", - " 992 NaN NaN 3.293817 8.719239 65.735676 \n", - " 993 932.568195 1845.259380 3.157982 7.494639 65.103553 \n", - " 994 -2905.284812 -1268.636611 3.776062 6.736775 46.083801 \n", - " 995 NaN NaN 3.424730 NaN 79.231321 \n", - " 996 9027.387381 NaN 3.960882 8.911136 26.948635 \n", - " 997 NaN NaN 3.764028 8.049413 111.723841 \n", - " 998 NaN 6261.182193 3.609560 6.024725 38.411214 \n", - " 999 NaN NaN 3.051021 7.855145 75.446009 \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", " \n", - " MarkDown5 MarkDown2 CPI \n", - " 0 7218.910186 NaN 154.437715 \n", - " 1 4928.572337 1860.474257 126.006100 \n", - " 2 4381.712904 6491.080062 182.505905 \n", - " 3 NaN 2075.661361 170.252793 \n", - " 4 -1363.927516 6383.011111 181.850615 \n", - " 5 NaN NaN 250.558631 \n", - " 6 NaN NaN 213.846371 \n", - " 7 1644.489130 4966.726612 189.105468 \n", - " 8 NaN 4689.948004 152.931090 \n", - " 9 3291.191563 2344.032770 152.780297 \n", - " 10 NaN NaN 195.602913 \n", - " 11 NaN NaN 176.381561 \n", - " 12 NaN NaN 147.881741 \n", - " 13 1668.882903 NaN 155.004680 \n", - " 14 NaN NaN 193.442985 \n", - " 15 1437.517079 4874.558939 152.281926 \n", - " 16 8439.806431 NaN 130.530240 \n", - " 17 NaN NaN 201.486725 \n", - " 18 NaN NaN 188.760741 \n", - " 19 -6907.068884 3246.451314 108.120768 \n", - " 20 NaN NaN 148.788886 \n", - " 21 5724.976324 4496.854628 190.036218 \n", - " 22 NaN NaN 181.746008 \n", - " 23 NaN NaN 200.412670 \n", - " 24 6637.078664 -1388.022496 87.164451 \n", - " 25 NaN NaN 195.622231 \n", - " 26 2922.749710 NaN 224.233432 \n", - " 27 -3042.801607 NaN 131.909287 \n", - " 28 7063.102432 NaN 146.156666 \n", - " 29 NaN NaN 176.694595 \n", - " .. ... ... ... \n", - " 970 2882.447318 NaN 148.323409 \n", - " 971 1918.923930 NaN 141.723698 \n", - " 972 151.136535 1221.530133 155.532651 \n", - " 973 NaN NaN 213.405340 \n", - " 974 NaN NaN 186.636286 \n", - " 975 2598.731935 -846.566855 190.957987 \n", - " 976 NaN NaN 153.917260 \n", - " 977 8372.618269 1768.408869 249.068471 \n", - " 978 NaN NaN 177.318454 \n", - " 979 NaN NaN 208.719040 \n", - " 980 NaN 754.021358 108.580497 \n", - " 981 NaN NaN 159.526588 \n", - " 982 7349.003146 -732.495611 173.559001 \n", - " 983 NaN NaN 140.910196 \n", - " 984 401.319810 NaN 175.340840 \n", - " 985 NaN NaN 171.683362 \n", - " 986 -465.292159 -744.114065 177.553651 \n", - " 987 NaN NaN 144.344477 \n", - " 988 567.400188 NaN 213.833680 \n", - " 989 NaN NaN 157.113353 \n", - " 990 NaN -232.801449 132.592027 \n", - " 991 -72.622621 NaN 207.754643 \n", - " 992 NaN NaN 132.575018 \n", - " 993 4955.778653 4683.646548 160.260401 \n", - " 994 2898.089751 NaN 129.801694 \n", - " 995 NaN NaN NaN \n", - " 996 NaN 608.026190 115.452177 \n", - " 997 NaN NaN 206.323679 \n", - " 998 NaN 7020.489423 178.966240 \n", - " 999 NaN NaN 152.636567 \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", " \n", " [1000 rows x 12 columns]}" ] }, - "execution_count": 13, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/5. Generate Metadata from Dataframes.ipynb b/tutorials/04_Working_with_Metadata.ipynb similarity index 100% rename from examples/5. Generate Metadata from Dataframes.ipynb rename to tutorials/04_Working_with_Metadata.ipynb From 5cacaf6f4e8fcc29fe2ff8c1020028fada3eef2a Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:11:16 +0200 Subject: [PATCH 26/33] Remove unsupported constraints argument --- sdv/tabular/__init__.py | 1 - sdv/tabular/base.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/sdv/tabular/__init__.py b/sdv/tabular/__init__.py index 752184c61..dae118d1d 100644 --- a/sdv/tabular/__init__.py +++ b/sdv/tabular/__init__.py @@ -1,7 +1,6 @@ from sdv.tabular.copulas import GaussianCopula from sdv.tabular.ctgan import CTGAN - __all__ = [ 'GaussianCopula', 'CTGAN' diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 70b5aeb7a..f21be7123 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -15,8 +15,7 @@ class BaseTabularModel(): _metadata = None def __init__(self, field_names=None, primary_key=None, field_types=None, - anonymize_fields=None, constraints=None, table_metadata=None, - *args, **kwargs): + anonymize_fields=None, table_metadata=None, *args, **kwargs): """Initialize a Tabular Model. Args: @@ -38,9 +37,6 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, anonymize_fields (dict[str, str]): Dict specifying which fields to anonymize and what faker category they belong to. - constraints (list[dict]): - List of dicts specifying field and inter-field constraints. - TODO: Format TBD table_metadata (dict or metadata.Table): Table metadata instance or dict representation. If given alongside any other metadata-related arguments, an @@ -54,7 +50,7 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, if isinstance(table_metadata, dict): table_metadata = Table(table_metadata,) - for arg in (field_names, primary_key, field_types, anonymize_fields, constraints): + for arg in (field_names, primary_key, field_types, anonymize_fields): if arg: raise ValueError( 'If table_metadata is given {} must be None'.format(arg.__name__)) @@ -66,7 +62,6 @@ def __init__(self, field_names=None, primary_key=None, field_types=None, self._primary_key = primary_key self._field_types = field_types self._anonymize_fields = anonymize_fields - self._constraints = constraints def _fit_metadata(self, data): """Generate a new Table metadata and fit it to the data. @@ -84,7 +79,6 @@ def _fit_metadata(self, data): primary_key=self._primary_key, field_types=self._field_types, anonymize_fields=self._anonymize_fields, - constraints=self._constraints, transformer_templates=self.TRANSFORMER_TEMPLATES, ) metadata.fit(data) From 1bbb7c0aaa9b8ad3bcdc0409e81fee796f308fe0 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:13:55 +0200 Subject: [PATCH 27/33] Change dev version to actual releases --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index ef713040b..3c2c87521 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ 'exrex>=0.9.4,<0.11', 'numpy>=1.15.4,<2', 'pandas>=0.23.4,<0.25', - 'copulas>=0.3.1.dev0,<0.4', - 'rdt>=0.2.3.dev0,<0.3', + 'copulas>=0.3.1,<0.4', + 'rdt>=0.2.3,<0.3', 'graphviz>=0.13.2', 'sdmetrics>=0.0.2.dev0,<0.0.3', 'scikit-learn<0.23,>=0.21', From 53e044b5b0a5a06767d8cec4b20b47df3645671b Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:21:15 +0200 Subject: [PATCH 28/33] Update the docs to use tutorials and cover Single Table Modeling --- .gitignore | 2 + Makefile | 4 +- docs/conf.py | 4 + docs/index.rst | 8 +- docs/tutorials/01_Quickstart.ipynb | 838 ++++++++++++ docs/tutorials/02_Single_Table_Modeling.ipynb | 1217 +++++++++++++++++ .../03_Relational_Data_Modeling.ipynb | 1039 ++++++++++++++ docs/tutorials/04_Working_with_Metadata.ipynb | 460 +++++++ docs/tutorials/demo_metadata.json | 74 + docs/tutorials/sdv.pkl | Bin 0 -> 255975 bytes setup.py | 1 + tutorials/02_Single_Table_Modeling.ipynb | 529 +++---- 12 files changed, 3914 insertions(+), 262 deletions(-) create mode 100644 docs/tutorials/01_Quickstart.ipynb create mode 100644 docs/tutorials/02_Single_Table_Modeling.ipynb create mode 100644 docs/tutorials/03_Relational_Data_Modeling.ipynb create mode 100644 docs/tutorials/04_Working_with_Metadata.ipynb create mode 100644 docs/tutorials/demo_metadata.json create mode 100644 docs/tutorials/sdv.pkl diff --git a/.gitignore b/.gitignore index 241e9f48f..2e25954c9 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ ENV/ .*.swp sdv/data/ +tutorials/sdv.pkl +tutorials/demo_metadata.json diff --git a/Makefile b/Makefile index 31f1d6d6f..5fc2f37f3 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,7 @@ clean-pyc: ## remove Python file artifacts .PHONY: clean-docs clean-docs: ## remove previously built docs rm -f docs/api/*.rst + rm -rf docs/tutorials -$(MAKE) -C docs clean 2>/dev/null # this fails if sphinx is not yet installed .PHONY: clean-coverage @@ -110,7 +111,7 @@ test-readme: ## run the readme snippets .PHONY: test-tutorials test-tutorials: ## run the tutorial notebooks - jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 examples/?.\ *.ipynb --stdout > /dev/null + jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 tutorials/*.ipynb --stdout > /dev/null .PHONY: test test: test-unit test-readme test-tutorials ## test everything that needs test dependencies @@ -134,6 +135,7 @@ coverage: ## check code coverage quickly with the default Python .PHONY: docs docs: clean-docs ## generate Sphinx HTML documentation, including API docs + cp -r tutorials docs/tutorials sphinx-apidoc --separate --no-toc -o docs/api/ sdv $(MAKE) -C docs html diff --git a/docs/conf.py b/docs/conf.py index 1c7e3ea90..c99d64607 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,6 +32,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'm2r', + 'nbsphinx', 'sphinx.ext.autodoc', 'sphinx.ext.githubpages', 'sphinx.ext.viewcode', @@ -53,6 +54,9 @@ # The master toctree document. master_doc = 'index' +# Jupyter Notebooks +nbsphinx_execute = 'never' + # General information about the project. project = 'SDV' slug = 'sdv' diff --git a/docs/index.rst b/docs/index.rst index 4f136e234..c0f4a2cc6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,15 +7,17 @@ Overview .. toctree:: - :caption: Advanced Usage - :maxdepth: 3 + :caption: User Guides + :maxdepth: 2 - metadata + tutorials/02_Single_Table_Modeling + tutorials/03_Relational_Data_Modeling .. toctree:: :caption: Resources :hidden: + metadata API Reference contributing history diff --git a/docs/tutorials/01_Quickstart.ipynb b/docs/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/docs/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/02_Single_Table_Modeling.ipynb b/docs/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..fa536bd3d --- /dev/null +++ b/docs/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM38
11UKNaN23
22USAF34
33ESNaN47
44ESF29
55UKF39
66FRNaN40
77ESM38
88ESF32
99ESF36
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKF48
11USANaN38
22USAM29
33BGM22
44USAM43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" + ] + } + ], + "source": [ + "model.fit(census.sample(1000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
050Local-gov1697191st-4th9Widowed?HusbandWhiteMale114838Columbia<=50K
132?1524791st-4th9Never-marriedAdm-clericalWifeBlackMale-422021Jamaica>50K
222Private69617Bachelors0Separated?HusbandWhiteMale61138Guatemala<=50K
325?65285810th16Married-civ-spouseHandlers-cleanersNot-in-familyWhiteFemale152-2739Cuba<=50K
443Private301956Some-college8Married-civ-spouse?WifeWhiteMale-133-1239India<=50K
566Private401171Prof-school13SeparatedProtective-servUnmarriedBlackFemale-124-140Cuba<=50K
652Private278399Bachelors12Never-marriedProf-specialtyUnmarriedOtherMale122567-647Columbia<=50K
736Federal-gov229817HS-grad8Married-AF-spouseFarming-fishingNot-in-familyWhiteMale81938Portugal>50K
827Federal-gov306972Some-college8Never-marriedExec-managerialHusbandAsian-Pac-IslanderFemale42144339Japan>50K
928Local-gov4161611st-4th8DivorcedAdm-clericalUnmarriedWhiteFemale-349109061Guatemala>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", + "\n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", + "\n", + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-144.971907591418" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/03_Relational_Data_Modeling.ipynb b/docs/tutorials/03_Relational_Data_Modeling.ipynb new file mode 100644 index 000000000..230061b00 --- /dev/null +++ b/docs/tutorials/03_Relational_Data_Modeling.ipynb @@ -0,0 +1,1039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", + "\n", + "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", + "\n", + "From those stores we obtained a training of historical data \n", + "between 2010-02-05 and 2012-11-01. This historical data includes the sales of each department on a specific date.\n", + "In this notebook, we will show you step-by-step how to download the \"Walmart\" dataset, explain the structure and sample the data.\n", + "\n", + "In this demonstration we will show how SDV can be used to generate synthetic data. And lately, this data can be used to train machine learning models.\n", + "\n", + "*The dataset used in this example can be found in [Kaggle](https://www.kaggle.com/c/walmart-recruiting-store-sales-forecasting/data), but we will show how to download it from SDV.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data model summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

stores

\n", + "\n", + "| Field | Type | Subtype | Additional Properties |\n", + "|-------|-------------|---------|-----------------------|\n", + "| Store | id | integer | Primary key |\n", + "| Size | numerical | integer | |\n", + "| Type | categorical | | |\n", + "\n", + "Contains information about the 45 stores, indicating the type and size of store." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

features

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|-----------------------------|\n", + "| Store | id | integer | foreign key (stores.Store) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| IsHoliday | boolean | | |\n", + "| Fuel_Price | numerical | float | |\n", + "| Unemployment | numerical | float | |\n", + "| Temperature | numerical | float | |\n", + "| CPI | numerical | float | |\n", + "| MarkDown1 | numerical | float | |\n", + "| MarkDown2 | numerical | float | |\n", + "| MarkDown3 | numerical | float | |\n", + "| MarkDown4 | numerical | float | |\n", + "| MarkDown5 | numerical | float | |\n", + "\n", + "Contains historical training data, which covers to 2010-02-05 to 2012-11-01." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

depts

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|------------------------------|\n", + "| Store | id | integer | foreign key (stores.Stores) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| Weekly_Sales | numerical | float | |\n", + "| Dept | numerical | integer | |\n", + "| IsHoliday | boolean | | |\n", + "\n", + "Contains additional data related to the store, department, and regional activity for the given dates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load data\n", + "\n", + "Let's start downloading the data set. In this case, we will download the data set *walmart*. We will use the SDV function `load_demo`, we can specify the name of the dataset we want to use and if we want its Metadata object or not. To know more about the demo data [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.demo.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(dataset_name='walmart', metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our dataset is downloaded from an [Amazon S3 bucket](http://sdv-datasets.s3.amazonaws.com/index.html) that contains all available data sets of the `load_demo` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now visualize the metadata structure" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "stores\n", + "\n", + "stores\n", + "\n", + "Type : categorical\n", + "Size : numerical - integer\n", + "Store : id - integer\n", + "\n", + "Primary key: Store\n", + "Data path: stores.csv\n", + "\n", + "\n", + "\n", + "features\n", + "\n", + "features\n", + "\n", + "Date : datetime\n", + "MarkDown1 : numerical - float\n", + "Store : id - integer\n", + "IsHoliday : boolean\n", + "MarkDown4 : numerical - float\n", + "MarkDown3 : numerical - float\n", + "Fuel_Price : numerical - float\n", + "Unemployment : numerical - float\n", + "Temperature : numerical - float\n", + "MarkDown5 : numerical - float\n", + "MarkDown2 : numerical - float\n", + "CPI : numerical - float\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: features.csv\n", + "\n", + "\n", + "\n", + "stores->features\n", + "\n", + "\n", + "   features.Store -> stores.Store\n", + "\n", + "\n", + "\n", + "depts\n", + "\n", + "depts\n", + "\n", + "Date : datetime\n", + "Weekly_Sales : numerical - float\n", + "Store : id - integer\n", + "Dept : numerical - integer\n", + "IsHoliday : boolean\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: train.csv\n", + "\n", + "\n", + "\n", + "stores->depts\n", + "\n", + "\n", + "   depts.Store -> stores.Store\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And also validate that the metadata is correctly defined for our data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate(tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Create an instance of SDV and train the instance\n", + "\n", + "Once we download it, we have to create an SDV instance. With that instance, we have to analyze the loaded tables to generate a statistical model from the data. In this case, the process of adjusting the model is quickly because the dataset is small. However, with larger datasets it can be a slow process." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables=tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: We may not want to train the model every time we want to generate new synthetic data. We can [save](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) the SDV instance to [load](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) it later." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Generate synthetic data\n", + "\n", + "Once the instance is trained, we are ready to generate the synthetic data.\n", + "\n", + "The easiest way to generate synthetic data for the entire dataset is to call the `sample_all` method. By default, this method generates only 5 rows, but we can specify the row number that will be generated with the `num_rows` argument. To learn more about the available arguments, see [sample_all](https://sdv-dev.github.io/SDV/api/sdv.sampler.html#sdv.sampler.Sampler.sample_all)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'stores': 45, 'features': 8190, 'depts': 421570}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.modeler.table_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "samples = sdv.sample_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This returns a dictionary with a `pandas.DataFrame` for each table." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TypeSizeStore
0B1061060
1B680711
2C2536032
3A2232683
4B1668484
\n", + "
" + ], + "text/plain": [ + " Type Size Store\n", + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['stores'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateMarkDown1StoreIsHolidayMarkDown4MarkDown3Fuel_PriceUnemploymentTemperatureMarkDown5MarkDown2CPI
02009-12-13 04:59:49.832576000222.2688220False758.900545-1080.5694354.5749666.05279551.354606-4016.012141NaN224.973722
12011-07-20 14:40:40.4002447366857.7120280FalseNaN-551.0452573.0182966.72066980.723862NaNNaN204.595072
22011-10-12 08:23:54.9876582405530.9294250FalseNaNNaN3.4663587.25657073.954938NaNNaN202.591941
32012-11-26 09:22:59.377291776NaN0FalseNaN3703.1371913.6237256.59861572.041454NaN3925.456611169.457078
42013-05-13 15:22:12.213717760NaN0False3773.942572NaN3.3817955.11500283.05832215018.944605NaN182.801974
\n", + "
" + ], + "text/plain": [ + " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", + "\n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", + "\n", + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['features'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateWeekly_SalesStoreDeptIsHoliday
02012-04-19 04:40:04.3703828487523.202470036False
12012-10-24 12:49:45.63017651213675.425498017False
22012-06-05 05:02:35.5517614082402.017324031False
32012-05-31 22:15:14.903856896-12761.073176081False
42011-09-07 15:34:43.2398356483642.817612058False
\n", + "
" + ], + "text/plain": [ + " Date Weekly_Sales Store Dept IsHoliday\n", + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['depts'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may not want to generate data for all tables in the dataset, rather for just one table. This is possible with SDV using the `sample` method. To use it we only need to specify the name of the table we want to synthesize and the row numbers to generate. In this case, the \"walmart\" data set has 3 tables: stores, features and depts.\n", + "\n", + "In the following example, we will generate 1000 rows of the \"features\" table." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'features': Date MarkDown1 Store IsHoliday \\\n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", + " .. ... ... ... ... \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", + " \n", + " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", + " .. ... ... ... ... ... \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", + " \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", + " \n", + " [1000 rows x 12 columns]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.sample('features', 1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SDV has tools to evaluate the synthetic data generated and compare them with the original data. To see more about the evaluation tools, [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.evaluation.html#sdv.evaluation.evaluate)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/04_Working_with_Metadata.ipynb b/docs/tutorials/04_Working_with_Metadata.ipynb new file mode 100644 index 000000000..91b001036 --- /dev/null +++ b/docs/tutorials/04_Working_with_Metadata.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'primary_key': 'user_id',\n", + " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", + " 'sessions': {'primary_key': 'session_id',\n", + " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'device': {'type': 'categorical'},\n", + " 'os': {'type': 'categorical'}}},\n", + " 'transactions': {'primary_key': 'transaction_id',\n", + " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'approved': {'type': 'boolean'}}}}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import Metadata\n", + "\n", + "new_meta = Metadata()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.add_table('users', data=tables['users'], primary_key='user_id')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.add_table('sessions', data=tables['sessions'], primary_key='session_id',\n", + " parent='users', foreign_key='user_id')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "new_meta.add_table('transactions', tables['transactions'], fields_metadata=transactions_fields,\n", + " primary_key='transaction_id', parent='sessions')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'approved': {'type': 'boolean'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "gender : categorical\n", + "user_id : id - integer\n", + "country : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "os : categorical\n", + "device : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "timestamp : datetime\n", + "session_id : id - integer\n", + "transaction_id : id - integer\n", + "approved : boolean\n", + "amount : numerical - float\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.visualize()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/demo_metadata.json b/docs/tutorials/demo_metadata.json new file mode 100644 index 000000000..2be7ecc2e --- /dev/null +++ b/docs/tutorials/demo_metadata.json @@ -0,0 +1,74 @@ +{ + "tables": { + "users": { + "fields": { + "age": { + "type": "numerical", + "subtype": "integer" + }, + "user_id": { + "type": "id", + "subtype": "integer" + }, + "gender": { + "type": "categorical" + }, + "country": { + "type": "categorical" + } + }, + "primary_key": "user_id" + }, + "sessions": { + "fields": { + "os": { + "type": "categorical" + }, + "user_id": { + "type": "id", + "subtype": "integer", + "ref": { + "table": "users", + "field": "user_id" + } + }, + "session_id": { + "type": "id", + "subtype": "integer" + }, + "device": { + "type": "categorical" + } + }, + "primary_key": "session_id" + }, + "transactions": { + "fields": { + "timestamp": { + "type": "datetime", + "format": "%Y-%m-%d" + }, + "transaction_id": { + "type": "id", + "subtype": "integer" + }, + "session_id": { + "type": "id", + "subtype": "integer", + "ref": { + "table": "sessions", + "field": "session_id" + } + }, + "approved": { + "type": "boolean" + }, + "amount": { + "type": "numerical", + "subtype": "float" + } + }, + "primary_key": "transaction_id" + } + } +} \ No newline at end of file diff --git a/docs/tutorials/sdv.pkl b/docs/tutorials/sdv.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d3c88b5a4eca4045c109b46b6393d46f5be7a46d GIT binary patch literal 255975 zcmd@7cYIV;+x?GIq=nu=x(b9My($otUWHKvR18T7Oh!nUorI=C6Anc|njnZEponw? zR0IrFI*5Suj!F?wLG1Ef*V<=hGn{xk`S`r@`{U>H+}AmKt+n^r^PW93=S*gX;hAF+ zdkv_D|Ajr>{Mj&XhDsy68GYVN;oY)i$dI95La*4qULE0nVzror{>go#62n?WB_}4v zMa4H`JIwn)w_N@XkuMF5QhgG=S$y8CRC~oGCaJjSp)0@}l&DSj}tfZ*uzOjkk{6231eZ!LzV^yNJpwC+!`+3ug&}K;a(lwvxDl1Y-QhlOWenvm!On7+`s?YnV2E`lh zt>!Mm^+Jx`>h6Ncfu5Xjuh;=`F|pnnK5tDuVIh)`=&j}R)~0~T$?{23QSpgUF|tk) zy>)!vx-!8$!G=t-o=mcSxVxUu+u)u_V$Kc2-I&~C;qGegaCdcgg?lC_2 z_3Q6_+~;j1sAOefRmF@}y=Q)xW<=7!cdb+sx-}j$_DriB3r98ygkxZIOt3BC%URNRe(V4;9H9 z-5bZ0CD+pDZ58e=7^8Y6Rdd*3t);^=0!in?ysfe1+xWa~!woktMa3t?5AGK?B-Y!hqrSPxLxv3Yb{@x79@96jzt=U6 zH?tpZdt7`U@6$f-GhBwUz}0}LzKPy0KJT;Pd@OWn=j|HaE<>w`L$&!|tB9@soeyX7 zza(!rF19XDKjG~jeoycmxZ%H65yzRE)hc3Me*Tx_eU4+dc-s3sI-Fw+xo>;(zg7_! z{f7v2c!3>Sy1YHmA&|SGXV5C*vVVxl{v*7|4$WJ5Bhev{`+hI~YZY;gN2t{*BG92H zJ3QIS8-)&m+|9fEuT{hi|J(u{qLY)O-R0auKgQ?n74GiT(cO8R+cnM&ps{Ws>>bW0 zvzFdIBpAT3%g_9AVGtc|;0pgxk!+zB_`Cc49_$laHSG!e2pu)SVIY6CE} zfj;k`a6T>ej>9alc9UbhgJW>3`VYp-pNg&4FBvn4QYvb2SR$U5qg0r8i1x!ckudMe zyyJT}@V+u`>=>`ln8sO9);p2YW}2w<#>(6I@h$FO`6op!Y(Y!~2*+vKt&13XtiT&Ky6(58MP-hu{|ZydM(W;&Aul z&gJTVi3)3Ozg7h;+a(~s)aP9m?oM#`$I5XhyJaOW_jy-@y9Z;v3~iSEuaosQZ5_O} znRlh-UFGtAB(YKB+@swS-1Y)+H+9QHbG6UACNR)Fk2k)=g?ZPyliX|V1AOfBeiAsq z`5}dtx60qOsrOUMyUykPER_Q^bKWrfbusKo{rVVmv({ra8+_i4m<=B*=9Yc#^QPdy z0X6uBkoP8VbT7qiTJ3J;-r#4xnaq><@Q#de$Mx+MiI*(z7Mzs!?HjAlJi>a#_Kw0s zyjNTdo?*6{XPm@Xw6~4+Zg+XVz>BndTMzFJcZI0Y-kmP*E<6Cx-tDe~_8ymaue*-4 z_l@@McXz15AO4!XQwh3l?|+y&9~u*-Wycm0|}A9Z<;=|deK z?LFc0p2T|6ue5S*jNKgH$LyT`UDy0B)t;)MJC4LtmnIqGAz%YA|vHW zATrW>xspb1UmHNV;`3gWTq19cey&gSUhCogUhcT2?q;!MW-byEz1MqqZ-l$EzCZ$gas&V3JwUo7vh zF7I!Al9gxlSnuyT@w&C-((2OLn|c4Rynnj9f9Z3;NPS)OcT8eT99|P}LYb6UEmE#> z;v?hYd#fm1?j*+~C9BvlH6XILN*EFwALhMVX@tsv3lX2nh^j9TgpSIDW+gt^zcjE{ zqRPz71CnvyasTH2-HpS_O93yVvOo_$ZcIWvUJsKKk`wh562szgG8u`BgfNv=h904^ zameg4q`h|f$uqag0i|6=zAoWK42@QI?Wa>}(JH5C!~B{|Vo*{-UzJPb##u+bZOyCl zHt*=@s9t=bRJqy1K9S(GRRe+LALYA2-D9vR1by=^!YLWfpVqWSS8x|R<1-WZtBV&>hRG2D+Ja2u6 zXKFTVSi8O|3}Lh?g6dL5QGKcyYB*m8qvMihQyRI#|D^uh>Jr2@gQ_+I)p$q za;L;>pS`wnt!m|Sw6d}d{q*QJFHE0ZWySmxT@OEW>iohRS<1ff*!KD@-fZ6dZ1~g@ zn~rvFczDYL8~V&FmwDEsV~&p=ez(;xyK?QG-6C6^&R>+<7}ICOnlGoe+&it3JG{?_ zeJyqPdhW%$TRz;m!{1-DS=H&eHxjFtZ@=7AK2P5IAG#;3oAS%B`u*>W?-lv`A8T?H zXn*L`)nWseEZ=$j^70+$>UJ6Z{p#LxqhE|2o3&@kM~z#2RO^X9tM{He=%a?)7FQpf zYtUaAXWsp3|M0~XW}f<(j`%&HBsehR;mjF!1ulk!|-i{5YX-$3m+!AH30^a^{N1?pCke+vEE3iIMF} zwtRJ+>*SK3@(Pd*ACf;qIxZ$+y=oc&Jh5{pCt8-FRYIt6Ya3 zTfeIC+FO~j6{x#3eBQ<_3CHV#_gXp z&2n`4l#X?l^~nBOx0i}P`=G1D>Y8C&&o?}_p-=3-(vN>Oqjdet8A{BnfBX57LmsZ) z|Gntpo10WSyZJ!E?V`g7OJCvx2_IVVB#s?7752X8)Bds(X%b4zrdIJaNls@3CvFFEwV@ZDLy z`t`-z0~)MpJt_9$kt11ZcdNB*Vulro)u;b9de@nZ9ZDA%+~hWRGFq$kK-e{ zmdL#K&b9hCnjB27Q^jg@V%wBAdiv(Q@Jdw5%^SY>i0chzcdz--(Xi9+sA=)*s$5yQ z;Aa1QH$UG~v&YkA+&jKXs8PILch{9NC%zc`W?+Vhi=*UsA zGDWofY~9K&v%0ozRm{q;_C&*#)wi8#lY9NL1)G0cI`2sT1A8(T z`g`L$U;mvlp-$l`qc&tnywUH8H9Jc@5&c<*>0K(H+1z=~%?VRad(T9?R4kWoQKrm! zyZu^xO4NbMuic%tcE+LCdR{+t{r7C&d$!MAabomicNR2AD5AO--xJqnu<7d>l1wkp$Fj2)c&aMN@3j#T(5xt290a>A{tUliI9 z`Qe2;RkP-4(&OS&t3K~AZORR+#zPINPF;KTPOGmzS{MZEvm;kJ-JN`8^=D_kt?%pBI3efqPV=^RUi#GL_tx|s z_|^XR^Y+}lVE03p*8KVFtR-clD_$78uh=UW*LGNR?AtRFi@w%yRoH>os|dcovGL-$JXY{$92tG?smubj*k1= zz5J(o?N5K){zjSpy$5f8`Ndi9hGSc zHob$d-=sb7=Iweb>D_D>cXVCf`})KAx@3yjern0Ohj;cp7q|V93nee_Z%}B#XG6;M zoLHjKr4iq)&+HjjaaWn)3!nV<#a4Z*<~|Lv%M8xf6~x*H!LllWn`Xff1GH2aBj=dW8NtEZo3S#jxJdqon`v& zw_on^^k1#FZQMO=)F&Ipj%j`Ql?T5sKHyTv1zl!6(6e}(B9E--bhvfH{X3%S{JM8p z>E)Aa?ke46_twY(6O)T)xb@Aj<5jMYA76R*z9z-yx6jpa$H)`gkN@!c-~-+p?`Mx0 zb*%cx+uNI5eKW_JbvZNFS^Y`ED?3|$+Pd4RJ98hK-lEgTYsTd`k#ug<{?|L?e0;(a zn=a3~`gD!W^NweFrt+?D`@HsJ*7XZ>*BaODT*2Qzz5UwqZC{+bb70R?z6WYtDlnpD zwlU{+p1B^g`24Uk8_M3ESvqHq?`m$_UFT+=0_)#dR`yCnt4YnvZb{s+rSy!?e>{=5 ze#UZRy8ZC;l9yDz3*#0IcqGS)AseUUc{C}1jpeQL#$6n7a7}~HCw=thH^YAznLK=9 zv8+`(9m;&H(NBZt9LZ2JtlX?-^#@u-3cBCe8-4YSS_i+GH)6}gA*I$=ex~@8eS0bm z-~ZXA=SLnGm*=->%|9FEnU%A{%hmU7nYFuT=AIYE{giD|V(Y<$lj_fX{FgT(DiA_DX>3t*#SS7%29t#$4A;Hm!Cc`~(D}b**ZK7YA)h7TS{W6NbGal{ zNzUl_E@}USL{%BB4v`s?`KGHoDlw9OQc+bPx6`MsTBr)46;4;$gAR4r-U_Nm&m?JF+- zdjzTuR7=%Gb*Xxiv@f{c?ewPi`lY%}Rec5;FwhV++5NX$uAm;nzE3@l8t&nRlSf}K zH4?>r$D;@Hu!~17mOS<~P-Av%!jAg7=YQR3PtYZfC+MoiBNsCsxrXuZvc7{Bi9p_ia`Arww`-|Hbvg2owX?ozd4L7hWF!m6l5y`ZDWj?7Gx5sJP5(m$_YLOLalRJ#<{N)~e(V zZ^<1$MhyOH@r}+yzKke3EX!|qu1`tvw_(!nQorJyZ~9(K@%#R^zrT7GBPOY?sL}qr zoo}4t?HDB&+&YvBLYi-`YgvZf>>Lrm6GJwr*>k};z~Ox9GW~To%%n(E^ z^)jkUy&~P@&n0*hEk@qb?Vo@@M=2jZy627{JEqZpQv?H_z6Ly<{tWRANctMnTl8m&Z`jb+ z7-rFb8&xpO#x_+9@4&Fs9ERsg*n#0)s6I6hl^EtrD1c!BS~dpxQw{DFxo4WH_n=$q zeN>lPC>g^*#!e;6wV0#WxJ-YX(c`FBRV~psyJ$qrV(g7_7jS41#95 z5~`(EG5L`s9ROBC^{F+e1h7^@0RSIM3xj_`+dj8>x{LhxO^W&ys-@PUy3}Wqw9jo$ z+&Mh|8S94#wVr_u3~WU8pWA$neV^hfQT;#{L=6UdP!s^wgL&8`2=>ZOqy?bIY1=@b)hjeYoq%eolc+9rN|H9vlU$+aBGhRH&M_I>I+YB10X zq5!BK%)>50Krgc6C3Z|Npx@wVX#x5z=lGq>QGi~?jsbc_S^#>L_Wcj}YfvrqJ*rDx zm!u8!xF6^Z25vI&1F9eBkJ$IApHPE=-Vy~s^x0?_Z&J8OVT7=DAcxR6o#6*!QW- zsKG!V5CuT>U>*%9pt@9EN!mca?=q><{7spfseBCNXP^M8A80}B`&1#+V4#IX0Z=`dhh2hz z7GcMt?3i9ai{WT#0a~1Md{E{nKuci904*sk04+t^20A2T%z|ED=W41-L$y>HRF^6% zNgL?(X6r83xLK>ID#yS>3_OhL2l@#1eX2ZaFwhF30H_|!!!AKUE3#uaJEj-VN;q0t zfL7)ltH>M$XjSYOppQxmK&#QVfv)^@{EA;sbZx4tL$y>5RF|qLNgHUrL)F4+ZLZc- z)ncGF19edSK=FdD0XsHi$MgdF7>Plb*U#MX#=f4Bw=`yx@DTGW(+iEparTQXiMz- zR4deApshs#P(7H3U4nqNVaK-Ym|j5J;b>_A+MaXlAafL;9kFA8J|!&x?L^xKTE5Mc zN7_ue6rnmpwUi6hrJk0g4YYfc-=Ew*^=O28hJh{&Jd5fF+7QcQVX#;(6a-D2TPUdf_Vj1YoKp#{;P&f8{Dh@Rms7Dk4)q{E1B?#zC?AVtb(+g-n z94#$C<2lC!nWF&hj~xTlD=h$3v~8fcO2U7*vP^_(sU%dFN|vMzw6%)3J#^RQ2sMC# zfeZ{n^#dJ@eV-bF8VvMhQ2*&n8h9W5$?Wnq2>v&`iGT0^H{1XJJ$}EEJ0{DYGR{svdwlMQ z^AYj<)5ZCe1icNkS4_{?YWljX5f#sPkJhQaB%;Ik8DmDSemkXE^x&=y#`{v9t(I_W z?6sW{*$PzO+voYSDYKhBaLLtZ$%cQrKmEsx{i|)p=TpA2kAE=bJm+yR(4O~c`CAQ! zoQybHYs%fR3qDV&|97hvVU3nYlsmiqRK+S=Qz}$%J_UYpvOhKEszrX&IjkNivo86=4|Fm-YmXwLz8uu-+V^+keVr!aY9=b2$^05`i zhCl30>Bif8EJfemn-L`|Jv}C&qz2xZliM9XnxT)M;-ah~F{kQ-B`u6(Qi~s)fkMF;I{%`ul{1~yf|A~5s|2Q5|aPXx(?xgb( zbD!xuZ1C+15&v}m`X3Me_32+P{`=40_TN7L@%)d!AxWA4=(Om3%J|jYzKV(YA*J*e z*;*f3y)H_{p-`eUi|l;pQr!!`H$ycf9h4dC?=^1OBb)AR1mBi|e#7FY=7d}(@L-1<||IZZv2>f~oshVoSzoq!c z--#ilYN|IvNYzr4{w-Bo{@gq{gj7v6C4^KhH8q4(Z8h!RQgvirZ-$VnsiudJs-Pj2J8A*5=m zc_E}~srezKYO4hyr0S^mLP*tB?}w18rxyNOs=ko;AcRy+wJ3yCE%jjtsoH9B2&p=1 zNeHRBYH0|mdTLn+srqX9zoi-ouN5JrYO0kXq-v>EA*5=nk3vY*QL95p)m3XkNYzto zLrB$EABT`?pg#GxR70WqX$YyBYF!AaTI#b9Qnl6k5K?v2h7eM9)y5D~_0;Dfr0T1b z5K;})rVvsM)n-ZYLnZ%krcjfg8QT&{tftx;O01UJ7D}wP+8#=*j`|{$SY5Rvlvq8r zGn80;wJVfZ1GPJpSVOf(Vx|Ab`I_7dHEXE7|2iB`+4hAJtEu*f604;Sgc7T*z6>Q+ zM;#0$R#$x$O01qb6iTeVIvh%@fjSaOtfBf^V*m5r!o7fd>u5;1n(ElU9+oH1$3uzL zR3}1-)lw%ziPcu8LW$K;r$dR=RcAtp)l+9fiPcxG4A5+fZUP)pw!9YN^Yi#A>T6p~UK_tD(f|s%xRd>Z$KTiPcxv zLy0v|H$sUuR5vB|Kkt3q3%K`x2q{-n{TNcNmij5ATy1qra{u#nnwOQgLy6T?KZg>l zrS60htF3+sC00lM8cM9L`Yn`LJ@tDivHI$dP+|?#pP|GWs=p-mKUWFv1+0?4L(0`u zcSFk6QW@|K^M5~XZIv;kTpg83a*y11z7h0cdHs`J`h(?}C0!wuwEjr>1ClTMzw!x* zdgLsUeCXejdep3vZSude|MtxXeavh!PSgK3j-FBW|1={xQ`0lb@t!*? zVDjCFod0!37)Q@2*MFLkocHJ%<^E4I(&OkE<@v8OvKi(5uQLiT%J*MqgqtV#TK@k! zBaEXN6_AnIq~$RZ^yCrzO&dMAf-+>g|Fa?WJPZAw4a)KKObg4P9sVz7Y7eUCT119z z`TsoZJ>MwO(=00Ewf?^uPtUN}|ILW^pe!yUw*Aj@4C46Uf1G9D`%ietXp$xV<1B+m z)GSN>UyXSW%~CRE`~P>Efq4e;EG?rpNXw|^n_-$?85!ZR|JM;Tud*^mk^9Cl2U6wW z^DAf%VMl)j?P2Ws)FW);V?E`uV?NeXL0b4&Pes~6AM3%#@x!57suHS8RhFdvv7R5g z?s$Ag&aWa=6$Yv@@F=SPV?EWd?^D%LgFn_&Llk_hM-S##2>S zh4`@^exsxg=U7+fD4&z6haCg7zO(?e0qy&NHiT-a$537BaY@=h--@ex=2-WG5vmab zjTvZy>IWKueV=NI8VvLaQ2cbpv`f#v;b|vIkuEJ3eZ;AF+f{O z3qae@wt@P822!XpF4bO=HqeXxcdieQ4{xeEFwl{Kr%?SsJ7M3aI->>yb%_F? zdN2>W1Oa`T9iL&x^a9!iM@tLPXF12NGDiU_zkXnVc1J7tTXc9j8Qw%a2i;Q7qq@`! zk}=<+lkx0t(e#^oR1B5s~zJif_I}M_+%7E{cA%_~u)5^!2ysV(9lm zl{-BaTihG^SLQHHJ=5M$E!BrfwQg>c0vIZx0Dxg=IRIb<42N#15vVRTQZfdBjAsLQmHsFZ3;_BXz-aoG z_yz!d4PXrYvEmy5^fiES^v9zLfY-3S4*+KRI#f$dU~-})9RS{d>Qj?Y31G5>0sy9< zl=>RYvs!uIIC4l!N6aer(TB!k82;EX2pt{r|$ru1Ko(3=1@0YF~^I7I)j_yz!d4d4j~N;0#)+0XPfYQs+=z>bztO02$8)aDo0s5exwO8o(v` z--vGj(ANOIrT?Ay1^|5x;4=Lyr~=?BwoU-opyJD)OkS6y1HcWaK6Mk70Dh2A0KkuE zr3T<9=$5*L>Qc8QV*to_Hh`b$-x0w8psxY^LjPCs4FLKYz;E<_7vBJ&uL1l)|4&o_ z@E5jD0DnWZ)LkYs;0I5g05U@LsZ6K@kXb?j01u$$0DwnU7U-7Bit19?Bx3-`cs79S z^mB+{0MOR}!szD|-vFSm0py~eTYLk6z6OwoeqK}okPqAY0AQy1p<1c{lLaN|SOJBg z`cz?50w^M(0Dz)sIRN01RSdeNile&JgOV`-WIP){3Hl{PFaYRl0Hx@c7T*A%uK|>y zUsikrfW8J$j{ZZa0^nh6?*o9DJ_6NJ<(aG?Ne6(6P<<*Kl>jP9C;*@`TB!l30^L$o zQC;d$$ru1Ko(-TH{punZ0Q5D08uV+5ZvfEO0BX^%ExrLjUjwK^zb>i(sE4f+Kz*o| zYQSVeNjd;L1{J>qhDrd9BoqM97_HO*G=Yxav=$48_b*bKxF#u#d8$ciW zZV?Os`WiqSeUJDC0DTSMCHj5EHvs5s0R8C4qY8ioZ0`ernf8ZjDKC>sk`4fgP<<*1 zl>m|@6aX*)Ee8NRwG4!AsX?eNHCQqRfQ)AY7()ML5exwO8o(>`ec~Gc^fiE?^oNOW z0MOR}hSMK`DgZ`edmjMI^i`;q8pY&jNjd;nPN zElD~6%!KMwvrq}(Z3zVc%tk9U0PjG@uN9%X)Lh9J05YBp;9dIjL@)s8YXI}TpbCIR*g64x2-Q-HnOq`C2Y{teeQFsh0W6nL0Kf{gQUkCO zI)1GP)uld?i~%6y*#K74Un7D6Kwkq`OaEi>4FLKYz$f%S72g1$uK}#1{~4+PSdXm} zzy_$6+Q{VRl5_w_f$CG6PzhkOgaQDzpp_bctnYx^fiFJ^!K3(fc@Ay0UUs8sV|v4C`ku^ub}$WAyfi5ETI5^ zBWO7Q;JNi{=$1N)>QcufV*to_Hh|;wPl#Xu(ANM?(my4>0YF~^I8FbI_yz!d4d5*O zbEpE~Jht}%z)UYdwbVr>FG06aerCTB!l}6FPnn2NgevBN+og#J0hFg-K?DPU zz6MZ{ez^Ds0DTRh68*~J8vyh*fGYH>q6&aVv2_Bd2Gvs4nXDm62Y{MTeX15J0o0aI z06-nIQUg#Ix~1x&x>SA17yvS!4WI%2h9VdM^fiFT=szyL0YF~^XhgrU_yz!d4WJ4A z2vh;k6k8{NC!kvDNhX^~(gC13RG(^rN&qb-6adf)t<(UthHj}gs4mr3G6sN*X9H+Q zzr6?s0DTRh1O1NT8vyh*fT!qp65jxzuK{$X??M#-Ph;x@@C;P^Br%iEO40$KD^&a> z3Mv6~mrww}b7(mL;Kk*6=$3i`)unn!#sHAJ0eI;v5exwO8bBibB=HRZ`Wiqo{Q=?|0Q5D0f%FHV3V^}b-Uk3P9Rd|U zKg;ARl5_y@LG`Jjs01)fLID87(Mk=#2bBwy@~2l(z6S6%{n_Fh0Q5D0cj(VS6##Rw zbpm)7s-@;JIbV_v01Ke{)O)A|@VJ0W7A!L<9qX zz6P+A{xb0m0Qwrha{4R8Hvs5s04wRQLKOfXVe1628mgt%Fu7Ke4geoR#jnVs62PYt z3IJG#R%!q~gKnwys4lfZG6sN*X9L(s|8o%x0Qwq03jIyu8vyh*fX(!`h;IPU*8sNC z--apxwqxrA@C8&$?O<}JBpm>DLG`KKs06S_LID7K(Mk=#KIoR(kLpqfBx3-`cs779 z=^qrq0HCh{d`16|_yz!d4d5{SBjOtX^fiF5=^sTE0LQR(0yqxUQYV-^DM<%_Q&4^C zG%5j{kx&4@S+r6Ea1Oeq&ZD~21<4oyGM)|KBK=Du7y$G&fN$u3E4~3hUjz7#{$=qE z0Qwrh75Z0E1;90IodCXvYN_i?-jJjNz)h$=^#dva{3xLSfS=HE0Klv3E$EiIjp|ZA zOU3|@@oWHh=>H;u0YF~^_?7-|;u`?;HGtpg{~^8sKwks+lm1_*0^o0K?*o9D-i2zZ z4ERlR{7fmT6F?@Y_`y`yJ|;;AfXAWw zR3lUZXe^-sfF@|A1|R~urJAC;)Dx010AxHHz?1ZwiC_TG*8rN+Zy~+`KwkrBNxzl& z1^|5xpf&wAr~;rZwoU-;pjxUulN}`K0MHStPd$Z70G%Wh0MHq&)Bw1kTk2_4{7W#B zF#u#d8$cKO&x&9G(ANOE(wBd9;T{0=HGuB)pA+8zpsxWuPyYo}0nh_mCx92BS}Kys zo|1F`h=S@<(WnFvBcT9*UTCETAQrl%dZW5jAITU1GM){Bpm>fp!!rYDgg|TPyoO{v{C~w2)d;P zqq@`($ru1Ko(YF1`UkUjrCHeAxX@0YF~^m_&cF z_yz!d4PXlWsp1;|^fiEK^xs4k0MoH`0+<2SQg1OiQ<4qESn+2pAl*GnKBB)`1hb;(>lL+z{#xs(Rmp;~GKlN%-JSW%xt#Xl#4N&uTA6j)K4(Mk=#7U-7Rit19^Bx3-`cs79T z^uG|n0HCh{?4Z9>d;@^K2C$3%Zt)EO`WnC<`g>6Yz&>o90QN(*)Bz^Hl%xZ|L8w0U z6)FK7l28D^VYE^Ma0I%gzD9MaqmnTIWIP+dG5W_vFaYRl04M056yE@#uK}E*e_DJ4 zfW8KBhW=Sp0dNjmCxG)%Ep>s(i;{EzxCGUwzCk5`ZzU7}@EuyI0k{m^Qddx2>Z)W6 z02$8)aE<==A{YSlHGu2%Z-{RI(ANNN(*Hqx1Ax8;@FV@7PzAs(Y@GmZL$%b;Ox}^C z1Hdm(edVRhpyS^=Ma92&Dj5Sn#(sXavFpj*m=>QXOB#ymu2Jo_QqmwrDH%tMsEeu&1?PY~ZcMCt2?Xn*=%@y$b& zzJ7=*`iZD=rzc_Se26AP#lL>Z%z=tOB%%_)yAlcjn1@zs0Omut z)B;qOdQUP2fQ)AYc%S}45exwO8o&qi7m05G(ANMyq`z2v1Ax8;u!R0nQ~|IITPJ|! zP%X8B$(53H09XYTe@H|nfYlNT09b=oY5>+kx75d|F7=6I3;-F=2Jk8Ubs`u5^fiFb z=&u*w0HCh{Y@okUd;@^K2Jkuk6jTAQ30o(C%}_11g~_dwbO6`}6@N%XC4es^6acUT zt<(VQgl?%_s4lfzG6sN*X9L(nf3FAz0DTQ$AN~E}8vyh*fCKcu6yE@#uK^sS{}rkL zIE1Yez+tGCI>O}Fl5_w#3Kf4yL?wXZ5()q~fmUh&PC~cTDO8s_Eg1tq#ulkA{YSlHGo_6Z;NjL(ANNdrhi9#1Ax8; z@C*H4Q3b$n*g66H4%JeBF!`q>9RU7b*Tp=X`fiFEt_qhdLO?Kl7)e+3}i#~pIBwbzE9;q4L-3769p$$dN2>W z1f5vrWXD|Wm|iDVxpB0#oLJ@I9P`Q?<-{r0Oc7^)v=aqQ!hp{T(?ONauXdN2>W1OY9{j-}W! zy?~a+(b5964Ch!@<|shRVaEV{NLm2;Fl`$seknq~N&N^^d_@q|r7B3$2726o>R*w8 za0V)&`hix)zE4#_4F+0O6adwOdDtZg=%egdjUCeqXmuPdEkJ8VSEpj)an zs!O$zi~%6y*#O$oZzqBQKwkrBPrrls1^|5xpd=xYEm^n0NSfLLsu0D42kZ^SU^mZSqf98~;93@QPI7d0T>P)zY&9q--wZn0U+bq0A8g(N(2Ld zz6LOwz9qf^KwkqGLw~IJ1^|5xU>yDNr~=?MY@GmJhl<~bVRE7*9RS{diro+KRr=0nAA#Gn$udlCu&cpt6Q04#)#--toQZ^TH(0Fd!) z03Xs{EP??*UjtY|f2sHe0DTQ$8U5wr8vyh*fEDytq6&aj*g64x1Qov#!{i!CIsmMN zirUc0$E(#Gn$uZV3ed?2#4*-%I=cH>vkQ z#m8Y$@o`v5+9y_lH>tm5;2;BEq54m(4q@M?4xh@H>q*bf{#<2foiF!945|1oS(0yv&a21@sDzmKLB_Imc@J2xcAo<@k#2EP%Tx8$QiM=380*W0xu#Dp_LkdhoM{Q z5mc8dFBtZ7_;1IZWwGM){fA^pchFaYRl0FTpe zB)$PaUjt}Nzlrz;0DTQ0f__s}0q_L2P5@6rwNx`En@iFGpaoQ)YKck!tt1ox&>F4O z0JMQ_skW#t)lM=7fQ)AYXivX`2nGOs4WJ|ar^Gh^=xYF-=yw+10HCh{xadEPDgd6r z)(M~sR7*X}WLHT#0Ca=uQ{7Ps;5i8e06dRYY5-n^t#Ww)xYXIZuj~CwnpsxYEM*nqG0WbkuCxD4iE%gSIlO*W?Fd3>(O+h7q zsS*kRn1)tr0N#Xdsp+ULHA6B6fQ)AYc#Hl_5exwO8o(_2Z;NjL(ANNF(|<>N1Ax8; zFo*tJQ~~fVwoU-^pjv7^lM5v20Pr4EpL!pa02WFp0N?{@Vemz??bp}`-KYBu{p+U` z^&wRJfDfumEs>;srrPz-#qQlLjz*}Z3@l?{IjaA}Y6bRvY9(s$iPb7maAKtg^RP?M ziPcB!xSAc)>%?jej+T}atF@ft$1+DbvHAo%2I!~K0?>7|ZJ?Wu)NJ!WvF}sVXHfBF zRaAUgRgyN)>+i-a+>r-AZNHI$&lyNT^#k36eV^Kl8Vq!cC;+Ml^RP=0(5>vajUCeq z=yn_}EkM8E9CyeZ1?W!f7@)hP1)#fW+dzAK*(zh9<$t88Jy0#R7uBWqNzw)y+qQbe zE>}K}Q2QA;z`&QNexL`jj}Jqm1_M1L3V`atJnRw#^e{UfVaN0W`ZbQ07NAEt$73=_ z0eT!e2IvWC0q9BE_pi`XP%U*D)uqlz^8OWimVt8&oJaKoy?}k6x`-MK^pYq5st5D1 zOAyd+*zsF-OfR6{;b>_AdYN;)B6Ad=SFvM&UXvDpeoxy5`uE+tWq5IV9jc{npt{se zN!maI7pFfk@FN31q56T|!ajag2{jn#&!PaR9?ZioK|t@Y<1g%(UO<1v(b59+H_q{Q znWF&x13L!jpJ=7_s{9vpOZ|=NQgpt8-LbolC|5knwB)73fzK z!2qDI0ff`9B)$PaUjwL2zl!(<0DTRhD*Z=M1wb`yodBvs#nm~JH6`f)Pzx%q&QS@V zj)VdL>Y|kzfO^nzb&iUwbIBM0GM){fA^pchFaYRl0FTpeB)$PaUjt}Nzlrz;0DTQ0 zf__s}0q_L2P5@6r#nm~J%_Zpo&;ly1&QS@Vm4pHSTBDU3fHu%^b&iUwbIBM0GM){f zJ^cKqkU=aMl1WIP){B>kQu7y$G&fGGOW;u`?;HGmlUy~H;F z=xYG6^n0TUfIiqd0l1;!>YPcBBpm==f{LqiR08NHp#Xq*v{D0*03BE7sJJ?pi~%6y z*#H#%L=g-C`WiqI{bcbC0Qwrh0Qv*PHvs5s0E6fcMil@87i*Mne<800bnRp zT%Ds5z;Fo#0E|E@H2@=_D z0MOR}UZejyssNaPtrNgRsJJ?3a*`w+0477l)j28wOqEaoz%*%L@Hc7Oud#t|qfCd2 zpVvioskbC)pI8O{&^METSq!|5>OZlXjeVbb2Q~P_YK|y4vC@Ni*d^%1YA!pz%Z}-F zVl@v(OUsGXe9m!!%u!CP-ouUo`o6RPbRlgUD84tX|IqgVRQ!|_Dt<~zlJ}pfE@of} z14~i;K$l_Pr=FcYB|ENS$MgdF5ssD?psP8@H8Mv5x)wVI=*Q9m z&`)UFKm#9|{1mFC)}gx8XOgsm20k>oo`DSvY((_~{T%!FsxoRY&`qKMs21?ab&<99Mg0eTrb2Iv)O0q9lQW`*iMMPGx8fBXj3rLIfz{uO$Iftw8c zfa(YOBldmjC)8k|w?qL@J(!1Gf`Hy;$Di3Ty@1}q(b59+7tZllnWF&x4Lb(t?`WlV zn)3&AeD53;-#eF#In9yr?9-gT>E9K>oaWHir#TsNO8<Ac6rvUjryezmWLmJdeHxP?&xZ@eKg_8bDF{#ZU!6acrFc9)ybT zoikZdk`4f+pyGSys02_(LID6}(MkVx02$8)P@jGS5exwO8bCw(kBM&p(ANMSr{73?1Ax8;(3pM`Q~?ly ztrI{~sQ6?%lTS+00iYRFe6k&t09r^W0H7sWsR3vO9iMDR#V6Y(V*to_Hh{MD+lgQR z(ANOk)9)a@0YF~^=t%!5@eKg_8bBxdolymV3tK0Er=jAL?M!x&qyxaSQ1QujR08NG zp#XsHXr%_=Iq3LgJ1RceE*S$r#1^|5xAd-Gh@eKg_8bB2NXz>jI`WiqC z{a&a7AQoFEfZkB?$#y2)l5_xwgNjeKqY}VN5()t5i&km?`a#Dh+fnh!cF7n3GM){f zKYgzV1^|5xK+#VW-vFSm0VL5+7T*A%uK^68KM+*_48qn4U@%mCvYpA7CFub03RHZu z9hCrvN++a+TF$apq@SLu%u!2qDI0gR?^iEjYV*8s-QA1l5A zKwkqGM}Iu30C){sCxF+X;*;%6PL!krz#CBU$#zr%m@J_HfGN_#;8SVaud$En-=0WO z)1cyqo>B2b&yuuHtQz=#dtwFyZ!s_v)qi3&3;XyXG1TA_tJ$L9#7Yn5VV9s2t9RIO z4m+mTiPc;jEiETj?{bdwWR7xTH6J?$=mKd0=zFw-PFnsw)IwC3`aqI4(69Yp zKwHGXhYT!6^#fgkeV+dEkIXsjvvV!1?Xz* z7@%vU1)ytb+d%EB{S@^vR7-t=>QbLd(gqs1+F!@OXAG=I^#k32eV^Kh8VvMvQ2=Fd@OLjcSj_C#TD;zB?Ko4<_hh>fe^ayqg z(66Nhphs!jK<%sj2z3l9{%KlNmpUOy8))Ea|0Dya7&wjU2YLqkK6Mr~80a}s08|g= zVV5AF=h^WBJEj-Vi#S?ZfL`JpzmYi#&~LG0fPN<}0KH86{uO!!D*hQJRQxkclDvO~ ze$T*l25zAGf!@SEe&!oB80e3p0H_|!!!AKUe`3d5?3i9aZ{uib0s1rNct_?aK!3rG z0s5=70Q5K72B?1fXn}NHiexMn03DAtF!9X*K0-$;@ z54!{b&CHGuuw!}w&4Qz)1!z{zF`LX$fM&;z0h$A?)Xp2jpj#>@s!QdPj5%+V@$B=) z-1PH^V9p!q>+{CE^z(^t&Kv3L^Tz!23y5#d8|mxw#)9+tTn8vyh*fC}_0q6&a;Z0`l2XIcrW zr7AO7MUoBxRiWaamqI0gY7z$FA{YSl zHGsPG>xpjw(ANO!({CWY0YF~^Xh{DtQ~~ffw)X9-Nz0HCh{w58upd;@^K z2GE{<2UG#j5!-tK=$SqR)l!|9>?}zK02fsJ^HQh;@Qj240J@+R3?LuCHdW6;w^UbD zm+B@N13<>J0d%MToCpR0eGTAw`Y(uY0MOR}deDDSd;@^K1`tWVC#nF5!uDPOdZy7( zEfvFLFG)H8#6rbCFNI0~eIyhB;6^JLfaV$p9lr{PieH73i~%6y*#P>|?*ah&6LnWMa{zJ?tG^mVjSTTv6B z~^nu4_9sHB#>}@IC_zQT;3G1MK7L2B^U+>O)bmqV!-Mb_rThi`j7rJEqr) zT8g8kWkoIH9GA-+Wks#Pjsdz7t<+Z3D(IH_2-T%lOUA4y8P8r(Yv`{P!K^6ydPRLq z{}b`eilVPq)Ti{_ zU`1_}76#u&+g?%6>gT8kwH+#c6adwwc1Y4*Q78Evg|9~JWMCHqyHWisY7h4DV*seZ zD{7x8SW$W~54!}dsQv7CfF09oMSY2*rDa7O;0f{GszKqY{0B@|du-$@IDU#4xZsAv4IdRL(0>KqkU=aLM1)x(PV zo`LHO+(7lOsGHcw)j4YLiuzF$tSCL0hh2hJ)KBbqiyhNzMcu~H(z2p{<{aYT|cl60)7s!(xtj!FR4 zBotUt)un~OYtXh=RD}!8f1EaWZ8KF9s-RjR+AHevjAA3IHqOvY)nTA61NBh- zE2=*BeX0R!@QP|E3RaXJ%)>50E9x*89uvSymf`OI{v_ka*ZH;}OYJ(aK zw5=!rst5D1OAyd@?AV?i(+g+^94#$CJ93Us$s7e}C+rxYouvh!F4{KGfjxIdZ&T%( zs;8migAJ%I)kTsv(EUd?yl`Ptb-c;IKvxF3q56S#$G%TJhZ+p@c~JmV59VQ)AfPX> zV-I#rFQ6~tXlVf&$vO6vISSAy>=>ZY(gM&J+BQ&6rb1n+`O0rny`Wku7S*MCOVS3K z{hPv->x?MTRP|xN%|INgAE*cW_$RVagMs!H1wi#+9(D-=+K(OM*)hFfB?RdAbYVGdYR3$^jTPLV4HBgc^&{~-*4=VfFLrv8n1_m=Q z1l14pW$fcE6Vza!K2ZQv59VQ)AfQ9paTq(M7trB2T3UdP;2cNF90lmB*fBsyNee(n z)3$*gc>DF1XS!mAT2L)D2Gym;O40^e@I=Cn`QKuNj$>dv1Fxa_fxeD?d=~^Y80bV% z08|g=VV5AFZ?NMec1$mzlX0}P0G+})PL(+d&}rB)K;M)WfKI1v1MPV=a(J#xMK`G# zQ1Q>*p}N#eN!mca%e*1$#)ltls%A0pHUqO!{XpNrK7L0LH5lkzQ26JX)q{E1B?#y$cKnDP(+lWo94#$C*Km$& zWsU;$W9%59pGXTpKc#I0tyr+@XZ7E`6rt8Z#g9j#;>RN;X#*|y<#&G_ZF4X}ZD3#{ z1D~V%fu>*||H=VsFwo7S0H_|!!!AKUx3J?@c1$mz+i%Pz&pE$?jNamM&IK=g@7ewCcV_1Iee;`hKIYDultxX9PJ=J#nBG7oEo7aJ6e1c(YZ<|>p3OEH%f*K@d-jR#rG)dx8gG= zp;@*of>0GKL@APk(DO>c3rfNKBJ`rL<_Mve6dy0!e6$gIMSRI7^r}r0gnp+aYlN2H z++}>Ty-oI6uZgSTFnaOnXZ^vJQzLY9;IGtKPnlni%$^xllU&qAQYcD z3B6&vA_!H%LX;vo2>nGVcvC5uUxeNg)*K=9w<jHXm(--VtB23H?8tCJ6maN!AFB zSN``Cv;HowitQ=Hr`VptmQy2iVVk$Uyz-FXn}oG1Nm`6$@a*r&s`E%c&9i_aT>mT3MocCu3?|ZS&DasF(PXP3WUGO%Pf~N!AEeCRr2{ zuX*c=t8ZHCiBCUkeOpeAQ1#G>Ys}g}$A&`5lO(8l6>l(mWY%t>fd+Z92m3KpUi z$wBC2O2KAI!Tch$xv-W?LWhU+7SoJE9<_bc5g*-)&9b%-#Y5@|Ww!Nk@zu}zgw6Zz z#k()-zD%s~5lYaM=x6I)C?WPBOi}l^7 zKldSbo2(8wwIl)eQ+TO28V6oZ)6+Fwub-Sg_GYJ8?^n*S7UwgT`j**zI>J&qzxMnykz09IzmzN9Gmf&b=k6>XG-)eu+}1baqo*_UI5tj; zYrm~tf1&-#ZMxs64@q+wi1V$OwEgl}i&4}nFKpUqx=8|QmzN=Bi*Py;Q6#a9@wOF@9$`6?7?pIIXGXVwH0|g0{CNAm%O(juY{bq}aZ}P*bhX;cG7?X7tN*X%`|f?8 zH=0$oi!Hs^smp8A6Epa&-bIe5T-|{y;o+#2!wE5=|@sIg5ZrjeL zd(MmhV`;0)dVdUYEL@9gM%#hP|2YF;6}xy2jXSH~4}YH9erl+O18F~J?o!|vq*=JY zfixRGMCF0mldzjR7VPH&4y0TU*hjj3q}-(dFG#8%bl8DDJdgL)=@lbiymL82zaO_; z4SlD-Tbi!>4*9V4;87PCyzZvzcGD*hemt;O-xThD?YpjPC#Lh3yBlnuAo@|o-yXj{ zE;^leN=X^w;hw?5+(#~Wabp_a^Zk#V_8-{F6ORonIqp3I{=&O9_^UOMMzJ8ZvD?hD;N_?c5q?U++1 zjTj&Xhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?H zhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn; zfa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~ zw;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX z3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{ z0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU} zMo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?H zhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn; zfa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~ zw;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=id za56CMaF3oxVlMJH--=1wFORi^RqWz9H0~^aIK3@AUcKg(ywEH8Ix||MNU!8@8xEvV zJmFUkJhO0tV^13OTf^m$T0L|fc1^OQ^9@`{!me5E%dKj=09V!-(lzO@-S7(o9$!_j z;x6zLlhx0&)EA$A2Hh+9$rHEdeBc9pS6??vzxDOoL6_6(ZO`sTI%8Y|Og!MpyvbKc`RZ@$b*^wk7xW*d7|o>kJ<``c~_BcBAg`$-#bstVVLn zPu%OIvc%`zeqPZtj-Oum_3x{K7qe@h@2nJ^K9og1G4*nT&`mtL+O3#|$A_^A3O{KS zx6mL3DyhF#p!Ab*Vt^QM0s|D+oIol;5(C6QfiggGtw8A~V!#OuP+W5YsRT(35Ca9u0L8TerJsxw1H^z67@)Z31X2mog&Dwn z+A1$>+I=%THX)Se)8Gzp(;`1-t4=&);& z9R@r*5`H*5uT}OB{)@)7wJ&hFq%-K6bl9$+*W`arn9yIJ9~&RPAR)b~)dGL+q^>_1 zmeEw7ZyP@9*G>MF`{?=Q^g7yn+S(rKeA+(ReA=ij>U`RL+I-sIN~rT`D{Av;x71hX z(^l8!(|(}Mr#-CAr;Sfi=hIHs=F|Fms`F`8p3*qQGqhO#O2iGLR_}S97Mi)<;2vD@ zh!|hTAmhopNj&J!yT-4s!=6cY81U>!*mXFr$o^^ga{Q8Yif6PwK)!Vfkxsuu3sGZI zRfqv%Kw*I58r=xQ05MS57@)XT*jkk8LktiD6xZk;AO?tm!o~o_wZhh-R3Bo17@)XD z_W&_K3=}p7D6SQ@7NzGzni{7eM=MJI3hxCz{19pVC{O?_c3*&st~J2J7ns>X%vl$N(R2 zeSJW{rl7&^4(O%lm(#J|_8;!*`hdRL`hZoN)%5|1+WLSy+WG(wZGFHFZG8at+s68U zue9|68>RID-L>@r>V1{HKHzwp=^JM^SV!^9vD+Zex$#`*89*y2&rdgs2fqZavfcpr zMsa~-Pnv{}dHH%=)9m58S@ZP^-THuX*IvEws&DpsC(qtrwu+h;qQ|o?Zg+;P7-iAp zS;{vh=DxQwd%csKj`dD6wRq;I#j}OddZ*i3Jag0H*+DIyJ)*65TBEIZI;h37y;?k* zs;ze#KertG$TNKY6B($)s-*m9r zl)XqZK;vsgy7{T{#6Y2BfZ|%AY*wliF;Ju#ptx3~o1ZFA3=~QRD6SRCW~E9I14Wtv zifcu>`Kj{6K%r!S;##3>R;m;+P^1~4xK^Z_pDIrb6iNmtt`*8=eF!Us^-gpiO8zzu z=B*mFYWBQT_+>V(z{|1TzzO)adAc;8Hjn4x6qfn46wi!Cl)RNk+{jz$>{4dO1^8t) z?mc)9jBD97%G);$pW+$Cvz|v{E^=DGnYZUCucZ(-@>(9{?uir1IwD-ptAt->MzP^_!Q6b7|&?^ChjDXd0_x?BQIs|EhEQ?d(-mr9e$aOE92dxD{(C^zTR6# z!=reX$9Ps)esA7;PtbQ_pkNuG_XY*aKUpURhyn8%ptxo}r-Voh5Ca9v0L8U}<)5q* z1H^#&3{YG%pHo6428e-zWq{&Z!SYYmi2-83d%;&t zU_JvB*UabCH4(M-PVDM}fjgUxS(6Y_v5V)>xU+`O!EsGC^3Q8+-YPe0QJzO?6c2ux zjVt5Hx>3IcUXJw!8a|Vx>3qxeaFRV{)#k|pU(z{ru}DPxyY6s3=f=PKW|={L`uTIk z+V>yjekoi(-|0^3`jcT9P4)AgVp@;-Z>{Os=iABY+Ie;CRCDFLIyPQC--%z^pqy97 zwe#xO*phpDY!8iPE!Fd#xL2LylY{*Nxpux2Kc}AW#4f4lJMq6I?yMA@K9rAE&v#;* ztK5odczhV2sMI@+Q9LvHB;~C<;zr&|XP1(nr|guxeS=?SNo@{KEvbf7V~?pLKRu`)9ML ze`fSnGgKb;p&<}aSi8UVb&E#v;FsCBGM=m(^;^S3zvy0hZ(MWuIhXyC{kEgSzFGFD zf96>KjMM&S@}28kU5gvcs1^A+Hi`$o%*K`RWZkIW8lE249}t(7xMq?a#4CFeewfw1 zTy#3czwG_pB%MLmq{DV;J-X35gUS;F1;qfp4=pI>$u2QK3^Sc8LLEz*!7XTyqw)X5rHLn`SZT7!KyGI#%lb+-A>9g_Kr2ez;6LD=h0bY5%jF&qaP|THN49 z?a9xvQ9Sr%Hm-~(>qh<7@Yr8S)A^R`A@qBHhni)N_Sc~OHHFzcoDciwdHDhJKn&+Wb)BnpyP{pukn6q|^Qy>i%bj|9%*4FDg$A6n+M1 zo=4$tW2zxBP}CTpxK`9!oa##q6n+LMt`+_^rWz6hMU4T9YelWaslLQO;b(y2TH$YF zsv$8@)EJ<+R@7RY>Prk1eg-J675+A+8WICVjRA^ljV(j3?_x{nqdvhV=oq`L6K8@cd3uuZ1=|oEMCGN;-#*dA5%7IXC{*H_Ht2)6WMk z*1rEJ_e>51>uKi+pK7k2Cp=y|Px#UX^*rH3 z?L6VJCDrqUTWaSCd)3K#p76Qn)$@ce<#K(%>xX;vJQ8z}`LsVC*tFhfmY|AVJcq`e zWsdSe`5j{*ujg3HR&f79%IS{k@C$g5j^fDqfn!g~`5o2qJlm77>nN_AA2{|T>?2(+ z@2C#EFe$DHU6{D8uM2>#Nr&U88qchHyt=;P=I^Win3({7)p+Hr$FG-v-gRXErBQl* zIlZP9&$6!0nY7aC&6;TO?Cz&(Jln0sv)Luocy?~s$kDf2$8+z-YCQ8yZuyCOom3vI z#j|xioNL?SmsDFkQ%U_9pkWs5<_jBNv*|;*EteeCG`bGAn z!_T?wm+ZG49rn$#cW?h};RE_>)IZbW8THTPJ2roG?eDJP`yKrNNTc4vFW^BM#RZN% zX%rv*7<&?SjoJZ@Jqi0rH);=f=zra#FQoq2CT6F9#;d%rY4^?W*o08^_om_J%xF6n z_VQXkh#9S?);~ahz%U;CG8XO{wwJqi0r z&n}s_Z@_y%|4f+3F@CntrWf_`vncg_057@eqIJUlJ(l6rKg-*V=U^$^rw$goFT3{K zU|iU6{|tVajVt5H`h691@PdY?$91h9x(<5|!rqr1IDa8w*Km8FUWik6bUluVd<*l^ zFN_YKIi=YfBlPk8z!USX)f)At{yt##$!e7Xo^PP1tM50%bzM$>RDB=7yLVIGZ}Rum z_W`Wtzm@j^EJ1xAz%Q%sH~Guz`%TufmhwJ;RaM^yu-aXf_nW-TqOX5n6}*^r-J_l# z&`&!*;9Ga~{D7(I_?hAU!F>%*7l;93;2~jv=C3^@4MA=Y1H^!73{YG%jYXAy*G!|5fQSKN;2~jv;@U&f5ab3iKn$41K;emNTNb3c#rp1J z&rUA?edBE>_~{iRU%YcUgWh|ZcCS1bDEz{J$5+*>sK4)iaIU|;ig}Qbe_i6wbF6=& zSA`zT%Eq`m!HDgL5JNxul--A0tZsAhuQXkr%LNCS%)7; zm*-vlb&b>{rf<94)|c`T1H?d4VZe3!Yaw@=tPVM~B;m1Ijh;PIJc)0)8v0Iuw=_Kt7nK&w z?|Q~$_46$C#iyScwHqJ5AR)b~)dFc0ufo%wwa%;!ws`jbvQ^Z)5YuqsAEep%`u;ZB zJ~vAiY?$BSh1nKdhh3xm00;8EdT92MZqy$9*Y~^7@S(#Fba@_ZZtctcu@S%jqxe9c zuJ_N>^z68hsr28{?4}ogb9Lk0JI*opR4?DYQ@62z9d4N~efb&hHgL_^mxLaCwPDcF zHCx%5(7Fd+p0bU7)W2(Loi{$`AFo=l%$DBpqp)&mWjJ3Troi)`v&< zy3UCWZRyt(`-!|%^A=zE3qAa$fz3Bn4%YWK%5A+VA2C1-6g381r@t0n>W#*Mm($pD z%J^9_^BwYG>%pTg=yA8GwOoGH*DLwS6SwDl;A7Mt_R)hhidVnP>PH6nc<<`sc4x?n zQI=h1;Q|NJ+4%ojHoebNCpzi=!>&<&fCDMl1J@zVwddOoxT?Hgm)9NT*1kMGCilnW zbyQ>3b?^G`IJJEC`uj4=uMO=y{s15Mx7}rrZuPdz*sw3A`~PB@+xa`g2WJej9I871 zx#(w3vWnK77dy=wY`Oi$&Ashji{NeJfBpS~fB;JkW&A9Id;PqxVaP{*3F`RSkF0vH zf1POe%OeS1&HQJTh!3-DnA3OC^REuI#D|WEiksZl@OX4?`J#Np05MQl7es@5xb}8SK znET#J^KjuGq}lk`{~2=8jN9v0`GvQG4!cJA0S=^G53}t75BtH&I{e6)p1n>yxAm3Z zSIXmKa(_&$i_fv{T~05j_RrYL5v^FOj7gTa)czT-TW!XUfmsoj6t#cGtIk-_yLZP3 zORCyGV+U5QT(LT2oTaAPKVvaF)2}2fvReMDyx;th|2C)PZ{5bfZaEgZ?SS90xd|zO zEh_kjS}k$!_7B}~ekF~c<+i~oA2C1-*kRz|>95iFnVlwb@?{!7bME*Yjh{Jo1LtD= zOmxYz$H(OUm^{8l<7bZXPq~Q!V&GwApy2vzV{dkf^?v0HTRN-QMgOnP@yUZ95A4-9 zg~rdEe5+{Zc^T%pj<#G>VV&(ID$<=L$-Mgy5Z*x^(6UV2C93h`*uD!HTPF*d#E^&xi9(PoPW5hcr!I= z7A|lg&Bn($kB~!qD)nH{VfVqc&iz&3K+5$n+aB<6?whQ`52Ry1*xc7h&53nMD2I8l z&@t}z~-%jt+?;yec>uDv&TfAendX7S+UHUF+Dez%tC+BNfs_mkP$JyYfc9k|YF z+^F%xv9(z|K4xBsE&bXE@3sD&wzEeo?f7zb(slN7U}*Url~Q^8vcqrI5b3camRyP~ zmdTgz{%4#`uYc~94_du-k)^&n)_>}i11vMT{iA6=ujegyH`qR5(pHLRXecMik=p*l zM&bKtIFLs1`f7N7CgZ|CNVD;AUM=Lus{b51>>A|}#=EyyEpM``r(mV|S}XCwCI*p*@Gif7baL z8#iP8yf#8V^yrI?udF%6gCZiMJ%hLM(uZf{uBx&@t}XYeRmFF`G9{V$45{x^PxQ||ztJ)L z%7nPf@-wq8vLVy1&V9GTaTYdW=c%|UX{=L9$`B9t3qs^pbR%t{+xn6fN__T7ZDvzCz%vQ{=P@0(YnA5Acx%_B zZ7irmOtaFW|JAzV*w<%PPvzUTJl3eXNI(4jJGDmtn8Do^d3UKM^pp*Ys#y+RWVa_) zOt92RWLE!Q%lF;;KHu~GkDc}(pz$-ble1(U?R+ky_!TuANTYZj8lIcUxbP3sZ2Xq0 zzZ`Oiq{FUJet-ih*8}#EZqy#|&J7zm`c~_By?)SP2RhDubJQ;OA456pcL5#iE1_dP zKXf@A?<+CRkMS}5uGITuS{y@vZFSS{wwx2cJ68<;X?#$`MYd!^ivag6S^RY1d%uM4 zKg(8r5%BhZM0({8UY^qiU1P5w+|_NL(A_roEaCmxDIV^VSasRxt-R5!s$Fd3XNkuK zmK^vNjh~^NoF#A6RnHkTiod14iUVmBZ-<6=&}3Zr2Wd9GR~>~14!OQ&P91iQ@&g=5 zxgM~Obfflw=iNAGeW1e*bev!5s9o%%h;rDs2Rhb?L&v%#=yE#dUEqBs#`!TmhToOw zk0FjBu9aBZ;>7jg+u87u!Tv48_}M>d2mE?XjGtYP+a6(kLGM zdc9 z&`!>hiO;L&J{!gVK*NDF3vZ2+aN(!f_~%rA8DvxSTy0&4U9FNjfvyJC{K7D)cxAcA!4do8ybHXq#5h03$MCxn{V~L`nOa;s_DSic=fv;MZ#o?qTDIj0<~3|qoufkk{X&JN z31tuPl+Ld`|4igo?pyxG#}mc)nP1uX-xkcJ@iVlOv*fkUcUFo{AL{UZC;Jnhq~buz z>5l3fHN3qh*%z$oYW-Dc1w`kuH~aR0p0%O3wN~haKo$ z_o(-Sqjqr42+Cn!Yv|ZV5jytmfi9g-dAFrALC>AU5Wk};#i~> z*Lrlj+BQLqpKa{;z1ZsLI%Dn2-&rZf&$9ZZT=AZ9l$~hv&m-a=YjNX$?uU6Yjh~^R zoFpTkn0mQE=qAHFPu6fCjk*h*sZPRm7XE0}Uk2G<(qY#uet;+Q1p7!gY7cnxr20UI z9q9eEc5~D&&Pzf$?9UG!`@urT{$tSPbgZw$x+Ki!$2?ffAH%#0ysyMKKgP%KyAu5| z#4)kihcYjB)IYxa?h6suzVvy2O^ew_dDwGz77r5Zn|vC#ZD;GBMR%NeXvjS3pP`+c zC8Mj|ifMR!nBo4JucwLwY1H5D8r}etap51N+4!5QC_He;&h9yN*fq)za3JM+z&_HA z+5=v4jhyv?4m;4lRqO5WeT8%1;4jV@fsTExp<^FK=yE#NiDP{w)+J#+Kjy(={ut(6 z;C&^=`7u6*-<9Z(A&%jFXzd>UpAHiJv&Lu>C^0FC69r z^&*)&1|}%`uN!~wBIYBCKUHZQGI!SP=Mw$-9C&8o0>_@T?@{hjCaQij$XSvOySYQN zp9?tlB&$f=OxMe^UM3e%IVne0_()Fz7p$_ zFrOduU@?CT^Dgkd665?BAH(lT^v4j#1P@Ao&HAKAv0@R{r^IJqNXz2Iip_}(6#ow@ zA5(Eiw!pH|({R+66a%HNnl=Mx->GW^JqBJgwl^>}3RH26;wVpHVJU z8yXxj+!`@vOz1ebp<^P$BHXM#BB8t6ENeR<{jBZ9XHb(6<)X(d>)(}&7Td)(>cy~G z)(*=0.2.0,<0.3', + 'nbsphinx>=0.5.0,<0.7', 'Sphinx>=1.7.1,<3', 'sphinx_rtd_theme>=0.2.4,<0.5', 'autodocsumm>=0.1.10', diff --git a/tutorials/02_Single_Table_Modeling.ipynb b/tutorials/02_Single_Table_Modeling.ipynb index 0f5649a3e..fa536bd3d 100644 --- a/tutorials/02_Single_Table_Modeling.ipynb +++ b/tutorials/02_Single_Table_Modeling.ipynb @@ -248,17 +248,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-09 20:41:29,959 - INFO - table - Loading transformer CategoricalTransformer for field country\n", - "2020-07-09 20:41:29,961 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", - "2020-07-09 20:41:29,962 - INFO - table - Loading transformer NumericalTransformer for field age\n", - "2020-07-09 20:41:29,981 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" ] } ], @@ -293,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -310,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -345,71 +345,71 @@ " 0\n", " 0\n", " USA\n", - " F\n", - " 29\n", + " M\n", + " 38\n", " \n", " \n", " 1\n", " 1\n", - " FR\n", - " F\n", - " 44\n", + " UK\n", + " NaN\n", + " 23\n", " \n", " \n", " 2\n", " 2\n", - " UK\n", - " M\n", - " 38\n", + " USA\n", + " F\n", + " 34\n", " \n", " \n", " 3\n", " 3\n", " ES\n", - " M\n", - " 19\n", + " NaN\n", + " 47\n", " \n", " \n", " 4\n", " 4\n", - " USA\n", + " ES\n", " F\n", - " 55\n", + " 29\n", " \n", " \n", " 5\n", " 5\n", - " USA\n", - " NaN\n", - " 27\n", + " UK\n", + " F\n", + " 39\n", " \n", " \n", " 6\n", " 6\n", - " USA\n", - " F\n", - " 27\n", + " FR\n", + " NaN\n", + " 40\n", " \n", " \n", " 7\n", " 7\n", - " UK\n", - " NaN\n", - " 7\n", + " ES\n", + " M\n", + " 38\n", " \n", " \n", " 8\n", " 8\n", - " FR\n", + " ES\n", " F\n", - " 55\n", + " 32\n", " \n", " \n", " 9\n", " 9\n", - " UK\n", - " NaN\n", - " 43\n", + " ES\n", + " F\n", + " 36\n", " \n", " \n", "\n", @@ -417,19 +417,19 @@ ], "text/plain": [ " user_id country gender age\n", - "0 0 USA F 29\n", - "1 1 FR F 44\n", - "2 2 UK M 38\n", - "3 3 ES M 19\n", - "4 4 USA F 55\n", - "5 5 USA NaN 27\n", - "6 6 USA F 27\n", - "7 7 UK NaN 7\n", - "8 8 FR F 55\n", - "9 9 UK NaN 43" + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" ] }, - "execution_count": 11, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -450,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -484,37 +484,37 @@ " \n", " 0\n", " 0\n", - " BG\n", - " M\n", - " 30\n", + " UK\n", + " F\n", + " 48\n", " \n", " \n", " 1\n", " 1\n", " USA\n", " NaN\n", - " 61\n", + " 38\n", " \n", " \n", " 2\n", " 2\n", " USA\n", " M\n", - " 35\n", + " 29\n", " \n", " \n", " 3\n", " 3\n", - " ES\n", - " NaN\n", - " 34\n", + " BG\n", + " M\n", + " 22\n", " \n", " \n", " 4\n", " 4\n", " USA\n", - " F\n", - " 41\n", + " M\n", + " 43\n", " \n", " \n", "\n", @@ -522,14 +522,14 @@ ], "text/plain": [ " user_id country gender age\n", - "0 0 BG M 30\n", - "1 1 USA NaN 61\n", - "2 2 USA M 35\n", - "3 3 ES NaN 34\n", - "4 4 USA F 41" + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" ] }, - "execution_count": 12, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -552,14 +552,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-09 20:48:49,870 - INFO - __init__ - Loading table census\n" + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" ] } ], @@ -578,7 +578,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "metadata": { "scrolled": false }, @@ -739,7 +739,7 @@ "4 0 0 40 Cuba <=50K " ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -766,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -789,54 +789,55 @@ "metadata": {}, "source": [ "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", - "time to finish, especially on non-GPU enabled systems." + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-09 20:51:21,668 - INFO - table - Loading transformer NumericalTransformer for field age\n", - "2020-07-09 20:51:21,669 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", - "2020-07-09 20:51:21,670 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", - "2020-07-09 20:51:21,670 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", - "2020-07-09 20:51:21,671 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", - "2020-07-09 20:51:21,672 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", - "2020-07-09 20:51:21,672 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", - "2020-07-09 20:51:21,673 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", - "2020-07-09 20:51:21,673 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", - "2020-07-09 20:51:21,674 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", - "2020-07-09 20:51:21,674 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", - "2020-07-09 20:51:21,675 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", - "2020-07-09 20:51:21,675 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", - "2020-07-09 20:51:21,676 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", - "2020-07-09 20:51:21,678 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1, Loss G: 1.9379, Loss D: -0.1297\n", - "Epoch 2, Loss G: 1.6529, Loss D: 0.1716\n", - "Epoch 3, Loss G: 1.0939, Loss D: 0.0208\n", - "Epoch 4, Loss G: 0.9390, Loss D: -0.0918\n", - "Epoch 5, Loss G: -0.0550, Loss D: 0.0415\n", - "Epoch 6, Loss G: -0.0864, Loss D: -0.0104\n", - "Epoch 7, Loss G: -0.3865, Loss D: 0.0074\n", - "Epoch 8, Loss G: -1.0150, Loss D: -0.0708\n", - "Epoch 9, Loss G: -0.8685, Loss D: -0.1289\n", - "Epoch 10, Loss G: -1.1524, Loss D: -0.0487\n" + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" ] } ], "source": [ - "model.fit(census)" + "model.fit(census.sample(1000))" ] }, { @@ -851,7 +852,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -868,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -912,182 +913,182 @@ " \n", " \n", " 0\n", - " 38\n", - " Private\n", - " 212843\n", - " Prof-school\n", - " 10\n", - " Never-married\n", - " Tech-support\n", - " Unmarried\n", + " 50\n", + " Local-gov\n", + " 169719\n", + " 1st-4th\n", + " 9\n", + " Widowed\n", + " ?\n", + " Husband\n", " White\n", " Male\n", - " 49\n", - " -1\n", - " 40\n", - " United-States\n", + " 114\n", + " 8\n", + " 38\n", + " Columbia\n", " <=50K\n", " \n", " \n", " 1\n", - " 30\n", - " Federal-gov\n", - " 321819\n", - " HS-grad\n", - " 14\n", - " Married-civ-spouse\n", - " Craft-repair\n", - " Unmarried\n", + " 32\n", + " ?\n", + " 152479\n", + " 1st-4th\n", + " 9\n", + " Never-married\n", + " Adm-clerical\n", + " Wife\n", " Black\n", - " Female\n", - " 51\n", - " -3\n", - " 39\n", - " United-States\n", - " <=50K\n", + " Male\n", + " -42\n", + " 20\n", + " 21\n", + " Jamaica\n", + " >50K\n", " \n", " \n", " 2\n", - " 25\n", + " 22\n", " Private\n", - " 169771\n", - " HS-grad\n", - " 13\n", - " Never-married\n", - " Other-service\n", - " Not-in-family\n", + " 69617\n", + " Bachelors\n", + " 0\n", + " Separated\n", + " ?\n", + " Husband\n", " White\n", - " Female\n", - " -37\n", - " -2\n", - " 39\n", - " United-States\n", + " Male\n", + " 6\n", + " 11\n", + " 38\n", + " Guatemala\n", " <=50K\n", " \n", " \n", " 3\n", - " 47\n", - " Private\n", - " 116751\n", - " Some-college\n", - " 6\n", - " Divorced\n", + " 25\n", + " ?\n", + " 652858\n", + " 10th\n", + " 16\n", + " Married-civ-spouse\n", " Handlers-cleaners\n", " Not-in-family\n", " White\n", - " Male\n", - " 0\n", - " 2057\n", - " 41\n", - " United-States\n", + " Female\n", + " 152\n", + " -27\n", + " 39\n", + " Cuba\n", " <=50K\n", " \n", " \n", " 4\n", - " 38\n", + " 43\n", " Private\n", - " 315119\n", - " Assoc-acdm\n", - " 13\n", - " Never-married\n", - " Craft-repair\n", - " Husband\n", - " Black\n", + " 301956\n", + " Some-college\n", + " 8\n", + " Married-civ-spouse\n", + " ?\n", + " Wife\n", + " White\n", " Male\n", - " -11\n", - " -1\n", - " 49\n", - " United-States\n", + " -133\n", + " -12\n", + " 39\n", + " India\n", " <=50K\n", " \n", " \n", " 5\n", - " 36\n", - " State-gov\n", - " 172646\n", - " Masters\n", - " 6\n", - " Married-civ-spouse\n", - " Machine-op-inspct\n", - " Wife\n", - " White\n", - " Male\n", - " -50\n", - " -2\n", + " 66\n", + " Private\n", + " 401171\n", + " Prof-school\n", + " 13\n", + " Separated\n", + " Protective-serv\n", + " Unmarried\n", + " Black\n", + " Female\n", + " -124\n", + " -1\n", " 40\n", - " Japan\n", + " Cuba\n", " <=50K\n", " \n", " \n", " 6\n", - " 40\n", + " 52\n", " Private\n", - " 163368\n", - " Some-college\n", - " 4\n", - " Married-civ-spouse\n", - " Transport-moving\n", - " Own-child\n", - " White\n", - " Female\n", - " -13\n", - " 5\n", - " 39\n", - " United-States\n", + " 278399\n", + " Bachelors\n", + " 12\n", + " Never-married\n", + " Prof-specialty\n", + " Unmarried\n", + " Other\n", + " Male\n", + " 122567\n", + " -6\n", + " 47\n", + " Columbia\n", " <=50K\n", " \n", " \n", " 7\n", - " 23\n", - " Private\n", - " 369324\n", - " Some-college\n", - " 11\n", - " Married-civ-spouse\n", - " Tech-support\n", - " Husband\n", + " 36\n", + " Federal-gov\n", + " 229817\n", + " HS-grad\n", + " 8\n", + " Married-AF-spouse\n", + " Farming-fishing\n", + " Not-in-family\n", " White\n", - " Female\n", - " -60\n", - " 5\n", - " 40\n", - " United-States\n", - " <=50K\n", + " Male\n", + " 8\n", + " 19\n", + " 38\n", + " Portugal\n", + " >50K\n", " \n", " \n", " 8\n", - " 32\n", - " Private\n", - " 192521\n", - " Bachelors\n", - " 10\n", - " Married-civ-spouse\n", + " 27\n", + " Federal-gov\n", + " 306972\n", + " Some-college\n", + " 8\n", + " Never-married\n", " Exec-managerial\n", - " Not-in-family\n", - " White\n", + " Husband\n", + " Asian-Pac-Islander\n", " Female\n", - " -15\n", - " 2\n", - " 43\n", - " United-States\n", - " <=50K\n", + " 42144\n", + " 3\n", + " 39\n", + " Japan\n", + " >50K\n", " \n", " \n", " 9\n", " 28\n", - " Private\n", - " 244118\n", - " HS-grad\n", - " 9\n", - " Separated\n", - " Prof-specialty\n", - " Not-in-family\n", + " Local-gov\n", + " 416161\n", + " 1st-4th\n", + " 8\n", + " Divorced\n", + " Adm-clerical\n", + " Unmarried\n", " White\n", - " Male\n", - " -21\n", - " 0\n", - " 39\n", - " United-States\n", + " Female\n", + " -349\n", + " 1090\n", + " 61\n", + " Guatemala\n", " >50K\n", " \n", " \n", @@ -1096,43 +1097,55 @@ ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", - "0 38 Private 212843 Prof-school 10 \n", - "1 30 Federal-gov 321819 HS-grad 14 \n", - "2 25 Private 169771 HS-grad 13 \n", - "3 47 Private 116751 Some-college 6 \n", - "4 38 Private 315119 Assoc-acdm 13 \n", - "5 36 State-gov 172646 Masters 6 \n", - "6 40 Private 163368 Some-college 4 \n", - "7 23 Private 369324 Some-college 11 \n", - "8 32 Private 192521 Bachelors 10 \n", - "9 28 Private 244118 HS-grad 9 \n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", "\n", - " marital-status occupation relationship race sex \\\n", - "0 Never-married Tech-support Unmarried White Male \n", - "1 Married-civ-spouse Craft-repair Unmarried Black Female \n", - "2 Never-married Other-service Not-in-family White Female \n", - "3 Divorced Handlers-cleaners Not-in-family White Male \n", - "4 Never-married Craft-repair Husband Black Male \n", - "5 Married-civ-spouse Machine-op-inspct Wife White Male \n", - "6 Married-civ-spouse Transport-moving Own-child White Female \n", - "7 Married-civ-spouse Tech-support Husband White Female \n", - "8 Married-civ-spouse Exec-managerial Not-in-family White Female \n", - "9 Separated Prof-specialty Not-in-family White Male \n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", "\n", - " capital-gain capital-loss hours-per-week native-country income \n", - "0 49 -1 40 United-States <=50K \n", - "1 51 -3 39 United-States <=50K \n", - "2 -37 -2 39 United-States <=50K \n", - "3 0 2057 41 United-States <=50K \n", - "4 -11 -1 49 United-States <=50K \n", - "5 -50 -2 40 Japan <=50K \n", - "6 -13 5 39 United-States <=50K \n", - "7 -60 5 40 United-States <=50K \n", - "8 -15 2 43 United-States <=50K \n", - "9 -21 0 39 United-States >50K " + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " ] }, - "execution_count": 24, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -1156,16 +1169,16 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "-28.22245129315498" + "-144.971907591418" ] }, - "execution_count": 25, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } From bf2b3940af23034988eb7dc9a3d58f0d2ebf9186 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:29:58 +0200 Subject: [PATCH 29/33] Add pandoc dependency --- .github/workflows/docs.yml | 1 + .github/workflows/tests.yml | 4 ++-- .travis.yml | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 12d63d205..02e92fd67 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,6 +18,7 @@ jobs: - name: Build run: | + sudo apt install pandoc python -m pip install --upgrade pip pip install -e .[dev] make docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0f882e819..7357cc190 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,12 +24,12 @@ jobs: - if: matrix.os == 'ubuntu-latest' name: Install graphviz - Ubuntu run: | - sudo apt-get install graphviz + sudo apt-get install graphviz pandoc - if: matrix.os == 'macos-latest' name: Install graphviz - MacOS run: | - brew install graphviz + brew install graphviz pandoc - name: Install dependencies run: | diff --git a/.travis.yml b/.travis.yml index 8204e8dfc..8345682c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: # Command to install dependencies install: - sudo apt-get update - - sudo apt-get install graphviz + - sudo apt-get install graphviz pandoc - pip install -U tox-travis codecov after_success: codecov From 815111bd1d7c15b1976d2ac2f608db052d852e33 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:39:57 +0200 Subject: [PATCH 30/33] Update the tutorials --- docs/index.rst | 2 +- docs/metadata.rst | 4 +- docs/tutorials/04_Working_with_Metadata.ipynb | 463 +++++++++--------- docs/tutorials/demo_metadata.json | 24 +- docs/tutorials/sdv.pkl | Bin 255975 -> 0 bytes tutorials/04_Working_with_Metadata.ipynb | 463 +++++++++--------- 6 files changed, 465 insertions(+), 491 deletions(-) delete mode 100644 docs/tutorials/sdv.pkl diff --git a/docs/index.rst b/docs/index.rst index c0f4a2cc6..b79a5add0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,12 +12,12 @@ tutorials/02_Single_Table_Modeling tutorials/03_Relational_Data_Modeling + metadata .. toctree:: :caption: Resources :hidden: - metadata API Reference contributing history diff --git a/docs/metadata.rst b/docs/metadata.rst index c9f4d8607..d7f7dd933 100644 --- a/docs/metadata.rst +++ b/docs/metadata.rst @@ -1,5 +1,5 @@ -Metadata -======== +Working with Metadata +===================== In order to use **SDV** you will need a ``Metadata`` object alongside your data. diff --git a/docs/tutorials/04_Working_with_Metadata.ipynb b/docs/tutorials/04_Working_with_Metadata.ipynb index 91b001036..5c79f8ebf 100644 --- a/docs/tutorials/04_Working_with_Metadata.ipynb +++ b/docs/tutorials/04_Working_with_Metadata.ipynb @@ -1,21 +1,72 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Metadata\n", + "\n", + "In order to work with complex dataset structures you will need to pass additional information\n", + "about you data to SDV.\n", + "\n", + "This is done by using a Metadata.\n", + "\n", + "Let's have a quick look at how to do it.\n", + "\n", + "## Load demo data\n", + "\n", + "We will load the demo dataset included in SDV for the rest of the session." + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from sdv import load_demo" + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A part from the tables dict, this is returning a Metadata object that contains all the information\n", + "about our dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "metadata, tables = load_demo(metadata=True)" + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" ] }, { @@ -58,60 +109,22 @@ ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "tables" + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -121,141 +134,92 @@ ] }, { - "cell_type": "code", - "execution_count": 6, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.add_table('users', data=tables['users'], primary_key='user_id')" + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "new_meta.add_table('sessions', data=tables['sessions'], primary_key='session_id',\n", - " parent='users', foreign_key='user_id')" + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "transactions_fields = {\n", - " 'timestamp': {\n", - " 'type': 'datetime',\n", - " 'format': '%Y-%m-%d'\n", - " }\n", - "}\n", - "new_meta.add_table('transactions', tables['transactions'], fields_metadata=transactions_fields,\n", - " primary_key='transaction_id', parent='sessions')" + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", - " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}},\n", - " 'primary_key': 'user_id'},\n", - " 'sessions': {'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'users', 'field': 'user_id'}},\n", - " 'os': {'type': 'categorical'},\n", - " 'device': {'type': 'categorical'}},\n", - " 'primary_key': 'session_id'},\n", - " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", - " 'format': '%Y-%m-%d'},\n", - " 'session_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", - " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'approved': {'type': 'boolean'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'}},\n", - " 'primary_key': 'transaction_id'}}}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "new_meta.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "new_meta.to_dict() == metadata.to_dict()" + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 11, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.to_json('demo_metadata.json')" + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "loaded = Metadata('demo_metadata.json')" + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 13, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "loaded.to_dict() == new_meta.to_dict()" + "Let's see what our Metadata looks like right now:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -278,10 +242,10 @@ "\n", "users\n", "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", "\n", "Primary key: user_id\n", "\n", @@ -291,10 +255,10 @@ "\n", "sessions\n", "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", "\n", "Primary key: session_id\n", "Foreign key (users): user_id\n", @@ -303,7 +267,7 @@ "\n", "users->sessions\n", "\n", - "\n", + "\n", "   sessions.user_id -> users.user_id\n", "\n", "\n", @@ -312,11 +276,11 @@ "\n", "transactions\n", "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", "\n", "Primary key: transaction_id\n", "Foreign key (sessions): session_id\n", @@ -325,114 +289,137 @@ "\n", "sessions->transactions\n", "\n", - "\n", + "\n", "   transactions.session_id -> sessions.session_id\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metadata.visualize()" + "new_meta.visualize()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "gender : categorical\n", - "user_id : id - integer\n", - "country : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "os : categorical\n", - "device : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "timestamp : datetime\n", - "session_id : id - integer\n", - "transaction_id : id - integer\n", - "approved : boolean\n", - "amount : numerical - float\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], "text/plain": [ - "" + "True" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_meta.visualize()" + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" ] } ], diff --git a/docs/tutorials/demo_metadata.json b/docs/tutorials/demo_metadata.json index 2be7ecc2e..f59aa96da 100644 --- a/docs/tutorials/demo_metadata.json +++ b/docs/tutorials/demo_metadata.json @@ -2,6 +2,9 @@ "tables": { "users": { "fields": { + "gender": { + "type": "categorical" + }, "age": { "type": "numerical", "subtype": "integer" @@ -10,9 +13,6 @@ "type": "id", "subtype": "integer" }, - "gender": { - "type": "categorical" - }, "country": { "type": "categorical" } @@ -24,6 +24,9 @@ "os": { "type": "categorical" }, + "device": { + "type": "categorical" + }, "user_id": { "type": "id", "subtype": "integer", @@ -35,9 +38,6 @@ "session_id": { "type": "id", "subtype": "integer" - }, - "device": { - "type": "categorical" } }, "primary_key": "session_id" @@ -48,9 +48,9 @@ "type": "datetime", "format": "%Y-%m-%d" }, - "transaction_id": { - "type": "id", - "subtype": "integer" + "amount": { + "type": "numerical", + "subtype": "float" }, "session_id": { "type": "id", @@ -63,9 +63,9 @@ "approved": { "type": "boolean" }, - "amount": { - "type": "numerical", - "subtype": "float" + "transaction_id": { + "type": "id", + "subtype": "integer" } }, "primary_key": "transaction_id" diff --git a/docs/tutorials/sdv.pkl b/docs/tutorials/sdv.pkl deleted file mode 100644 index d3c88b5a4eca4045c109b46b6393d46f5be7a46d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 255975 zcmd@7cYIV;+x?GIq=nu=x(b9My($otUWHKvR18T7Oh!nUorI=C6Anc|njnZEponw? zR0IrFI*5Suj!F?wLG1Ef*V<=hGn{xk`S`r@`{U>H+}AmKt+n^r^PW93=S*gX;hAF+ zdkv_D|Ajr>{Mj&XhDsy68GYVN;oY)i$dI95La*4qULE0nVzror{>go#62n?WB_}4v zMa4H`JIwn)w_N@XkuMF5QhgG=S$y8CRC~oGCaJjSp)0@}l&DSj}tfZ*uzOjkk{6231eZ!LzV^yNJpwC+!`+3ug&}K;a(lwvxDl1Y-QhlOWenvm!On7+`s?YnV2E`lh zt>!Mm^+Jx`>h6Ncfu5Xjuh;=`F|pnnK5tDuVIh)`=&j}R)~0~T$?{23QSpgUF|tk) zy>)!vx-!8$!G=t-o=mcSxVxUu+u)u_V$Kc2-I&~C;qGegaCdcgg?lC_2 z_3Q6_+~;j1sAOefRmF@}y=Q)xW<=7!cdb+sx-}j$_DriB3r98ygkxZIOt3BC%URNRe(V4;9H9 z-5bZ0CD+pDZ58e=7^8Y6Rdd*3t);^=0!in?ysfe1+xWa~!woktMa3t?5AGK?B-Y!hqrSPxLxv3Yb{@x79@96jzt=U6 zH?tpZdt7`U@6$f-GhBwUz}0}LzKPy0KJT;Pd@OWn=j|HaE<>w`L$&!|tB9@soeyX7 zza(!rF19XDKjG~jeoycmxZ%H65yzRE)hc3Me*Tx_eU4+dc-s3sI-Fw+xo>;(zg7_! z{f7v2c!3>Sy1YHmA&|SGXV5C*vVVxl{v*7|4$WJ5Bhev{`+hI~YZY;gN2t{*BG92H zJ3QIS8-)&m+|9fEuT{hi|J(u{qLY)O-R0auKgQ?n74GiT(cO8R+cnM&ps{Ws>>bW0 zvzFdIBpAT3%g_9AVGtc|;0pgxk!+zB_`Cc49_$laHSG!e2pu)SVIY6CE} zfj;k`a6T>ej>9alc9UbhgJW>3`VYp-pNg&4FBvn4QYvb2SR$U5qg0r8i1x!ckudMe zyyJT}@V+u`>=>`ln8sO9);p2YW}2w<#>(6I@h$FO`6op!Y(Y!~2*+vKt&13XtiT&Ky6(58MP-hu{|ZydM(W;&Aul z&gJTVi3)3Ozg7h;+a(~s)aP9m?oM#`$I5XhyJaOW_jy-@y9Z;v3~iSEuaosQZ5_O} znRlh-UFGtAB(YKB+@swS-1Y)+H+9QHbG6UACNR)Fk2k)=g?ZPyliX|V1AOfBeiAsq z`5}dtx60qOsrOUMyUykPER_Q^bKWrfbusKo{rVVmv({ra8+_i4m<=B*=9Yc#^QPdy z0X6uBkoP8VbT7qiTJ3J;-r#4xnaq><@Q#de$Mx+MiI*(z7Mzs!?HjAlJi>a#_Kw0s zyjNTdo?*6{XPm@Xw6~4+Zg+XVz>BndTMzFJcZI0Y-kmP*E<6Cx-tDe~_8ymaue*-4 z_l@@McXz15AO4!XQwh3l?|+y&9~u*-Wycm0|}A9Z<;=|deK z?LFc0p2T|6ue5S*jNKgH$LyT`UDy0B)t;)MJC4LtmnIqGAz%YA|vHW zATrW>xspb1UmHNV;`3gWTq19cey&gSUhCogUhcT2?q;!MW-byEz1MqqZ-l$EzCZ$gas&V3JwUo7vh zF7I!Al9gxlSnuyT@w&C-((2OLn|c4Rynnj9f9Z3;NPS)OcT8eT99|P}LYb6UEmE#> z;v?hYd#fm1?j*+~C9BvlH6XILN*EFwALhMVX@tsv3lX2nh^j9TgpSIDW+gt^zcjE{ zqRPz71CnvyasTH2-HpS_O93yVvOo_$ZcIWvUJsKKk`wh562szgG8u`BgfNv=h904^ zameg4q`h|f$uqag0i|6=zAoWK42@QI?Wa>}(JH5C!~B{|Vo*{-UzJPb##u+bZOyCl zHt*=@s9t=bRJqy1K9S(GRRe+LALYA2-D9vR1by=^!YLWfpVqWSS8x|R<1-WZtBV&>hRG2D+Ja2u6 zXKFTVSi8O|3}Lh?g6dL5QGKcyYB*m8qvMihQyRI#|D^uh>Jr2@gQ_+I)p$q za;L;>pS`wnt!m|Sw6d}d{q*QJFHE0ZWySmxT@OEW>iohRS<1ff*!KD@-fZ6dZ1~g@ zn~rvFczDYL8~V&FmwDEsV~&p=ez(;xyK?QG-6C6^&R>+<7}ICOnlGoe+&it3JG{?_ zeJyqPdhW%$TRz;m!{1-DS=H&eHxjFtZ@=7AK2P5IAG#;3oAS%B`u*>W?-lv`A8T?H zXn*L`)nWseEZ=$j^70+$>UJ6Z{p#LxqhE|2o3&@kM~z#2RO^X9tM{He=%a?)7FQpf zYtUaAXWsp3|M0~XW}f<(j`%&HBsehR;mjF!1ulk!|-i{5YX-$3m+!AH30^a^{N1?pCke+vEE3iIMF} zwtRJ+>*SK3@(Pd*ACf;qIxZ$+y=oc&Jh5{pCt8-FRYIt6Ya3 zTfeIC+FO~j6{x#3eBQ<_3CHV#_gXp z&2n`4l#X?l^~nBOx0i}P`=G1D>Y8C&&o?}_p-=3-(vN>Oqjdet8A{BnfBX57LmsZ) z|Gntpo10WSyZJ!E?V`g7OJCvx2_IVVB#s?7752X8)Bds(X%b4zrdIJaNls@3CvFFEwV@ZDLy z`t`-z0~)MpJt_9$kt11ZcdNB*Vulro)u;b9de@nZ9ZDA%+~hWRGFq$kK-e{ zmdL#K&b9hCnjB27Q^jg@V%wBAdiv(Q@Jdw5%^SY>i0chzcdz--(Xi9+sA=)*s$5yQ z;Aa1QH$UG~v&YkA+&jKXs8PILch{9NC%zc`W?+Vhi=*UsA zGDWofY~9K&v%0ozRm{q;_C&*#)wi8#lY9NL1)G0cI`2sT1A8(T z`g`L$U;mvlp-$l`qc&tnywUH8H9Jc@5&c<*>0K(H+1z=~%?VRad(T9?R4kWoQKrm! zyZu^xO4NbMuic%tcE+LCdR{+t{r7C&d$!MAabomicNR2AD5AO--xJqnu<7d>l1wkp$Fj2)c&aMN@3j#T(5xt290a>A{tUliI9 z`Qe2;RkP-4(&OS&t3K~AZORR+#zPINPF;KTPOGmzS{MZEvm;kJ-JN`8^=D_kt?%pBI3efqPV=^RUi#GL_tx|s z_|^XR^Y+}lVE03p*8KVFtR-clD_$78uh=UW*LGNR?AtRFi@w%yRoH>os|dcovGL-$JXY{$92tG?smubj*k1= zz5J(o?N5K){zjSpy$5f8`Ndi9hGSc zHob$d-=sb7=Iweb>D_D>cXVCf`})KAx@3yjern0Ohj;cp7q|V93nee_Z%}B#XG6;M zoLHjKr4iq)&+HjjaaWn)3!nV<#a4Z*<~|Lv%M8xf6~x*H!LllWn`Xff1GH2aBj=dW8NtEZo3S#jxJdqon`v& zw_on^^k1#FZQMO=)F&Ipj%j`Ql?T5sKHyTv1zl!6(6e}(B9E--bhvfH{X3%S{JM8p z>E)Aa?ke46_twY(6O)T)xb@Aj<5jMYA76R*z9z-yx6jpa$H)`gkN@!c-~-+p?`Mx0 zb*%cx+uNI5eKW_JbvZNFS^Y`ED?3|$+Pd4RJ98hK-lEgTYsTd`k#ug<{?|L?e0;(a zn=a3~`gD!W^NweFrt+?D`@HsJ*7XZ>*BaODT*2Qzz5UwqZC{+bb70R?z6WYtDlnpD zwlU{+p1B^g`24Uk8_M3ESvqHq?`m$_UFT+=0_)#dR`yCnt4YnvZb{s+rSy!?e>{=5 ze#UZRy8ZC;l9yDz3*#0IcqGS)AseUUc{C}1jpeQL#$6n7a7}~HCw=thH^YAznLK=9 zv8+`(9m;&H(NBZt9LZ2JtlX?-^#@u-3cBCe8-4YSS_i+GH)6}gA*I$=ex~@8eS0bm z-~ZXA=SLnGm*=->%|9FEnU%A{%hmU7nYFuT=AIYE{giD|V(Y<$lj_fX{FgT(DiA_DX>3t*#SS7%29t#$4A;Hm!Cc`~(D}b**ZK7YA)h7TS{W6NbGal{ zNzUl_E@}USL{%BB4v`s?`KGHoDlw9OQc+bPx6`MsTBr)46;4;$gAR4r-U_Nm&m?JF+- zdjzTuR7=%Gb*Xxiv@f{c?ewPi`lY%}Rec5;FwhV++5NX$uAm;nzE3@l8t&nRlSf}K zH4?>r$D;@Hu!~17mOS<~P-Av%!jAg7=YQR3PtYZfC+MoiBNsCsxrXuZvc7{Bi9p_ia`Arww`-|Hbvg2owX?ozd4L7hWF!m6l5y`ZDWj?7Gx5sJP5(m$_YLOLalRJ#<{N)~e(V zZ^<1$MhyOH@r}+yzKke3EX!|qu1`tvw_(!nQorJyZ~9(K@%#R^zrT7GBPOY?sL}qr zoo}4t?HDB&+&YvBLYi-`YgvZf>>Lrm6GJwr*>k};z~Ox9GW~To%%n(E^ z^)jkUy&~P@&n0*hEk@qb?Vo@@M=2jZy627{JEqZpQv?H_z6Ly<{tWRANctMnTl8m&Z`jb+ z7-rFb8&xpO#x_+9@4&Fs9ERsg*n#0)s6I6hl^EtrD1c!BS~dpxQw{DFxo4WH_n=$q zeN>lPC>g^*#!e;6wV0#WxJ-YX(c`FBRV~psyJ$qrV(g7_7jS41#95 z5~`(EG5L`s9ROBC^{F+e1h7^@0RSIM3xj_`+dj8>x{LhxO^W&ys-@PUy3}Wqw9jo$ z+&Mh|8S94#wVr_u3~WU8pWA$neV^hfQT;#{L=6UdP!s^wgL&8`2=>ZOqy?bIY1=@b)hjeYoq%eolc+9rN|H9vlU$+aBGhRH&M_I>I+YB10X zq5!BK%)>50Krgc6C3Z|Npx@wVX#x5z=lGq>QGi~?jsbc_S^#>L_Wcj}YfvrqJ*rDx zm!u8!xF6^Z25vI&1F9eBkJ$IApHPE=-Vy~s^x0?_Z&J8OVT7=DAcxR6o#6*!QW- zsKG!V5CuT>U>*%9pt@9EN!mca?=q><{7spfseBCNXP^M8A80}B`&1#+V4#IX0Z=`dhh2hz z7GcMt?3i9ai{WT#0a~1Md{E{nKuci904*sk04+t^20A2T%z|ED=W41-L$y>HRF^6% zNgL?(X6r83xLK>ID#yS>3_OhL2l@#1eX2ZaFwhF30H_|!!!AKUE3#uaJEj-VN;q0t zfL7)ltH>M$XjSYOppQxmK&#QVfv)^@{EA;sbZx4tL$y>5RF|qLNgHUrL)F4+ZLZc- z)ncGF19edSK=FdD0XsHi$MgdF7>Plb*U#MX#=f4Bw=`yx@DTGW(+iEparTQXiMz- zR4deApshs#P(7H3U4nqNVaK-Ym|j5J;b>_A+MaXlAafL;9kFA8J|!&x?L^xKTE5Mc zN7_ue6rnmpwUi6hrJk0g4YYfc-=Ew*^=O28hJh{&Jd5fF+7QcQVX#;(6a-D2TPUdf_Vj1YoKp#{;P&f8{Dh@Rms7Dk4)q{E1B?#zC?AVtb(+g-n z94#$C<2lC!nWF&hj~xTlD=h$3v~8fcO2U7*vP^_(sU%dFN|vMzw6%)3J#^RQ2sMC# zfeZ{n^#dJ@eV-bF8VvMhQ2*&n8h9W5$?Wnq2>v&`iGT0^H{1XJJ$}EEJ0{DYGR{svdwlMQ z^AYj<)5ZCe1icNkS4_{?YWljX5f#sPkJhQaB%;Ik8DmDSemkXE^x&=y#`{v9t(I_W z?6sW{*$PzO+voYSDYKhBaLLtZ$%cQrKmEsx{i|)p=TpA2kAE=bJm+yR(4O~c`CAQ! zoQybHYs%fR3qDV&|97hvVU3nYlsmiqRK+S=Qz}$%J_UYpvOhKEszrX&IjkNivo86=4|Fm-YmXwLz8uu-+V^+keVr!aY9=b2$^05`i zhCl30>Bif8EJfemn-L`|Jv}C&qz2xZliM9XnxT)M;-ah~F{kQ-B`u6(Qi~s)fkMF;I{%`ul{1~yf|A~5s|2Q5|aPXx(?xgb( zbD!xuZ1C+15&v}m`X3Me_32+P{`=40_TN7L@%)d!AxWA4=(Om3%J|jYzKV(YA*J*e z*;*f3y)H_{p-`eUi|l;pQr!!`H$ycf9h4dC?=^1OBb)AR1mBi|e#7FY=7d}(@L-1<||IZZv2>f~oshVoSzoq!c z--#ilYN|IvNYzr4{w-Bo{@gq{gj7v6C4^KhH8q4(Z8h!RQgvirZ-$VnsiudJs-Pj2J8A*5=m zc_E}~srezKYO4hyr0S^mLP*tB?}w18rxyNOs=ko;AcRy+wJ3yCE%jjtsoH9B2&p=1 zNeHRBYH0|mdTLn+srqX9zoi-ouN5JrYO0kXq-v>EA*5=nk3vY*QL95p)m3XkNYzto zLrB$EABT`?pg#GxR70WqX$YyBYF!AaTI#b9Qnl6k5K?v2h7eM9)y5D~_0;Dfr0T1b z5K;})rVvsM)n-ZYLnZ%krcjfg8QT&{tftx;O01UJ7D}wP+8#=*j`|{$SY5Rvlvq8r zGn80;wJVfZ1GPJpSVOf(Vx|Ab`I_7dHEXE7|2iB`+4hAJtEu*f604;Sgc7T*z6>Q+ zM;#0$R#$x$O01qb6iTeVIvh%@fjSaOtfBf^V*m5r!o7fd>u5;1n(ElU9+oH1$3uzL zR3}1-)lw%ziPcu8LW$K;r$dR=RcAtp)l+9fiPcxG4A5+fZUP)pw!9YN^Yi#A>T6p~UK_tD(f|s%xRd>Z$KTiPcxv zLy0v|H$sUuR5vB|Kkt3q3%K`x2q{-n{TNcNmij5ATy1qra{u#nnwOQgLy6T?KZg>l zrS60htF3+sC00lM8cM9L`Yn`LJ@tDivHI$dP+|?#pP|GWs=p-mKUWFv1+0?4L(0`u zcSFk6QW@|K^M5~XZIv;kTpg83a*y11z7h0cdHs`J`h(?}C0!wuwEjr>1ClTMzw!x* zdgLsUeCXejdep3vZSude|MtxXeavh!PSgK3j-FBW|1={xQ`0lb@t!*? zVDjCFod0!37)Q@2*MFLkocHJ%<^E4I(&OkE<@v8OvKi(5uQLiT%J*MqgqtV#TK@k! zBaEXN6_AnIq~$RZ^yCrzO&dMAf-+>g|Fa?WJPZAw4a)KKObg4P9sVz7Y7eUCT119z z`TsoZJ>MwO(=00Ewf?^uPtUN}|ILW^pe!yUw*Aj@4C46Uf1G9D`%ietXp$xV<1B+m z)GSN>UyXSW%~CRE`~P>Efq4e;EG?rpNXw|^n_-$?85!ZR|JM;Tud*^mk^9Cl2U6wW z^DAf%VMl)j?P2Ws)FW);V?E`uV?NeXL0b4&Pes~6AM3%#@x!57suHS8RhFdvv7R5g z?s$Ag&aWa=6$Yv@@F=SPV?EWd?^D%LgFn_&Llk_hM-S##2>S zh4`@^exsxg=U7+fD4&z6haCg7zO(?e0qy&NHiT-a$537BaY@=h--@ex=2-WG5vmab zjTvZy>IWKueV=NI8VvLaQ2cbpv`f#v;b|vIkuEJ3eZ;AF+f{O z3qae@wt@P822!XpF4bO=HqeXxcdieQ4{xeEFwl{Kr%?SsJ7M3aI->>yb%_F? zdN2>W1Oa`T9iL&x^a9!iM@tLPXF12NGDiU_zkXnVc1J7tTXc9j8Qw%a2i;Q7qq@`! zk}=<+lkx0t(e#^oR1B5s~zJif_I}M_+%7E{cA%_~u)5^!2ysV(9lm zl{-BaTihG^SLQHHJ=5M$E!BrfwQg>c0vIZx0Dxg=IRIb<42N#15vVRTQZfdBjAsLQmHsFZ3;_BXz-aoG z_yz!d4PXrYvEmy5^fiES^v9zLfY-3S4*+KRI#f$dU~-})9RS{d>Qj?Y31G5>0sy9< zl=>RYvs!uIIC4l!N6aer(TB!k82;EX2pt{r|$ru1Ko(3=1@0YF~^I7I)j_yz!d4d4j~N;0#)+0XPfYQs+=z>bztO02$8)aDo0s5exwO8o(v` z--vGj(ANOIrT?Ay1^|5x;4=Lyr~=?BwoU-opyJD)OkS6y1HcWaK6Mk70Dh2A0KkuE zr3T<9=$5*L>Qc8QV*to_Hh`b$-x0w8psxY^LjPCs4FLKYz;E<_7vBJ&uL1l)|4&o_ z@E5jD0DnWZ)LkYs;0I5g05U@LsZ6K@kXb?j01u$$0DwnU7U-7Bit19?Bx3-`cs79S z^mB+{0MOR}!szD|-vFSm0py~eTYLk6z6OwoeqK}okPqAY0AQy1p<1c{lLaN|SOJBg z`cz?50w^M(0Dz)sIRN01RSdeNile&JgOV`-WIP){3Hl{PFaYRl0Hx@c7T*A%uK|>y zUsikrfW8J$j{ZZa0^nh6?*o9DJ_6NJ<(aG?Ne6(6P<<*Kl>jP9C;*@`TB!l30^L$o zQC;d$$ru1Ko(-TH{punZ0Q5D08uV+5ZvfEO0BX^%ExrLjUjwK^zb>i(sE4f+Kz*o| zYQSVeNjd;L1{J>qhDrd9BoqM97_HO*G=Yxav=$48_b*bKxF#u#d8$ciW zZV?Os`WiqSeUJDC0DTSMCHj5EHvs5s0R8C4qY8ioZ0`ernf8ZjDKC>sk`4fgP<<*1 zl>m|@6aX*)Ee8NRwG4!AsX?eNHCQqRfQ)AY7()ML5exwO8o(>`ec~Gc^fiE?^oNOW z0MOR}hSMK`DgZ`edmjMI^i`;q8pY&jNjd;nPN zElD~6%!KMwvrq}(Z3zVc%tk9U0PjG@uN9%X)Lh9J05YBp;9dIjL@)s8YXI}TpbCIR*g64x2-Q-HnOq`C2Y{teeQFsh0W6nL0Kf{gQUkCO zI)1GP)uld?i~%6y*#K74Un7D6Kwkq`OaEi>4FLKYz$f%S72g1$uK}#1{~4+PSdXm} zzy_$6+Q{VRl5_w_f$CG6PzhkOgaQDzpp_bctnYx^fiFJ^!K3(fc@Ay0UUs8sV|v4C`ku^ub}$WAyfi5ETI5^ zBWO7Q;JNi{=$1N)>QcufV*to_Hh|;wPl#Xu(ANM?(my4>0YF~^I8FbI_yz!d4d5*O zbEpE~Jht}%z)UYdwbVr>FG06aerCTB!l}6FPnn2NgevBN+og#J0hFg-K?DPU zz6MZ{ez^Ds0DTRh68*~J8vyh*fGYH>q6&aVv2_Bd2Gvs4nXDm62Y{MTeX15J0o0aI z06-nIQUg#Ix~1x&x>SA17yvS!4WI%2h9VdM^fiFT=szyL0YF~^XhgrU_yz!d4WJ4A z2vh;k6k8{NC!kvDNhX^~(gC13RG(^rN&qb-6adf)t<(UthHj}gs4mr3G6sN*X9H+Q zzr6?s0DTRh1O1NT8vyh*fT!qp65jxzuK{$X??M#-Ph;x@@C;P^Br%iEO40$KD^&a> z3Mv6~mrww}b7(mL;Kk*6=$3i`)unn!#sHAJ0eI;v5exwO8bBibB=HRZ`Wiqo{Q=?|0Q5D0f%FHV3V^}b-Uk3P9Rd|U zKg;ARl5_y@LG`Jjs01)fLID87(Mk=#2bBwy@~2l(z6S6%{n_Fh0Q5D0cj(VS6##Rw zbpm)7s-@;JIbV_v01Ke{)O)A|@VJ0W7A!L<9qX zz6P+A{xb0m0Qwrha{4R8Hvs5s04wRQLKOfXVe1628mgt%Fu7Ke4geoR#jnVs62PYt z3IJG#R%!q~gKnwys4lfZG6sN*X9L(s|8o%x0Qwq03jIyu8vyh*fX(!`h;IPU*8sNC z--apxwqxrA@C8&$?O<}JBpm>DLG`KKs06S_LID7K(Mk=#KIoR(kLpqfBx3-`cs779 z=^qrq0HCh{d`16|_yz!d4d5{SBjOtX^fiF5=^sTE0LQR(0yqxUQYV-^DM<%_Q&4^C zG%5j{kx&4@S+r6Ea1Oeq&ZD~21<4oyGM)|KBK=Du7y$G&fN$u3E4~3hUjz7#{$=qE z0Qwrh75Z0E1;90IodCXvYN_i?-jJjNz)h$=^#dva{3xLSfS=HE0Klv3E$EiIjp|ZA zOU3|@@oWHh=>H;u0YF~^_?7-|;u`?;HGtpg{~^8sKwks+lm1_*0^o0K?*o9D-i2zZ z4ERlR{7fmT6F?@Y_`y`yJ|;;AfXAWw zR3lUZXe^-sfF@|A1|R~urJAC;)Dx010AxHHz?1ZwiC_TG*8rN+Zy~+`KwkrBNxzl& z1^|5xpf&wAr~;rZwoU-;pjxUulN}`K0MHStPd$Z70G%Wh0MHq&)Bw1kTk2_4{7W#B zF#u#d8$cKO&x&9G(ANOE(wBd9;T{0=HGuB)pA+8zpsxWuPyYo}0nh_mCx92BS}Kys zo|1F`h=S@<(WnFvBcT9*UTCETAQrl%dZW5jAITU1GM){Bpm>fp!!rYDgg|TPyoO{v{C~w2)d;P zqq@`($ru1Ko(YF1`UkUjrCHeAxX@0YF~^m_&cF z_yz!d4PXlWsp1;|^fiEK^xs4k0MoH`0+<2SQg1OiQ<4qESn+2pAl*GnKBB)`1hb;(>lL+z{#xs(Rmp;~GKlN%-JSW%xt#Xl#4N&uTA6j)K4(Mk=#7U-7Rit19^Bx3-`cs79T z^uG|n0HCh{?4Z9>d;@^K2C$3%Zt)EO`WnC<`g>6Yz&>o90QN(*)Bz^Hl%xZ|L8w0U z6)FK7l28D^VYE^Ma0I%gzD9MaqmnTIWIP+dG5W_vFaYRl04M056yE@#uK}E*e_DJ4 zfW8KBhW=Sp0dNjmCxG)%Ep>s(i;{EzxCGUwzCk5`ZzU7}@EuyI0k{m^Qddx2>Z)W6 z02$8)aE<==A{YSlHGu2%Z-{RI(ANNN(*Hqx1Ax8;@FV@7PzAs(Y@GmZL$%b;Ox}^C z1Hdm(edVRhpyS^=Ma92&Dj5Sn#(sXavFpj*m=>QXOB#ymu2Jo_QqmwrDH%tMsEeu&1?PY~ZcMCt2?Xn*=%@y$b& zzJ7=*`iZD=rzc_Se26AP#lL>Z%z=tOB%%_)yAlcjn1@zs0Omut z)B;qOdQUP2fQ)AYc%S}45exwO8o&qi7m05G(ANMyq`z2v1Ax8;u!R0nQ~|IITPJ|! zP%X8B$(53H09XYTe@H|nfYlNT09b=oY5>+kx75d|F7=6I3;-F=2Jk8Ubs`u5^fiFb z=&u*w0HCh{Y@okUd;@^K2Jkuk6jTAQ30o(C%}_11g~_dwbO6`}6@N%XC4es^6acUT zt<(VQgl?%_s4lfzG6sN*X9L(nf3FAz0DTQ$AN~E}8vyh*fCKcu6yE@#uK^sS{}rkL zIE1Yez+tGCI>O}Fl5_w#3Kf4yL?wXZ5()q~fmUh&PC~cTDO8s_Eg1tq#ulkA{YSlHGo_6Z;NjL(ANNdrhi9#1Ax8; z@C*H4Q3b$n*g66H4%JeBF!`q>9RU7b*Tp=X`fiFEt_qhdLO?Kl7)e+3}i#~pIBwbzE9;q4L-3769p$$dN2>W z1f5vrWXD|Wm|iDVxpB0#oLJ@I9P`Q?<-{r0Oc7^)v=aqQ!hp{T(?ONauXdN2>W1OY9{j-}W! zy?~a+(b5964Ch!@<|shRVaEV{NLm2;Fl`$seknq~N&N^^d_@q|r7B3$2726o>R*w8 za0V)&`hix)zE4#_4F+0O6adwOdDtZg=%egdjUCeqXmuPdEkJ8VSEpj)an zs!O$zi~%6y*#O$oZzqBQKwkrBPrrls1^|5xpd=xYEm^n0NSfLLsu0D42kZ^SU^mZSqf98~;93@QPI7d0T>P)zY&9q--wZn0U+bq0A8g(N(2Ld zz6LOwz9qf^KwkqGLw~IJ1^|5xU>yDNr~=?MY@GmJhl<~bVRE7*9RS{diro+KRr=0nAA#Gn$udlCu&cpt6Q04#)#--toQZ^TH(0Fd!) z03Xs{EP??*UjtY|f2sHe0DTQ$8U5wr8vyh*fEDytq6&aj*g64x1Qov#!{i!CIsmMN zirUc0$E(#Gn$uZV3ed?2#4*-%I=cH>vkQ z#m8Y$@o`v5+9y_lH>tm5;2;BEq54m(4q@M?4xh@H>q*bf{#<2foiF!945|1oS(0yv&a21@sDzmKLB_Imc@J2xcAo<@k#2EP%Tx8$QiM=380*W0xu#Dp_LkdhoM{Q z5mc8dFBtZ7_;1IZWwGM){fA^pchFaYRl0FTpe zB)$PaUjt}Nzlrz;0DTQ0f__s}0q_L2P5@6rwNx`En@iFGpaoQ)YKck!tt1ox&>F4O z0JMQ_skW#t)lM=7fQ)AYXivX`2nGOs4WJ|ar^Gh^=xYF-=yw+10HCh{xadEPDgd6r z)(M~sR7*X}WLHT#0Ca=uQ{7Ps;5i8e06dRYY5-n^t#Ww)xYXIZuj~CwnpsxYEM*nqG0WbkuCxD4iE%gSIlO*W?Fd3>(O+h7q zsS*kRn1)tr0N#Xdsp+ULHA6B6fQ)AYc#Hl_5exwO8o(_2Z;NjL(ANNF(|<>N1Ax8; zFo*tJQ~~fVwoU-^pjv7^lM5v20Pr4EpL!pa02WFp0N?{@Vemz??bp}`-KYBu{p+U` z^&wRJfDfumEs>;srrPz-#qQlLjz*}Z3@l?{IjaA}Y6bRvY9(s$iPb7maAKtg^RP?M ziPcB!xSAc)>%?jej+T}atF@ft$1+DbvHAo%2I!~K0?>7|ZJ?Wu)NJ!WvF}sVXHfBF zRaAUgRgyN)>+i-a+>r-AZNHI$&lyNT^#k36eV^Kl8Vq!cC;+Ml^RP=0(5>vajUCeq z=yn_}EkM8E9CyeZ1?W!f7@)hP1)#fW+dzAK*(zh9<$t88Jy0#R7uBWqNzw)y+qQbe zE>}K}Q2QA;z`&QNexL`jj}Jqm1_M1L3V`atJnRw#^e{UfVaN0W`ZbQ07NAEt$73=_ z0eT!e2IvWC0q9BE_pi`XP%U*D)uqlz^8OWimVt8&oJaKoy?}k6x`-MK^pYq5st5D1 zOAyd+*zsF-OfR6{;b>_AdYN;)B6Ad=SFvM&UXvDpeoxy5`uE+tWq5IV9jc{npt{se zN!maI7pFfk@FN31q56T|!ajag2{jn#&!PaR9?ZioK|t@Y<1g%(UO<1v(b59+H_q{Q znWF&x13L!jpJ=7_s{9vpOZ|=NQgpt8-LbolC|5knwB)73fzK z!2qDI0ff`9B)$PaUjwL2zl!(<0DTRhD*Z=M1wb`yodBvs#nm~JH6`f)Pzx%q&QS@V zj)VdL>Y|kzfO^nzb&iUwbIBM0GM){fA^pchFaYRl0FTpeB)$PaUjt}Nzlrz;0DTQ0 zf__s}0q_L2P5@6r#nm~J%_Zpo&;ly1&QS@Vm4pHSTBDU3fHu%^b&iUwbIBM0GM){f zJ^cKqkU=aMl1WIP){B>kQu7y$G&fGGOW;u`?;HGmlUy~H;F z=xYG6^n0TUfIiqd0l1;!>YPcBBpm==f{LqiR08NHp#Xq*v{D0*03BE7sJJ?pi~%6y z*#H#%L=g-C`WiqI{bcbC0Qwrh0Qv*PHvs5s0E6fcMil@87i*Mne<800bnRp zT%Ds5z;Fo#0E|E@H2@=_D z0MOR}UZejyssNaPtrNgRsJJ?3a*`w+0477l)j28wOqEaoz%*%L@Hc7Oud#t|qfCd2 zpVvioskbC)pI8O{&^METSq!|5>OZlXjeVbb2Q~P_YK|y4vC@Ni*d^%1YA!pz%Z}-F zVl@v(OUsGXe9m!!%u!CP-ouUo`o6RPbRlgUD84tX|IqgVRQ!|_Dt<~zlJ}pfE@of} z14~i;K$l_Pr=FcYB|ENS$MgdF5ssD?psP8@H8Mv5x)wVI=*Q9m z&`)UFKm#9|{1mFC)}gx8XOgsm20k>oo`DSvY((_~{T%!FsxoRY&`qKMs21?ab&<99Mg0eTrb2Iv)O0q9lQW`*iMMPGx8fBXj3rLIfz{uO$Iftw8c zfa(YOBldmjC)8k|w?qL@J(!1Gf`Hy;$Di3Ty@1}q(b59+7tZllnWF&x4Lb(t?`WlV zn)3&AeD53;-#eF#In9yr?9-gT>E9K>oaWHir#TsNO8<Ac6rvUjryezmWLmJdeHxP?&xZ@eKg_8bDF{#ZU!6acrFc9)ybT zoikZdk`4f+pyGSys02_(LID6}(MkVx02$8)P@jGS5exwO8bCw(kBM&p(ANMSr{73?1Ax8;(3pM`Q~?ly ztrI{~sQ6?%lTS+00iYRFe6k&t09r^W0H7sWsR3vO9iMDR#V6Y(V*to_Hh{MD+lgQR z(ANOk)9)a@0YF~^=t%!5@eKg_8bBxdolymV3tK0Er=jAL?M!x&qyxaSQ1QujR08NG zp#XsHXr%_=Iq3LgJ1RceE*S$r#1^|5xAd-Gh@eKg_8bB2NXz>jI`WiqC z{a&a7AQoFEfZkB?$#y2)l5_xwgNjeKqY}VN5()t5i&km?`a#Dh+fnh!cF7n3GM){f zKYgzV1^|5xK+#VW-vFSm0VL5+7T*A%uK^68KM+*_48qn4U@%mCvYpA7CFub03RHZu z9hCrvN++a+TF$apq@SLu%u!2qDI0gR?^iEjYV*8s-QA1l5A zKwkqGM}Iu30C){sCxF+X;*;%6PL!krz#CBU$#zr%m@J_HfGN_#;8SVaud$En-=0WO z)1cyqo>B2b&yuuHtQz=#dtwFyZ!s_v)qi3&3;XyXG1TA_tJ$L9#7Yn5VV9s2t9RIO z4m+mTiPc;jEiETj?{bdwWR7xTH6J?$=mKd0=zFw-PFnsw)IwC3`aqI4(69Yp zKwHGXhYT!6^#fgkeV+dEkIXsjvvV!1?Xz* z7@%vU1)ytb+d%EB{S@^vR7-t=>QbLd(gqs1+F!@OXAG=I^#k32eV^Kh8VvMvQ2=Fd@OLjcSj_C#TD;zB?Ko4<_hh>fe^ayqg z(66Nhphs!jK<%sj2z3l9{%KlNmpUOy8))Ea|0Dya7&wjU2YLqkK6Mr~80a}s08|g= zVV5AF=h^WBJEj-Vi#S?ZfL`JpzmYi#&~LG0fPN<}0KH86{uO!!D*hQJRQxkclDvO~ ze$T*l25zAGf!@SEe&!oB80e3p0H_|!!!AKUe`3d5?3i9aZ{uib0s1rNct_?aK!3rG z0s5=70Q5K72B?1fXn}NHiexMn03DAtF!9X*K0-$;@ z54!{b&CHGuuw!}w&4Qz)1!z{zF`LX$fM&;z0h$A?)Xp2jpj#>@s!QdPj5%+V@$B=) z-1PH^V9p!q>+{CE^z(^t&Kv3L^Tz!23y5#d8|mxw#)9+tTn8vyh*fC}_0q6&a;Z0`l2XIcrW zr7AO7MUoBxRiWaamqI0gY7z$FA{YSl zHGsPG>xpjw(ANO!({CWY0YF~^Xh{DtQ~~ffw)X9-Nz0HCh{w58upd;@^K z2GE{<2UG#j5!-tK=$SqR)l!|9>?}zK02fsJ^HQh;@Qj240J@+R3?LuCHdW6;w^UbD zm+B@N13<>J0d%MToCpR0eGTAw`Y(uY0MOR}deDDSd;@^K1`tWVC#nF5!uDPOdZy7( zEfvFLFG)H8#6rbCFNI0~eIyhB;6^JLfaV$p9lr{PieH73i~%6y*#P>|?*ah&6LnWMa{zJ?tG^mVjSTTv6B z~^nu4_9sHB#>}@IC_zQT;3G1MK7L2B^U+>O)bmqV!-Mb_rThi`j7rJEqr) zT8g8kWkoIH9GA-+Wks#Pjsdz7t<+Z3D(IH_2-T%lOUA4y8P8r(Yv`{P!K^6ydPRLq z{}b`eilVPq)Ti{_ zU`1_}76#u&+g?%6>gT8kwH+#c6adwwc1Y4*Q78Evg|9~JWMCHqyHWisY7h4DV*seZ zD{7x8SW$W~54!}dsQv7CfF09oMSY2*rDa7O;0f{GszKqY{0B@|du-$@IDU#4xZsAv4IdRL(0>KqkU=aLM1)x(PV zo`LHO+(7lOsGHcw)j4YLiuzF$tSCL0hh2hJ)KBbqiyhNzMcu~H(z2p{<{aYT|cl60)7s!(xtj!FR4 zBotUt)un~OYtXh=RD}!8f1EaWZ8KF9s-RjR+AHevjAA3IHqOvY)nTA61NBh- zE2=*BeX0R!@QP|E3RaXJ%)>50E9x*89uvSymf`OI{v_ka*ZH;}OYJ(aK zw5=!rst5D1OAyd@?AV?i(+g+^94#$CJ93Us$s7e}C+rxYouvh!F4{KGfjxIdZ&T%( zs;8migAJ%I)kTsv(EUd?yl`Ptb-c;IKvxF3q56S#$G%TJhZ+p@c~JmV59VQ)AfPX> zV-I#rFQ6~tXlVf&$vO6vISSAy>=>ZY(gM&J+BQ&6rb1n+`O0rny`Wku7S*MCOVS3K z{hPv->x?MTRP|xN%|INgAE*cW_$RVagMs!H1wi#+9(D-=+K(OM*)hFfB?RdAbYVGdYR3$^jTPLV4HBgc^&{~-*4=VfFLrv8n1_m=Q z1l14pW$fcE6Vza!K2ZQv59VQ)AfQ9paTq(M7trB2T3UdP;2cNF90lmB*fBsyNee(n z)3$*gc>DF1XS!mAT2L)D2Gym;O40^e@I=Cn`QKuNj$>dv1Fxa_fxeD?d=~^Y80bV% z08|g=VV5AFZ?NMec1$mzlX0}P0G+})PL(+d&}rB)K;M)WfKI1v1MPV=a(J#xMK`G# zQ1Q>*p}N#eN!mca%e*1$#)ltls%A0pHUqO!{XpNrK7L0LH5lkzQ26JX)q{E1B?#y$cKnDP(+lWo94#$C*Km$& zWsU;$W9%59pGXTpKc#I0tyr+@XZ7E`6rt8Z#g9j#;>RN;X#*|y<#&G_ZF4X}ZD3#{ z1D~V%fu>*||H=VsFwo7S0H_|!!!AKUx3J?@c1$mz+i%Pz&pE$?jNamM&IK=g@7ewCcV_1Iee;`hKIYDultxX9PJ=J#nBG7oEo7aJ6e1c(YZ<|>p3OEH%f*K@d-jR#rG)dx8gG= zp;@*of>0GKL@APk(DO>c3rfNKBJ`rL<_Mve6dy0!e6$gIMSRI7^r}r0gnp+aYlN2H z++}>Ty-oI6uZgSTFnaOnXZ^vJQzLY9;IGtKPnlni%$^xllU&qAQYcD z3B6&vA_!H%LX;vo2>nGVcvC5uUxeNg)*K=9w<jHXm(--VtB23H?8tCJ6maN!AFB zSN``Cv;HowitQ=Hr`VptmQy2iVVk$Uyz-FXn}oG1Nm`6$@a*r&s`E%c&9i_aT>mT3MocCu3?|ZS&DasF(PXP3WUGO%Pf~N!AEeCRr2{ zuX*c=t8ZHCiBCUkeOpeAQ1#G>Ys}g}$A&`5lO(8l6>l(mWY%t>fd+Z92m3KpUi z$wBC2O2KAI!Tch$xv-W?LWhU+7SoJE9<_bc5g*-)&9b%-#Y5@|Ww!Nk@zu}zgw6Zz z#k()-zD%s~5lYaM=x6I)C?WPBOi}l^7 zKldSbo2(8wwIl)eQ+TO28V6oZ)6+Fwub-Sg_GYJ8?^n*S7UwgT`j**zI>J&qzxMnykz09IzmzN9Gmf&b=k6>XG-)eu+}1baqo*_UI5tj; zYrm~tf1&-#ZMxs64@q+wi1V$OwEgl}i&4}nFKpUqx=8|QmzN=Bi*Py;Q6#a9@wOF@9$`6?7?pIIXGXVwH0|g0{CNAm%O(juY{bq}aZ}P*bhX;cG7?X7tN*X%`|f?8 zH=0$oi!Hs^smp8A6Epa&-bIe5T-|{y;o+#2!wE5=|@sIg5ZrjeL zd(MmhV`;0)dVdUYEL@9gM%#hP|2YF;6}xy2jXSH~4}YH9erl+O18F~J?o!|vq*=JY zfixRGMCF0mldzjR7VPH&4y0TU*hjj3q}-(dFG#8%bl8DDJdgL)=@lbiymL82zaO_; z4SlD-Tbi!>4*9V4;87PCyzZvzcGD*hemt;O-xThD?YpjPC#Lh3yBlnuAo@|o-yXj{ zE;^leN=X^w;hw?5+(#~Wabp_a^Zk#V_8-{F6ORonIqp3I{=&O9_^UOMMzJ8ZvD?hD;N_?c5q?U++1 zjTj&Xhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?H zhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn; zfa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~ zw;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX z3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{ z0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU} zMo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?H zhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn; zfa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~ zw;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=gX3=jh{0~FU}Mo1?Hhyhn;fa02~w;=id za56CMaF3oxVlMJH--=1wFORi^RqWz9H0~^aIK3@AUcKg(ywEH8Ix||MNU!8@8xEvV zJmFUkJhO0tV^13OTf^m$T0L|fc1^OQ^9@`{!me5E%dKj=09V!-(lzO@-S7(o9$!_j z;x6zLlhx0&)EA$A2Hh+9$rHEdeBc9pS6??vzxDOoL6_6(ZO`sTI%8Y|Og!MpyvbKc`RZ@$b*^wk7xW*d7|o>kJ<``c~_BcBAg`$-#bstVVLn zPu%OIvc%`zeqPZtj-Oum_3x{K7qe@h@2nJ^K9og1G4*nT&`mtL+O3#|$A_^A3O{KS zx6mL3DyhF#p!Ab*Vt^QM0s|D+oIol;5(C6QfiggGtw8A~V!#OuP+W5YsRT(35Ca9u0L8TerJsxw1H^z67@)Z31X2mog&Dwn z+A1$>+I=%THX)Se)8Gzp(;`1-t4=&);& z9R@r*5`H*5uT}OB{)@)7wJ&hFq%-K6bl9$+*W`arn9yIJ9~&RPAR)b~)dGL+q^>_1 zmeEw7ZyP@9*G>MF`{?=Q^g7yn+S(rKeA+(ReA=ij>U`RL+I-sIN~rT`D{Av;x71hX z(^l8!(|(}Mr#-CAr;Sfi=hIHs=F|Fms`F`8p3*qQGqhO#O2iGLR_}S97Mi)<;2vD@ zh!|hTAmhopNj&J!yT-4s!=6cY81U>!*mXFr$o^^ga{Q8Yif6PwK)!Vfkxsuu3sGZI zRfqv%Kw*I58r=xQ05MS57@)XT*jkk8LktiD6xZk;AO?tm!o~o_wZhh-R3Bo17@)XD z_W&_K3=}p7D6SQ@7NzGzni{7eM=MJI3hxCz{19pVC{O?_c3*&st~J2J7ns>X%vl$N(R2 zeSJW{rl7&^4(O%lm(#J|_8;!*`hdRL`hZoN)%5|1+WLSy+WG(wZGFHFZG8at+s68U zue9|68>RID-L>@r>V1{HKHzwp=^JM^SV!^9vD+Zex$#`*89*y2&rdgs2fqZavfcpr zMsa~-Pnv{}dHH%=)9m58S@ZP^-THuX*IvEws&DpsC(qtrwu+h;qQ|o?Zg+;P7-iAp zS;{vh=DxQwd%csKj`dD6wRq;I#j}OddZ*i3Jag0H*+DIyJ)*65TBEIZI;h37y;?k* zs;ze#KertG$TNKY6B($)s-*m9r zl)XqZK;vsgy7{T{#6Y2BfZ|%AY*wliF;Ju#ptx3~o1ZFA3=~QRD6SRCW~E9I14Wtv zifcu>`Kj{6K%r!S;##3>R;m;+P^1~4xK^Z_pDIrb6iNmtt`*8=eF!Us^-gpiO8zzu z=B*mFYWBQT_+>V(z{|1TzzO)adAc;8Hjn4x6qfn46wi!Cl)RNk+{jz$>{4dO1^8t) z?mc)9jBD97%G);$pW+$Cvz|v{E^=DGnYZUCucZ(-@>(9{?uir1IwD-ptAt->MzP^_!Q6b7|&?^ChjDXd0_x?BQIs|EhEQ?d(-mr9e$aOE92dxD{(C^zTR6# z!=reX$9Ps)esA7;PtbQ_pkNuG_XY*aKUpURhyn8%ptxo}r-Voh5Ca9v0L8U}<)5q* z1H^#&3{YG%pHo6428e-zWq{&Z!SYYmi2-83d%;&t zU_JvB*UabCH4(M-PVDM}fjgUxS(6Y_v5V)>xU+`O!EsGC^3Q8+-YPe0QJzO?6c2ux zjVt5Hx>3IcUXJw!8a|Vx>3qxeaFRV{)#k|pU(z{ru}DPxyY6s3=f=PKW|={L`uTIk z+V>yjekoi(-|0^3`jcT9P4)AgVp@;-Z>{Os=iABY+Ie;CRCDFLIyPQC--%z^pqy97 zwe#xO*phpDY!8iPE!Fd#xL2LylY{*Nxpux2Kc}AW#4f4lJMq6I?yMA@K9rAE&v#;* ztK5odczhV2sMI@+Q9LvHB;~C<;zr&|XP1(nr|guxeS=?SNo@{KEvbf7V~?pLKRu`)9ML ze`fSnGgKb;p&<}aSi8UVb&E#v;FsCBGM=m(^;^S3zvy0hZ(MWuIhXyC{kEgSzFGFD zf96>KjMM&S@}28kU5gvcs1^A+Hi`$o%*K`RWZkIW8lE249}t(7xMq?a#4CFeewfw1 zTy#3czwG_pB%MLmq{DV;J-X35gUS;F1;qfp4=pI>$u2QK3^Sc8LLEz*!7XTyqw)X5rHLn`SZT7!KyGI#%lb+-A>9g_Kr2ez;6LD=h0bY5%jF&qaP|THN49 z?a9xvQ9Sr%Hm-~(>qh<7@Yr8S)A^R`A@qBHhni)N_Sc~OHHFzcoDciwdHDhJKn&+Wb)BnpyP{pukn6q|^Qy>i%bj|9%*4FDg$A6n+M1 zo=4$tW2zxBP}CTpxK`9!oa##q6n+LMt`+_^rWz6hMU4T9YelWaslLQO;b(y2TH$YF zsv$8@)EJ<+R@7RY>Prk1eg-J675+A+8WICVjRA^ljV(j3?_x{nqdvhV=oq`L6K8@cd3uuZ1=|oEMCGN;-#*dA5%7IXC{*H_Ht2)6WMk z*1rEJ_e>51>uKi+pK7k2Cp=y|Px#UX^*rH3 z?L6VJCDrqUTWaSCd)3K#p76Qn)$@ce<#K(%>xX;vJQ8z}`LsVC*tFhfmY|AVJcq`e zWsdSe`5j{*ujg3HR&f79%IS{k@C$g5j^fDqfn!g~`5o2qJlm77>nN_AA2{|T>?2(+ z@2C#EFe$DHU6{D8uM2>#Nr&U88qchHyt=;P=I^Win3({7)p+Hr$FG-v-gRXErBQl* zIlZP9&$6!0nY7aC&6;TO?Cz&(Jln0sv)Luocy?~s$kDf2$8+z-YCQ8yZuyCOom3vI z#j|xioNL?SmsDFkQ%U_9pkWs5<_jBNv*|;*EteeCG`bGAn z!_T?wm+ZG49rn$#cW?h};RE_>)IZbW8THTPJ2roG?eDJP`yKrNNTc4vFW^BM#RZN% zX%rv*7<&?SjoJZ@Jqi0rH);=f=zra#FQoq2CT6F9#;d%rY4^?W*o08^_om_J%xF6n z_VQXkh#9S?);~ahz%U;CG8XO{wwJqi0r z&n}s_Z@_y%|4f+3F@CntrWf_`vncg_057@eqIJUlJ(l6rKg-*V=U^$^rw$goFT3{K zU|iU6{|tVajVt5H`h691@PdY?$91h9x(<5|!rqr1IDa8w*Km8FUWik6bUluVd<*l^ zFN_YKIi=YfBlPk8z!USX)f)At{yt##$!e7Xo^PP1tM50%bzM$>RDB=7yLVIGZ}Rum z_W`Wtzm@j^EJ1xAz%Q%sH~Guz`%TufmhwJ;RaM^yu-aXf_nW-TqOX5n6}*^r-J_l# z&`&!*;9Ga~{D7(I_?hAU!F>%*7l;93;2~jv=C3^@4MA=Y1H^!73{YG%jYXAy*G!|5fQSKN;2~jv;@U&f5ab3iKn$41K;emNTNb3c#rp1J z&rUA?edBE>_~{iRU%YcUgWh|ZcCS1bDEz{J$5+*>sK4)iaIU|;ig}Qbe_i6wbF6=& zSA`zT%Eq`m!HDgL5JNxul--A0tZsAhuQXkr%LNCS%)7; zm*-vlb&b>{rf<94)|c`T1H?d4VZe3!Yaw@=tPVM~B;m1Ijh;PIJc)0)8v0Iuw=_Kt7nK&w z?|Q~$_46$C#iyScwHqJ5AR)b~)dFc0ufo%wwa%;!ws`jbvQ^Z)5YuqsAEep%`u;ZB zJ~vAiY?$BSh1nKdhh3xm00;8EdT92MZqy$9*Y~^7@S(#Fba@_ZZtctcu@S%jqxe9c zuJ_N>^z68hsr28{?4}ogb9Lk0JI*opR4?DYQ@62z9d4N~efb&hHgL_^mxLaCwPDcF zHCx%5(7Fd+p0bU7)W2(Loi{$`AFo=l%$DBpqp)&mWjJ3Troi)`v&< zy3UCWZRyt(`-!|%^A=zE3qAa$fz3Bn4%YWK%5A+VA2C1-6g381r@t0n>W#*Mm($pD z%J^9_^BwYG>%pTg=yA8GwOoGH*DLwS6SwDl;A7Mt_R)hhidVnP>PH6nc<<`sc4x?n zQI=h1;Q|NJ+4%ojHoebNCpzi=!>&<&fCDMl1J@zVwddOoxT?Hgm)9NT*1kMGCilnW zbyQ>3b?^G`IJJEC`uj4=uMO=y{s15Mx7}rrZuPdz*sw3A`~PB@+xa`g2WJej9I871 zx#(w3vWnK77dy=wY`Oi$&Ashji{NeJfBpS~fB;JkW&A9Id;PqxVaP{*3F`RSkF0vH zf1POe%OeS1&HQJTh!3-DnA3OC^REuI#D|WEiksZl@OX4?`J#Np05MQl7es@5xb}8SK znET#J^KjuGq}lk`{~2=8jN9v0`GvQG4!cJA0S=^G53}t75BtH&I{e6)p1n>yxAm3Z zSIXmKa(_&$i_fv{T~05j_RrYL5v^FOj7gTa)czT-TW!XUfmsoj6t#cGtIk-_yLZP3 zORCyGV+U5QT(LT2oTaAPKVvaF)2}2fvReMDyx;th|2C)PZ{5bfZaEgZ?SS90xd|zO zEh_kjS}k$!_7B}~ekF~c<+i~oA2C1-*kRz|>95iFnVlwb@?{!7bME*Yjh{Jo1LtD= zOmxYz$H(OUm^{8l<7bZXPq~Q!V&GwApy2vzV{dkf^?v0HTRN-QMgOnP@yUZ95A4-9 zg~rdEe5+{Zc^T%pj<#G>VV&(ID$<=L$-Mgy5Z*x^(6UV2C93h`*uD!HTPF*d#E^&xi9(PoPW5hcr!I= z7A|lg&Bn($kB~!qD)nH{VfVqc&iz&3K+5$n+aB<6?whQ`52Ry1*xc7h&53nMD2I8l z&@t}z~-%jt+?;yec>uDv&TfAendX7S+UHUF+Dez%tC+BNfs_mkP$JyYfc9k|YF z+^F%xv9(z|K4xBsE&bXE@3sD&wzEeo?f7zb(slN7U}*Url~Q^8vcqrI5b3camRyP~ zmdTgz{%4#`uYc~94_du-k)^&n)_>}i11vMT{iA6=ujegyH`qR5(pHLRXecMik=p*l zM&bKtIFLs1`f7N7CgZ|CNVD;AUM=Lus{b51>>A|}#=EyyEpM``r(mV|S}XCwCI*p*@Gif7baL z8#iP8yf#8V^yrI?udF%6gCZiMJ%hLM(uZf{uBx&@t}XYeRmFF`G9{V$45{x^PxQ||ztJ)L z%7nPf@-wq8vLVy1&V9GTaTYdW=c%|UX{=L9$`B9t3qs^pbR%t{+xn6fN__T7ZDvzCz%vQ{=P@0(YnA5Acx%_B zZ7irmOtaFW|JAzV*w<%PPvzUTJl3eXNI(4jJGDmtn8Do^d3UKM^pp*Ys#y+RWVa_) zOt92RWLE!Q%lF;;KHu~GkDc}(pz$-ble1(U?R+ky_!TuANTYZj8lIcUxbP3sZ2Xq0 zzZ`Oiq{FUJet-ih*8}#EZqy#|&J7zm`c~_By?)SP2RhDubJQ;OA456pcL5#iE1_dP zKXf@A?<+CRkMS}5uGITuS{y@vZFSS{wwx2cJ68<;X?#$`MYd!^ivag6S^RY1d%uM4 zKg(8r5%BhZM0({8UY^qiU1P5w+|_NL(A_roEaCmxDIV^VSasRxt-R5!s$Fd3XNkuK zmK^vNjh~^NoF#A6RnHkTiod14iUVmBZ-<6=&}3Zr2Wd9GR~>~14!OQ&P91iQ@&g=5 zxgM~Obfflw=iNAGeW1e*bev!5s9o%%h;rDs2Rhb?L&v%#=yE#dUEqBs#`!TmhToOw zk0FjBu9aBZ;>7jg+u87u!Tv48_}M>d2mE?XjGtYP+a6(kLGM zdc9 z&`!>hiO;L&J{!gVK*NDF3vZ2+aN(!f_~%rA8DvxSTy0&4U9FNjfvyJC{K7D)cxAcA!4do8ybHXq#5h03$MCxn{V~L`nOa;s_DSic=fv;MZ#o?qTDIj0<~3|qoufkk{X&JN z31tuPl+Ld`|4igo?pyxG#}mc)nP1uX-xkcJ@iVlOv*fkUcUFo{AL{UZC;Jnhq~buz z>5l3fHN3qh*%z$oYW-Dc1w`kuH~aR0p0%O3wN~haKo$ z_o(-Sqjqr42+Cn!Yv|ZV5jytmfi9g-dAFrALC>AU5Wk};#i~> z*Lrlj+BQLqpKa{;z1ZsLI%Dn2-&rZf&$9ZZT=AZ9l$~hv&m-a=YjNX$?uU6Yjh~^R zoFpTkn0mQE=qAHFPu6fCjk*h*sZPRm7XE0}Uk2G<(qY#uet;+Q1p7!gY7cnxr20UI z9q9eEc5~D&&Pzf$?9UG!`@urT{$tSPbgZw$x+Ki!$2?ffAH%#0ysyMKKgP%KyAu5| z#4)kihcYjB)IYxa?h6suzVvy2O^ew_dDwGz77r5Zn|vC#ZD;GBMR%NeXvjS3pP`+c zC8Mj|ifMR!nBo4JucwLwY1H5D8r}etap51N+4!5QC_He;&h9yN*fq)za3JM+z&_HA z+5=v4jhyv?4m;4lRqO5WeT8%1;4jV@fsTExp<^FK=yE#NiDP{w)+J#+Kjy(={ut(6 z;C&^=`7u6*-<9Z(A&%jFXzd>UpAHiJv&Lu>C^0FC69r z^&*)&1|}%`uN!~wBIYBCKUHZQGI!SP=Mw$-9C&8o0>_@T?@{hjCaQij$XSvOySYQN zp9?tlB&$f=OxMe^UM3e%IVne0_()Fz7p$_ zFrOduU@?CT^Dgkd665?BAH(lT^v4j#1P@Ao&HAKAv0@R{r^IJqNXz2Iip_}(6#ow@ zA5(Eiw!pH|({R+66a%HNnl=Mx->GW^JqBJgwl^>}3RH26;wVpHVJU z8yXxj+!`@vOz1ebp<^P$BHXM#BB8t6ENeR<{jBZ9XHb(6<)X(d>)(}&7Td)(>cy~G z)(* users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "metadata, tables = load_demo(metadata=True)" + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" ] }, { @@ -58,60 +109,22 @@ ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "tables" + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -121,141 +134,92 @@ ] }, { - "cell_type": "code", - "execution_count": 6, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.add_table('users', data=tables['users'], primary_key='user_id')" + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "new_meta.add_table('sessions', data=tables['sessions'], primary_key='session_id',\n", - " parent='users', foreign_key='user_id')" + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "transactions_fields = {\n", - " 'timestamp': {\n", - " 'type': 'datetime',\n", - " 'format': '%Y-%m-%d'\n", - " }\n", - "}\n", - "new_meta.add_table('transactions', tables['transactions'], fields_metadata=transactions_fields,\n", - " primary_key='transaction_id', parent='sessions')" + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", - " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}},\n", - " 'primary_key': 'user_id'},\n", - " 'sessions': {'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'users', 'field': 'user_id'}},\n", - " 'os': {'type': 'categorical'},\n", - " 'device': {'type': 'categorical'}},\n", - " 'primary_key': 'session_id'},\n", - " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", - " 'format': '%Y-%m-%d'},\n", - " 'session_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", - " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'approved': {'type': 'boolean'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'}},\n", - " 'primary_key': 'transaction_id'}}}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "new_meta.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "new_meta.to_dict() == metadata.to_dict()" + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 11, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.to_json('demo_metadata.json')" + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "loaded = Metadata('demo_metadata.json')" + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 13, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "loaded.to_dict() == new_meta.to_dict()" + "Let's see what our Metadata looks like right now:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -278,10 +242,10 @@ "\n", "users\n", "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", "\n", "Primary key: user_id\n", "\n", @@ -291,10 +255,10 @@ "\n", "sessions\n", "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", "\n", "Primary key: session_id\n", "Foreign key (users): user_id\n", @@ -303,7 +267,7 @@ "\n", "users->sessions\n", "\n", - "\n", + "\n", "   sessions.user_id -> users.user_id\n", "\n", "\n", @@ -312,11 +276,11 @@ "\n", "transactions\n", "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", "\n", "Primary key: transaction_id\n", "Foreign key (sessions): session_id\n", @@ -325,114 +289,137 @@ "\n", "sessions->transactions\n", "\n", - "\n", + "\n", "   transactions.session_id -> sessions.session_id\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metadata.visualize()" + "new_meta.visualize()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "gender : categorical\n", - "user_id : id - integer\n", - "country : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "os : categorical\n", - "device : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "timestamp : datetime\n", - "session_id : id - integer\n", - "transaction_id : id - integer\n", - "approved : boolean\n", - "amount : numerical - float\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], "text/plain": [ - "" + "True" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_meta.visualize()" + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" ] } ], From 5503b16a96eadff82d78d523db1420aeee801437 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 21:42:07 +0200 Subject: [PATCH 31/33] Remove unneeded files --- sdv/metadata.py.composite | 1054 --------------------------------- sdv/metadata.py.non_primary | 1109 ----------------------------------- 2 files changed, 2163 deletions(-) delete mode 100644 sdv/metadata.py.composite delete mode 100644 sdv/metadata.py.non_primary diff --git a/sdv/metadata.py.composite b/sdv/metadata.py.composite deleted file mode 100644 index 10dda41e3..000000000 --- a/sdv/metadata.py.composite +++ /dev/null @@ -1,1054 +0,0 @@ -import copy -import json -import logging -import os -from collections import defaultdict - -import graphviz -import numpy as np -import pandas as pd -from rdt import HyperTransformer, transformers - -LOGGER = logging.getLogger(__name__) - - -def _read_csv_dtypes(table_meta): - """Get the dtypes specification that needs to be passed to read_csv.""" - dtypes = dict() - for name, field in table_meta['fields'].items(): - field_type = field['type'] - if field_type == 'categorical': - dtypes[name] = str - elif field_type == 'id' and field.get('subtype', 'integer') == 'string': - dtypes[name] = str - - return dtypes - - -def _parse_dtypes(data, table_meta): - """Convert the data columns to the right dtype after loading the CSV.""" - for name, field in table_meta['fields'].items(): - field_type = field['type'] - if field_type == 'datetime': - datetime_format = field.get('format') - data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) - elif field_type == 'numerical' and field.get('subtype') == 'integer': - data[name] = data[name].dropna().astype(int) - elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': - data[name] = data[name].dropna().astype(int) - - return data - - -def _load_csv(root_path, table_meta): - """Load a CSV with the right dtypes and then parse the columns.""" - relative_path = os.path.join(root_path, table_meta['path']) - dtypes = _read_csv_dtypes(table_meta) - - data = pd.read_csv(relative_path, dtype=dtypes) - data = _parse_dtypes(data, table_meta) - - return data - - -class MetadataError(Exception): - pass - - -class Metadata: - """Dataset Metadata. - - The Metadata class provides a unified layer of abstraction over the dataset - metadata, which includes both the necessary details to load the data from - the hdd and to know how to parse and transform it to numerical data. - - Args: - metadata (str or dict): - Path to a ``json`` file that contains the metadata or a ``dict`` representation - of ``metadata`` following the same structure. - - root_path (str): - The path where the ``metadata.json`` is located. Defaults to ``None``. - """ - - _child_map = None - _hyper_transformers = None - _metadata = None - _parent_map = None - - root_path = None - - _FIELD_TEMPLATES = { - 'i': { - 'type': 'numerical', - 'subtype': 'integer', - }, - 'f': { - 'type': 'numerical', - 'subtype': 'float', - }, - 'O': { - 'type': 'categorical', - }, - 'b': { - 'type': 'boolean', - }, - 'M': { - 'type': 'datetime', - } - } - _DTYPES = { - ('categorical', None): 'object', - ('boolean', None): 'bool', - ('numerical', None): 'float', - ('numerical', 'float'): 'float', - ('numerical', 'integer'): 'int', - ('datetime', None): 'datetime64', - ('id', None): 'int', - ('id', 'integer'): 'int', - ('id', 'string'): 'str' - } - - def _analyze_relationships(self): - """Extract information about child-parent relationships. - - Creates the following attributes: - * ``_child_map``: set of child tables that each table has. - * ``_parent_map``: set ot parents that each table has. - """ - self._child_map = defaultdict(set) - self._parent_map = defaultdict(set) - - for table, table_meta in self._metadata['tables'].items(): - if table_meta.get('use', True): - for field_meta in table_meta['fields'].values(): - ref = field_meta.get('ref') - if ref: - parent = ref['table'] - self._child_map[parent].add(table) - self._parent_map[table].add(parent) - - @staticmethod - def _transform_metadata(metadata): - """Ensure metadata has the internal SDV format. - - Convert list of tables and list of fields to dicts. - Ensure primary keys are defined as lists of fields. - - Args: - metadata (dict): - Original metadata to format. - - Returns: - dict: - Formated metadata dict. - """ - new_metadata = copy.deepcopy(metadata) - tables = new_metadata['tables'] - if isinstance(tables, dict): - new_metadata['tables'] = { - table: meta - for table, meta in tables.items() - if meta.pop('use', True) - } - return new_metadata - - new_tables = dict() - for table in tables: - if table.pop('use', True): - new_tables[table.pop('name')] = table - - fields = table['fields'] - new_fields = dict() - for field in fields: - new_fields[field.pop('name')] = field - - table['fields'] = new_fields - - primary_key = table.get('primary_key') - if isinstance(primary_key, str): - table['primary_key'] = [primary_key] - - new_metadata['tables'] = new_tables - - return new_metadata - - def __init__(self, metadata=None, root_path=None): - if isinstance(metadata, str): - self.root_path = root_path or os.path.dirname(metadata) - with open(metadata) as metadata_file: - metadata = json.load(metadata_file) - else: - self.root_path = root_path or '.' - - if metadata is not None: - self._metadata = self._transform_metadata(metadata) - else: - self._metadata = {'tables': {}} - - self._hyper_transformers = dict() - self._analyze_relationships() - - def get_children(self, table_name): - """Get tables for which the given table is parent. - - Args: - table_name (str): - Name of the table from which to get the children. - - Returns: - set: - Set of children for the given table. - """ - return self._child_map[table_name] - - def get_parents(self, table_name): - """Get tables for with the given table is child. - - Args: - table_name (str): - Name of the table from which to get the parents. - - Returns: - set: - Set of parents for the given table. - """ - return self._parent_map[table_name] - - def get_table_meta(self, table_name): - """Get the metadata dict for a table. - - Args: - table_name (str): - Name of table to get data for. - - Returns: - dict: - table metadata - - Raises: - ValueError: - If table does not exist in this metadata. - """ - table = self._metadata['tables'].get(table_name) - if table is None: - raise ValueError('Table "{}" does not exist'.format(table_name)) - - return copy.deepcopy(table) - - def get_tables(self): - """Get the list of table names. - - Returns: - list: - table names. - """ - return list(self._metadata['tables'].keys()) - - def get_fields(self, table_name): - """Get table fields metadata. - - Args: - table_name (str): - Name of the table to get the fields from. - - Returns: - dict: - Mapping of field names and their metadata dicts. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - return self.get_table_meta(table_name)['fields'] - - def get_primary_key(self, table_name): - """Get the primary key name of the indicated table. - - Args: - table_name (str): - Name of table for which to get the primary key field. - - Returns: - list or None: - Primary key field names. ``None`` if the table has no primary key. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - return self.get_table_meta(table_name).get('primary_key') - - def get_foreign_key(self, parent, child): - """Get table foreign key field name. - - Args: - parent (str): - Name of the parent table. - child (str): - Name of the child table. - - Returns: - str or None: - Foreign key field name. - - Raises: - ValueError: - If the relationship does not exist. - """ - primary = self.get_primary_key(parent) - - for name, field in self.get_fields(child).items(): - ref = field.get('ref') - if ref and ref['field'] == primary: - return name - - raise ValueError('{} is not parent of {}'.format(parent, child)) - - def load_table(self, table_name): - """Load table data. - - Args: - table_name (str): - Name of the table to load. - - Returns: - pandas.DataFrame: - DataFrame with the contents of the table. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - LOGGER.info('Loading table %s', table_name) - table_meta = self.get_table_meta(table_name) - return _load_csv(self.root_path, table_meta) - - def load_tables(self, tables=None): - """Get a dictionary with data from multiple tables. - - If a ``tables`` list is given, only load the indicated tables. - Otherwise, load all the tables from this metadata. - - Args: - tables (list): - List of table names. Defaults to ``None``. - - Returns: - dict(str, pandasd.DataFrame): - mapping of table names and their data loaded as ``pandas.DataFrame`` instances. - """ - return { - table_name: self.load_table(table_name) - for table_name in tables or self.get_tables() - } - - def get_dtypes(self, table_name, ids=False): - """Get a ``dict`` with the ``dtypes`` for each field of a given table. - - Args: - table_name (str): - Table name for which to retrive the ``dtypes``. - ids (bool): - Whether or not include the id fields. Defaults to ``False``. - - Returns: - dict: - Dictionary that contains the field names and data types from a table. - - Raises: - ValueError: - If a field has an invalid type or subtype or if the table does not - exist in this metadata. - """ - dtypes = dict() - table_meta = self.get_table_meta(table_name) - for name, field in table_meta['fields'].items(): - field_type = field['type'] - field_subtype = field.get('subtype') - dtype = self._DTYPES.get((field_type, field_subtype)) - if not dtype: - raise MetadataError( - 'Invalid type and subtype combination for field {}: ({}, {})'.format( - name, field_type, field_subtype) - ) - - if ids and field_type == 'id': - if (name not in table_meta.get('primary_key', [])) and not field.get('ref'): - raise MetadataError( - 'id field `{}` is neither a primary or a foreign key'.format(name)) - - if ids or (field_type != 'id'): - dtypes[name] = dtype - - return dtypes - - def _get_pii_fields(self, table_name): - """Get the ``pii_category`` for each field that contains PII. - - Args: - table_name (str): - Table name for which to get the pii fields. - - Returns: - dict: - pii field names and categories. - """ - pii_fields = dict() - for name, field in self.get_table_meta(table_name)['fields'].items(): - if field['type'] == 'categorical' and field.get('pii', False): - pii_fields[name] = field['pii_category'] - - return pii_fields - - @staticmethod - def _get_transformers(dtypes, pii_fields): - """Create the transformer instances needed to process the given dtypes. - - Temporary drop-in replacement of ``HyperTransformer._analyze`` method, - before RDT catches up. - - Args: - dtypes (dict): - mapping of field names and dtypes. - pii_fields (dict): - mapping of pii field names and categories. - - Returns: - dict: - mapping of field names and transformer instances. - """ - transformers_dict = dict() - for name, dtype in dtypes.items(): - dtype = np.dtype(dtype) - if dtype.kind == 'i': - transformer = transformers.NumericalTransformer(dtype=int) - elif dtype.kind == 'f': - transformer = transformers.NumericalTransformer(dtype=float) - elif dtype.kind == 'O': - anonymize = pii_fields.get(name) - transformer = transformers.CategoricalTransformer(anonymize=anonymize) - elif dtype.kind == 'b': - transformer = transformers.BooleanTransformer() - elif dtype.kind == 'M': - transformer = transformers.DatetimeTransformer() - else: - raise ValueError('Unsupported dtype: {}'.format(dtype)) - - LOGGER.info('Loading transformer %s for field %s', - transformer.__class__.__name__, name) - transformers_dict[name] = transformer - - return transformers_dict - - def _load_hyper_transformer(self, table_name): - """Create and return a new ``rdt.HyperTransformer`` instance for a table. - - First get the ``dtypes`` and ``pii fields`` from a given table, then use - those to build a transformer dictionary to be used by the ``HyperTransformer``. - - Args: - table_name (str): - Table name for which to load the HyperTransformer. - - Returns: - rdt.HyperTransformer: - Instance of ``rdt.HyperTransformer`` for the given table. - """ - dtypes = self.get_dtypes(table_name) - pii_fields = self._get_pii_fields(table_name) - transformers_dict = self._get_transformers(dtypes, pii_fields) - return HyperTransformer(transformers=transformers_dict) - - def transform(self, table_name, data): - """Transform data for a given table. - - If the ``HyperTransformer`` for a table is ``None`` it is created. - - Args: - table_name (str): - Name of the table that is being transformer. - data (pandas.DataFrame): - Table data. - - Returns: - pandas.DataFrame: - Transformed data. - """ - hyper_transformer = self._hyper_transformers.get(table_name) - if hyper_transformer is None: - hyper_transformer = self._load_hyper_transformer(table_name) - fields = list(hyper_transformer.transformers.keys()) - hyper_transformer.fit(data[fields]) - self._hyper_transformers[table_name] = hyper_transformer - - hyper_transformer = self._hyper_transformers.get(table_name) - fields = list(hyper_transformer.transformers.keys()) - return hyper_transformer.transform(data[fields]) - - def reverse_transform(self, table_name, data): - """Reverse the transformed data for a given table. - - Args: - table_name (str): - Name of the table to reverse transform. - data (pandas.DataFrame): - Data to be reversed. - - Returns: - pandas.DataFrame - """ - hyper_transformer = self._hyper_transformers[table_name] - reversed_data = hyper_transformer.reverse_transform(data) - - for name, dtype in self.get_dtypes(table_name, ids=True).items(): - reversed_data[name] = reversed_data[name].dropna().astype(dtype) - - return reversed_data - - # ################### # - # Metadata Validation # - # ################### # - - def _validate_table(self, table_name, table_meta, table_data=None): - """Validate table metadata. - - Validate the type and subtype combination for each field in ``table_meta``. - If a field has type ``id``, validate that it either is the ``primary_key`` or - has a ``ref`` entry. - - If the table has ``primary_key``, make sure that the corresponding field exists - and its type is ``id``. - - If ``table_data`` is provided, also check that the list of columns corresponds - to the ones indicated in the metadata and that all the dtypes are valid. - - Args: - table_name (str): - Name of the table to validate. - table_meta (dict): - Metadata of the table to validate. - table_data (pandas.DataFrame): - If provided, make sure that the data matches the one described - on the metadata. - - Raises: - MetadataError: - If there is any error in the metadata or the data does not - match the metadata description. - """ - dtypes = self.get_dtypes(table_name, ids=True) - - # Primary key field exists and its type is 'id' - primary_key_fields = table_meta.get('primary_key', []) - for primary_key in primary_key_fields: - pk_field = table_meta['fields'].get(primary_key) - - if not pk_field: - raise MetadataError('Primary key is not an existing field.') - - if pk_field['type'] != 'id': - raise MetadataError('Primary key is not of type `id`.') - - if table_data is not None: - for column in table_data: - try: - dtype = dtypes.pop(column) - table_data[column].dropna().astype(dtype) - except KeyError: - message = 'Unexpected column in table `{}`: `{}`'.format(table_name, column) - raise MetadataError(message) from None - except ValueError as ve: - message = 'Invalid values found in column `{}` of table `{}`: `{}`'.format( - column, table_name, ve) - raise MetadataError(message) from None - - # assert all dtypes are in data - if dtypes: - raise MetadataError( - 'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys())) - ) - - def _validate_circular_relationships(self, parent, children=None): - """Validate that there is no circular relatioship in the metadata.""" - if children is None: - children = self.get_children(parent) - - if parent in children: - raise MetadataError('Circular relationship found for table "{}"'.format(parent)) - - for child in children: - self._validate_circular_relationships(parent, self.get_children(child)) - - def validate(self, tables=None): - """Validate this metadata. - - For each table from in metadata ``tables`` entry: - * Validate the table metadata is correct. - - * If ``tables`` are provided or they have been loaded, check - that all the metadata tables exists in the ``tables`` dictionary. - * Validate the type/subtype combination for each field and - if a field of type ``id`` exists it must be the ``primary_key`` - or must have a ``ref`` entry. - * If ``primary_key`` entry exists, check that it's an existing - field and its type is ``id``. - * If ``tables`` are provided or they have been loaded, check - all the data types for the table correspond to each column and - all the data types exists on the table. - * Validate that there is no circular relatioship in the metadata. - * Check that all the tables have at most one parent. - - Args: - tables (bool, dict): - If a dict of table is passed, validate that the columns and - dtypes match the metadata. If ``True`` is passed, load the - tables from the Metadata instead. If ``None``, omit the data - validation. Defaults to ``None``. - """ - tables_meta = self._metadata.get('tables') - if not tables_meta: - raise MetadataError('"tables" entry not found in Metadata.') - - if tables and not isinstance(tables, dict): - tables = self.load_tables() - - for table_name, table_meta in tables_meta.items(): - if tables: - table = tables.get(table_name) - if table is None: - raise MetadataError('Table `{}` not found in tables'.format(table_name)) - - else: - table = None - - self._validate_table(table_name, table_meta, table) - self._validate_circular_relationships(table_name) - - def _check_field(self, table, field, exists=False): - """Validate the existance of the table and existance (or not) of field.""" - table_fields = self.get_fields(table) - if exists and (field not in table_fields): - raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) - - if not exists and (field in table_fields): - raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) - - # ################# # - # Metadata Creation # - # ################# # - - def add_field(self, table, field, field_type, field_subtype=None, properties=None): - """Add a new field to the indicated table. - - Args: - table (str): - Table name to add the new field, it must exist. - field (str): - Field name to be added, it must not exist. - field_type (str): - Data type of field to be added. Required. - field_subtype (str): - Data subtype of field to be added. Optional. - Defaults to ``None``. - properties (dict): - Extra properties of field like: ref, format, min, max, etc. Optional. - Defaults to ``None``. - - Raises: - ValueError: - If the table does not exist or it already contains the field. - """ - self._check_field(table, field, exists=False) - - field_details = { - 'type': field_type - } - - if field_subtype: - field_details['subtype'] = field_subtype - - if properties: - field_details.update(properties) - - self._metadata['tables'][table]['fields'][field] = field_details - - @staticmethod - def _get_key_subtype(field_meta): - """Get the appropriate key subtype.""" - field_type = field_meta['type'] - if field_type == 'categorical': - field_subtype = 'string' - elif field_type in ('numerical', 'id'): - field_subtype = field_meta['subtype'] - if field_subtype not in ('integer', 'string'): - raise ValueError( - 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) - ) - else: - raise ValueError( - 'Invalid field "type" for key field: "{}"'.format(field_type) - ) - - return field_subtype - - def set_primary_key(self, table, field): - """Set the primary key field of the indicated table. - - The field must exist and either be an integer or categorical field. - - Args: - table (str): - Name of the table where the primary key will be set. - field (str): - Field to be used as the new primary key. - - Raises: - ValueError: - If the table or the field do not exist or if the field has an - invalid type or subtype. - """ - self._check_field(table, field, exists=True) - - field_meta = self.get_fields(table).get(field) - field_subtype = self._get_key_subtype(field_meta) - - table_meta = self._metadata['tables'][table] - table_meta['fields'][field] = { - 'type': 'id', - 'subtype': field_subtype - } - table_meta['primary_key'] = field - - def add_relationship(self, parent, child, foreign_key=None): - """Add a new relationship between the parent and child tables. - - The relationship is created by adding a reference (``ref``) on the ``foreign_key`` - field of the ``child`` table pointing at the ``parent`` primary key. - - Args: - parent (str): - Name of the parent table. - child (str): - Name of the child table. - foreign_key (str): - Field in the child table through which the relationship is created. - If ``None``, use the parent primary key name. - - Raises: - ValueError: - If any of the following happens: - * The parent table does not exist. - * The child table does not exist. - * The parent table does not have a primary key. - * The foreign_key field already exists in the child table. - * The child table already has a parent. - * The new relationship closes a relationship circle. - """ - # Validate table and field names - primary_key = self.get_primary_key(parent) - if not primary_key: - raise ValueError('Parent table "{}" does not have a primary key'.format(parent)) - - if foreign_key is None: - foreign_key = primary_key - - # Validate relationships - if self.get_parents(child): - raise ValueError('Table "{}" already has a parent'.format(child)) - - grandchildren = self.get_children(child) - if grandchildren: - self._validate_circular_relationships(parent, grandchildren) - - # Copy primary key details over to the foreign key - foreign_key_details = copy.deepcopy(self.get_fields(parent)[primary_key]) - foreign_key_details['ref'] = { - 'table': parent, - 'field': primary_key - } - - # Make sure that key subtypes are the same - foreign_meta = self.get_fields(child).get(foreign_key) - if foreign_meta: - foreign_subtype = self._get_key_subtype(foreign_meta) - if foreign_subtype != foreign_key_details['subtype']: - raise ValueError('Primary and Foreign key subtypes mismatch') - - self._metadata['tables'][child]['fields'][foreign_key] = foreign_key_details - - # Re-analyze the relationships - self._analyze_relationships() - - def _get_field_details(self, data, fields): - """Get or build all the fields metadata. - - Analyze a ``pandas.DataFrame`` to build a ``dict`` with the name of the column, and - their data type and subtype. If ``columns`` are provided, only those columns will be - analyzed. - - Args: - data (pandas.DataFrame): - Table to be analyzed. - fields (set): - Set of field names or field specifications. - - Returns: - dict: - Dict of valid fields. - - Raises: - TypeError: - If a field specification is not a str or a dict. - ValueError: - If a column from the data analyzed is an unsupported data type or - """ - fields_metadata = dict() - for field in fields: - dtype = data[field].dtype - field_template = self._FIELD_TEMPLATES.get(dtype.kind) - if not field_template: - raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field)) - - field_details = copy.deepcopy(field_template) - fields_metadata[field] = field_details - - return fields_metadata - - def add_table(self, name, data=None, fields=None, fields_metadata=None, - primary_key=None, parent=None, foreign_key=None): - """Add a new table to this metadata. - - ``fields`` list can be a mixture of field names, which will be build automatically - from the data, or dictionaries specifying the field details. If a field needs to be - analyzed, data has to be also passed. - - If ``parent`` is given, a relationship will be established between this table - and the specified parent. - - Args: - name (str): - Name of the new table. - data (str or pandas.DataFrame): - Table to be analyzed or path to the csv file. - If it's a relative path, use ``root_path`` to find the file. - Only used if fields is not ``None``. - Defaults to ``None``. - fields (list): - List of field names to build. If ``None`` is given, all the fields - found in the data will be used. - Defaults to ``None``. - fields_metadata (dict): - Metadata to be used when creating fields. This will overwrite the - metadata built from the fields found in data. - Defaults to ``None``. - primary_key (str): - Field name to add as primary key, it must not exists. Defaults to ``None``. - parent (str): - Table name to refere a foreign key field. Defaults to ``None``. - foreign_key (str): - Foreing key field name to ``parent`` table primary key. Defaults to ``None``. - - Raises: - ValueError: - If the table ``name`` already exists or ``data`` is not passed and - fields need to be built from it. - """ - if name in self.get_tables(): - raise ValueError('Table "{}" already exists.'.format(name)) - - path = None - if data is not None: - if isinstance(data, str): - path = data - if not os.path.isabs(data): - data = os.path.join(self.root_path, data) - - data = pd.read_csv(data) - - fields = set(fields or data.columns) - if fields_metadata: - fields = fields - set(fields_metadata.keys()) - else: - fields_metadata = dict() - - fields_metadata.update(self._get_field_details(data, fields)) - - elif fields_metadata is None: - fields_metadata = dict() - - table_metadata = {'fields': fields_metadata} - if path: - table_metadata['path'] = path - - self._metadata['tables'][name] = table_metadata - - try: - if primary_key: - self.set_primary_key(name, primary_key) - - if parent: - self.add_relationship(parent, name, foreign_key) - - except ValueError: - # Cleanup - del self._metadata['tables'][name] - raise - - # ###################### # - # Metadata Serialization # - # ###################### # - - def to_dict(self): - """Get a dict representation of this metadata. - - Returns: - dict: - dict representation of this metadata. - """ - return copy.deepcopy(self._metadata) - - def to_json(self, path): - """Dump this metadata into a JSON file. - - Args: - path (str): - Path of the JSON file where this metadata will be stored. - """ - with open(path, 'w') as out_file: - json.dump(self._metadata, out_file, indent=4) - - @staticmethod - def _get_graphviz_extension(path): - if path: - path_splitted = path.split('.') - if len(path_splitted) == 1: - raise ValueError('Path without graphviz extansion.') - - graphviz_extension = path_splitted[-1] - - if graphviz_extension not in graphviz.backend.FORMATS: - raise ValueError( - '"{}" not a valid graphviz extension format.'.format(graphviz_extension) - ) - - return '.'.join(path_splitted[:-1]), graphviz_extension - - return None, None - - def _visualize_add_nodes(self, plot): - """Add nodes into a `graphviz.Digraph`. - - Each node represent a metadata table. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - # Append table fields - fields = [] - - for name, value in self.get_fields(table).items(): - if value.get('subtype') is not None: - fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) - - else: - fields.append('{} : {}'.format(name, value['type'])) - - fields = r'\l'.join(fields) - - # Append table extra information - extras = [] - - primary_key = self.get_primary_key(table) - if primary_key is not None: - extras.append('Primary key: {}'.format(primary_key)) - - parents = self.get_parents(table) - for parent in parents: - foreign_key = self.get_foreign_key(parent, table) - extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) - - path = self.get_table_meta(table).get('path') - if path is not None: - extras.append('Data path: {}'.format(path)) - - extras = r'\l'.join(extras) - - # Add table node - title = r'{%s|%s\l|%s\l}' % (table, fields, extras) - plot.node(table, label=title) - - def _visualize_add_edges(self, plot): - """Add edges into a `graphviz.Digraph`. - - Each edge represents a relationship between two metadata tables. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - for parent in list(self.get_parents(table)): - plot.edge( - parent, - table, - label=' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ), - arrowhead='crow' - ) - - def visualize(self, path=None): - """Plot metadata usign graphviz. - - Try to generate a plot using graphviz. - If a ``path`` is provided save the output into a file. - - Args: - path (str): - Output file path to save the plot, it requires a graphviz - supported extension. If ``None`` do not save the plot. - Defaults to ``None``. - """ - filename, graphviz_extension = self._get_graphviz_extension(path) - plot = graphviz.Digraph( - 'Metadata', - format=graphviz_extension, - node_attr={ - "shape": "Mrecord", - "fillcolor": "lightgoldenrod1", - "style": "filled" - }, - ) - - self._visualize_add_nodes(plot) - self._visualize_add_edges(plot) - - if filename: - plot.render(filename=filename, cleanup=True, format=graphviz_extension) - else: - return plot - - def __str__(self): - tables = self.get_tables() - relationships = [ - ' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ) - for table in tables - for parent in list(self.get_parents(table)) - ] - - return ( - "Metadata\n" - " root_path: {}\n" - " tables: {}\n" - " relationships:\n" - "{}" - ).format( - os.path.abspath(self.root_path), - tables, - '\n'.join(relationships) - ) diff --git a/sdv/metadata.py.non_primary b/sdv/metadata.py.non_primary deleted file mode 100644 index 2df337bd2..000000000 --- a/sdv/metadata.py.non_primary +++ /dev/null @@ -1,1109 +0,0 @@ -import copy -import json -import logging -import os -from collections import defaultdict - -import graphviz -import numpy as np -import pandas as pd -from rdt import HyperTransformer, transformers - -LOGGER = logging.getLogger(__name__) - - -def _read_csv_dtypes(table_meta): - """Get the dtypes specification that needs to be passed to read_csv.""" - dtypes = dict() - for name, field in table_meta['fields'].items(): - field_type = field['type'] - if field_type == 'categorical': - dtypes[name] = str - elif field_type == 'id' and field.get('subtype', 'integer') == 'string': - dtypes[name] = str - - return dtypes - - -def _parse_dtypes(data, table_meta): - """Convert the data columns to the right dtype after loading the CSV.""" - for name, field in table_meta['fields'].items(): - field_type = field['type'] - if field_type == 'datetime': - datetime_format = field.get('format') - data[name] = pd.to_datetime(data[name], format=datetime_format, exact=False) - elif field_type == 'numerical' and field.get('subtype') == 'integer': - data[name] = data[name].dropna().astype(int) - elif field_type == 'id' and field.get('subtype', 'integer') == 'integer': - data[name] = data[name].dropna().astype(int) - - return data - - -def _load_csv(root_path, table_meta): - """Load a CSV with the right dtypes and then parse the columns.""" - relative_path = os.path.join(root_path, table_meta['path']) - dtypes = _read_csv_dtypes(table_meta) - - data = pd.read_csv(relative_path, dtype=dtypes) - data = _parse_dtypes(data, table_meta) - - return data - - -class MetadataError(Exception): - pass - - -class Metadata: - """Dataset Metadata. - - The Metadata class provides a unified layer of abstraction over the dataset - metadata, which includes both the necessary details to load the data from - the hdd and to know how to parse and transform it to numerical data. - - Args: - metadata (str or dict): - Path to a ``json`` file that contains the metadata or a ``dict`` representation - of ``metadata`` following the same structure. - - root_path (str): - The path where the ``metadata.json`` is located. Defaults to ``None``. - """ - - _child_map = None - _hyper_transformers = None - _metadata = None - _parent_map = None - - root_path = None - - _FIELD_TEMPLATES = { - 'i': { - 'type': 'numerical', - 'subtype': 'integer', - }, - 'f': { - 'type': 'numerical', - 'subtype': 'float', - }, - 'O': { - 'type': 'categorical', - }, - 'b': { - 'type': 'boolean', - }, - 'M': { - 'type': 'datetime', - } - } - _DTYPES = { - ('categorical', None): 'object', - ('boolean', None): 'bool', - ('numerical', None): 'float', - ('numerical', 'float'): 'float', - ('numerical', 'integer'): 'int', - ('datetime', None): 'datetime64', - ('id', None): 'int', - ('id', 'integer'): 'int', - ('id', 'string'): 'str' - } - - def _analyze_relationships(self): - """Extract information about child-parent relationships. - - Creates the following attributes: - * ``_child_map``: set of child tables that each table has. - * ``_parent_map``: set ot parents that each table has. - """ - self._child_map = defaultdict(set) - self._parent_map = defaultdict(set) - - for table, table_meta in self._metadata['tables'].items(): - if table_meta.get('use', True): - for field_meta in table_meta['fields'].values(): - ref = field_meta.get('ref') - if ref: - parent = ref['table'] - self._child_map[parent].add(table) - self._parent_map[table].add(parent) - - @staticmethod - def _dict_metadata(metadata): - """Get a metadata ``dict`` with SDV format. - - For each table create a dict of fields from a previous list of fields. - - Args: - metadata (dict): - Original metadata to format. - - Returns: - dict: - Formated metadata dict. - """ - new_metadata = copy.deepcopy(metadata) - tables = new_metadata['tables'] - if isinstance(tables, dict): - new_metadata['tables'] = { - table: meta - for table, meta in tables.items() - if meta.pop('use', True) - } - return new_metadata - - new_tables = dict() - for table in tables: - if table.pop('use', True): - new_tables[table.pop('name')] = table - - fields = table['fields'] - new_fields = dict() - for field in fields: - new_fields[field.pop('name')] = field - - table['fields'] = new_fields - - new_metadata['tables'] = new_tables - - return new_metadata - - def __init__(self, metadata=None, root_path=None): - if isinstance(metadata, str): - self.root_path = root_path or os.path.dirname(metadata) - with open(metadata) as metadata_file: - metadata = json.load(metadata_file) - else: - self.root_path = root_path or '.' - - if metadata is not None: - self._metadata = self._dict_metadata(metadata) - else: - self._metadata = {'tables': {}} - - self._hyper_transformers = dict() - self._analyze_relationships() - - def get_children(self, table_name): - """Get tables for which the given table is parent. - - Args: - table_name (str): - Name of the table from which to get the children. - - Returns: - set: - Set of children for the given table. - """ - return self._child_map[table_name] - - def get_parents(self, table_name): - """Get tables for with the given table is child. - - Args: - table_name (str): - Name of the table from which to get the parents. - - Returns: - set: - Set of parents for the given table. - """ - return self._parent_map[table_name] - - def get_table_meta(self, table_name): - """Get the metadata dict for a table. - - Args: - table_name (str): - Name of table to get data for. - - Returns: - dict: - table metadata - - Raises: - ValueError: - If table does not exist in this metadata. - """ - table = self._metadata['tables'].get(table_name) - if table is None: - raise ValueError('Table "{}" does not exist'.format(table_name)) - - return copy.deepcopy(table) - - def get_tables(self): - """Get the list of table names. - - Returns: - list: - table names. - """ - return list(self._metadata['tables'].keys()) - - def get_field_meta(self, table_name, field_name): - """Get the metadata dict for a table. - - Args: - table_name (str): - Name of the table to which the field belongs. - field_name (str): - Name of the field to get data for. - - Returns: - dict: - field metadata - - Raises: - ValueError: - If the table or the field do not exist in this metadata. - """ - field_meta = self.get_fields(table_name).get(field_name) - if field_meta is None: - raise ValueError( - 'Table "{}" does not contain a field name "{}"'.format(table_name, field_name)) - - return copy.deepcopy(field_meta) - - def get_fields(self, table_name): - """Get table fields metadata. - - Args: - table_name (str): - Name of the table to get the fields from. - - Returns: - dict: - Mapping of field names and their metadata dicts. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - return self.get_table_meta(table_name)['fields'] - - def get_primary_key(self, table_name): - """Get the primary key name of the indicated table. - - Args: - table_name (str): - Name of table for which to get the primary key field. - - Returns: - str or None: - Primary key field name. ``None`` if the table has no primary key. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - return self.get_table_meta(table_name).get('primary_key') - - def get_foreign_key(self, parent, child): - """Get table foreign key field name. - - Args: - parent (str): - Name of the parent table. - child (str): - Name of the child table. - - Returns: - str or None: - Foreign key field name. - - Raises: - ValueError: - If the relationship does not exist. - """ - for name, field in self.get_fields(child).items(): - ref = field.get('ref') - if ref and ref['table'] == parent: - return name - - raise ValueError('{} is not parent of {}'.format(parent, child)) - - def load_table(self, table_name): - """Load table data. - - Args: - table_name (str): - Name of the table to load. - - Returns: - pandas.DataFrame: - DataFrame with the contents of the table. - - Raises: - ValueError: - If table does not exist in this metadata. - """ - LOGGER.info('Loading table %s', table_name) - table_meta = self.get_table_meta(table_name) - return _load_csv(self.root_path, table_meta) - - def load_tables(self, tables=None): - """Get a dictionary with data from multiple tables. - - If a ``tables`` list is given, only load the indicated tables. - Otherwise, load all the tables from this metadata. - - Args: - tables (list): - List of table names. Defaults to ``None``. - - Returns: - dict(str, pandasd.DataFrame): - mapping of table names and their data loaded as ``pandas.DataFrame`` instances. - """ - return { - table_name: self.load_table(table_name) - for table_name in tables or self.get_tables() - } - - def get_dtypes(self, table_name, ids=False): - """Get a ``dict`` with the ``dtypes`` for each field of a given table. - - Args: - table_name (str): - Table name for which to retrive the ``dtypes``. - ids (bool): - Whether or not include the id fields. Defaults to ``False``. - - Returns: - dict: - Dictionary that contains the field names and data types from a table. - - Raises: - ValueError: - If a field has an invalid type or subtype or if the table does not - exist in this metadata. - """ - dtypes = dict() - table_meta = self.get_table_meta(table_name) - for name, field in table_meta['fields'].items(): - field_type = field['type'] - field_subtype = field.get('subtype') - dtype = self._DTYPES.get((field_type, field_subtype)) - if not dtype: - raise MetadataError( - 'Invalid type and subtype combination for field {}: ({}, {})'.format( - name, field_type, field_subtype) - ) - - if ids and field_type == 'id': - if (name != table_meta.get('primary_key')) and not field.get('ref'): - for child_table in self.get_children(table_name): - if name == self.get_foreign_key(table_name, child_table): - break - - else: - raise MetadataError( - 'id field `{}` is neither a primary or a foreign key'.format(name)) - - if ids or (field_type != 'id'): - dtypes[name] = dtype - - return dtypes - - def _get_pii_fields(self, table_name): - """Get the ``pii_category`` for each field that contains PII. - - Args: - table_name (str): - Table name for which to get the pii fields. - - Returns: - dict: - pii field names and categories. - """ - pii_fields = dict() - for name, field in self.get_table_meta(table_name)['fields'].items(): - if field['type'] == 'categorical' and field.get('pii', False): - pii_fields[name] = field['pii_category'] - - return pii_fields - - @staticmethod - def _get_transformers(dtypes, pii_fields): - """Create the transformer instances needed to process the given dtypes. - - Temporary drop-in replacement of ``HyperTransformer._analyze`` method, - before RDT catches up. - - Args: - dtypes (dict): - mapping of field names and dtypes. - pii_fields (dict): - mapping of pii field names and categories. - - Returns: - dict: - mapping of field names and transformer instances. - """ - transformers_dict = dict() - for name, dtype in dtypes.items(): - dtype = np.dtype(dtype) - if dtype.kind == 'i': - transformer = transformers.NumericalTransformer(dtype=int) - elif dtype.kind == 'f': - transformer = transformers.NumericalTransformer(dtype=float) - elif dtype.kind == 'O': - anonymize = pii_fields.get(name) - transformer = transformers.CategoricalTransformer(anonymize=anonymize) - elif dtype.kind == 'b': - transformer = transformers.BooleanTransformer() - elif dtype.kind == 'M': - transformer = transformers.DatetimeTransformer() - else: - raise ValueError('Unsupported dtype: {}'.format(dtype)) - - LOGGER.info('Loading transformer %s for field %s', - transformer.__class__.__name__, name) - transformers_dict[name] = transformer - - return transformers_dict - - def _load_hyper_transformer(self, table_name): - """Create and return a new ``rdt.HyperTransformer`` instance for a table. - - First get the ``dtypes`` and ``pii fields`` from a given table, then use - those to build a transformer dictionary to be used by the ``HyperTransformer``. - - Args: - table_name (str): - Table name for which to load the HyperTransformer. - - Returns: - rdt.HyperTransformer: - Instance of ``rdt.HyperTransformer`` for the given table. - """ - dtypes = self.get_dtypes(table_name) - pii_fields = self._get_pii_fields(table_name) - transformers_dict = self._get_transformers(dtypes, pii_fields) - return HyperTransformer(transformers=transformers_dict) - - def transform(self, table_name, data): - """Transform data for a given table. - - If the ``HyperTransformer`` for a table is ``None`` it is created. - - Args: - table_name (str): - Name of the table that is being transformer. - data (pandas.DataFrame): - Table data. - - Returns: - pandas.DataFrame: - Transformed data. - """ - hyper_transformer = self._hyper_transformers.get(table_name) - if hyper_transformer is None: - hyper_transformer = self._load_hyper_transformer(table_name) - fields = list(hyper_transformer.transformers.keys()) - hyper_transformer.fit(data[fields]) - self._hyper_transformers[table_name] = hyper_transformer - - hyper_transformer = self._hyper_transformers.get(table_name) - fields = list(hyper_transformer.transformers.keys()) - return hyper_transformer.transform(data[fields]) - - def reverse_transform(self, table_name, data): - """Reverse the transformed data for a given table. - - Args: - table_name (str): - Name of the table to reverse transform. - data (pandas.DataFrame): - Data to be reversed. - - Returns: - pandas.DataFrame - """ - hyper_transformer = self._hyper_transformers[table_name] - reversed_data = hyper_transformer.reverse_transform(data) - - for name, dtype in self.get_dtypes(table_name, ids=True).items(): - reversed_data[name] = reversed_data[name].dropna().astype(dtype) - - return reversed_data - - # ################### # - # Metadata Validation # - # ################### # - - def _validate_table(self, table_name, table_meta, table_data=None): - """Validate table metadata. - - Validate the type and subtype combination for each field in ``table_meta``. - If a field has type ``id``, validate that it either is the ``primary_key`` or - has a ``ref`` entry. - - If the table has ``primary_key``, make sure that the corresponding field exists - and its type is ``id``. - - If ``table_data`` is provided, also check that the list of columns corresponds - to the ones indicated in the metadata and that all the dtypes are valid. - - Args: - table_name (str): - Name of the table to validate. - table_meta (dict): - Metadata of the table to validate. - table_data (pandas.DataFrame): - If provided, make sure that the data matches the one described - on the metadata. - - Raises: - MetadataError: - If there is any error in the metadata or the data does not - match the metadata description. - """ - dtypes = self.get_dtypes(table_name, ids=True) - - # Primary key field exists and its type is 'id' - primary_key = table_meta.get('primary_key') - if primary_key: - pk_field = table_meta['fields'].get(primary_key) - - if not pk_field: - raise MetadataError('Primary key is not an existing field.') - - if pk_field['type'] != 'id': - raise MetadataError('Primary key is not of type `id`.') - - if table_data is not None: - for column in table_data: - try: - dtype = dtypes.pop(column) - table_data[column].dropna().astype(dtype) - except KeyError: - message = 'Unexpected column in table `{}`: `{}`'.format(table_name, column) - raise MetadataError(message) from None - except ValueError as ve: - message = 'Invalid values found in column `{}` of table `{}`: `{}`'.format( - column, table_name, ve) - raise MetadataError(message) from None - - # assert all dtypes are in data - if dtypes: - raise MetadataError( - 'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys())) - ) - - def _validate_circular_relationships(self, parent, children=None): - """Validate that there is no circular relatioship in the metadata.""" - if children is None: - children = self.get_children(parent) - - if parent in children: - raise MetadataError('Circular relationship found for table "{}"'.format(parent)) - - for child in children: - self._validate_circular_relationships(parent, self.get_children(child)) - - def validate(self, tables=None): - """Validate this metadata. - - For each table from in metadata ``tables`` entry: - * Validate the table metadata is correct. - - * If ``tables`` are provided or they have been loaded, check - that all the metadata tables exists in the ``tables`` dictionary. - * Validate the type/subtype combination for each field and - if a field of type ``id`` exists it must be the ``primary_key`` - or must have a ``ref`` entry. - * If ``primary_key`` entry exists, check that it's an existing - field and its type is ``id``. - * If ``tables`` are provided or they have been loaded, check - all the data types for the table correspond to each column and - all the data types exists on the table. - * Validate that there is no circular relatioship in the metadata. - * Check that all the tables have at most one parent. - - Args: - tables (bool, dict): - If a dict of table is passed, validate that the columns and - dtypes match the metadata. If ``True`` is passed, load the - tables from the Metadata instead. If ``None``, omit the data - validation. Defaults to ``None``. - """ - tables_meta = self._metadata.get('tables') - if not tables_meta: - raise MetadataError('"tables" entry not found in Metadata.') - - if tables and not isinstance(tables, dict): - tables = self.load_tables() - - for table_name, table_meta in tables_meta.items(): - if tables: - table = tables.get(table_name) - if table is None: - raise MetadataError('Table `{}` not found in tables'.format(table_name)) - - else: - table = None - - self._validate_table(table_name, table_meta, table) - self._validate_circular_relationships(table_name) - - def _check_field(self, table, field, exists=False): - """Validate the existance of the table and existance (or not) of field.""" - table_fields = self.get_fields(table) - if exists and (field not in table_fields): - raise ValueError('Field "{}" does not exist in table "{}"'.format(field, table)) - - if not exists and (field in table_fields): - raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) - - # ################# # - # Metadata Creation # - # ################# # - - def add_field(self, table, field, field_type, field_subtype=None, properties=None): - """Add a new field to the indicated table. - - Args: - table (str): - Table name to add the new field, it must exist. - field (str): - Field name to be added, it must not exist. - field_type (str): - Data type of field to be added. Required. - field_subtype (str): - Data subtype of field to be added. Optional. - Defaults to ``None``. - properties (dict): - Extra properties of field like: ref, format, min, max, etc. Optional. - Defaults to ``None``. - - Raises: - ValueError: - If the table does not exist or it already contains the field. - """ - self._check_field(table, field, exists=False) - - field_details = { - 'type': field_type - } - - if field_subtype: - field_details['subtype'] = field_subtype - - if properties: - field_details.update(properties) - - self._metadata['tables'][table]['fields'][field] = field_details - - @staticmethod - def _get_key_subtype(field_meta): - """Get the appropriate key subtype.""" - field_type = field_meta['type'] - - if field_type == 'categorical': - field_subtype = 'string' - - elif field_type in ('numerical', 'id'): - field_subtype = field_meta['subtype'] - if field_subtype not in ('integer', 'string'): - raise ValueError( - 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) - ) - - else: - raise ValueError( - 'Invalid field "type" for key field: "{}"'.format(field_type) - ) - - return field_subtype - - def set_primary_key(self, table, field): - """Set the primary key field of the indicated table. - - The field must exist and either be an integer or categorical field. - - Args: - table (str): - Name of the table where the primary key will be set. - field (str): - Field to be used as the new primary key. - - Raises: - ValueError: - If the table or the field do not exist or if the field has an - invalid type or subtype. - """ - self._check_field(table, field, exists=True) - - field_meta = self.get_fields(table).get(field) - field_subtype = self._get_key_subtype(field_meta) - - table_meta = self._metadata['tables'][table] - table_meta['fields'][field] = { - 'type': 'id', - 'subtype': field_subtype - } - table_meta['primary_key'] = field - - def add_relationship(self, parent, child, parent_key=None, child_key=None): - """Add a new relationship between the parent and child tables. - - The relationship is created by adding a reference (``ref``) on the ``child_key`` - field of the ``child`` table pointing at the ``parent_key`` field from the - ``parent`` table. - - Args: - parent (str): - Name of the parent table. - child (str): - Name of the child table. - parent_key (str): - Field in the parent table through which the relationship is created. - If ``None``, use the parent primary key name. - child_key (str): - Field in the child table through which the relationship is created. - If ``None``, use the name of the parent key. - - Raises: - ValueError: - If any of the following happens: - * The parent or child tables do not exist. - * The parent_key or child_key fields do not exist. - * The child_key already is a foreign key. - * The new relationship closes a relationship circle. - """ - # Validate tables exists - self.get_table_meta(parent) - self.get_table_meta(child) - - # Validate fields exists - if parent_key is None: - parent_key = self.get_primary_key(parent) - if not parent_key: - msg = 'If parent table does not have a primary key, a `parent_key` must be given' - raise ValueError(msg.format(parent)) - - if child_key is None: - child_key = parent_key - - parent_key_meta = copy.deepcopy(self.get_field_meta(parent, parent_key)) - child_key_meta = copy.deepcopy(self.get_field_meta(child, child_key)) - - # Validate relationships - child_ref = child_key_meta.get('ref') - if child_ref: - raise ValueError( - 'Field "{}.{}" already defines a relationship'.format(child, child_key)) - - grandchildren = self.get_children(child) - if grandchildren: - self._validate_circular_relationships(parent, grandchildren) - - # Make sure that the parent key is an id - if parent_key_meta['type'] != 'id': - parent_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) - parent_key_meta['type'] = 'id' - - # Update the child key meta - child_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) - child_key_meta['type'] = 'id' - child_key_meta['ref'] = { - 'table': parent, - 'field': parent_key - } - - # Make sure that key subtypes are the same - if child_key_meta['subtype'] != parent_key_meta['subtype']: - raise ValueError('Parent and Child key subtypes mismatch') - - # Make a backup - metadata_backup = copy.deepcopy(self._metadata) - - self._metadata['tables'][parent]['fields'][parent_key] = parent_key_meta - self._metadata['tables'][child]['fields'][child_key] = child_key_meta - - # Re-analyze the relationships - self._analyze_relationships() - - try: - self.validate() - except MetadataError: - self._metadata = metadata_backup - raise - - def _get_field_details(self, data, fields): - """Get or build all the fields metadata. - - Analyze a ``pandas.DataFrame`` to build a ``dict`` with the name of the column, and - their data type and subtype. If ``columns`` are provided, only those columns will be - analyzed. - - Args: - data (pandas.DataFrame): - Table to be analyzed. - fields (set): - Set of field names or field specifications. - - Returns: - dict: - Dict of valid fields. - - Raises: - TypeError: - If a field specification is not a str or a dict. - ValueError: - If a column from the data analyzed is an unsupported data type or - """ - fields_metadata = dict() - for field in fields: - dtype = data[field].dtype - field_template = self._FIELD_TEMPLATES.get(dtype.kind) - if not field_template: - raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field)) - - field_details = copy.deepcopy(field_template) - fields_metadata[field] = field_details - - return fields_metadata - - def add_table(self, name, data=None, fields=None, fields_metadata=None, - primary_key=None, parent=None, parent_key=None, foreign_key=None): - """Add a new table to this metadata. - - ``fields`` list can be a mixture of field names, which will be build automatically - from the data, or dictionaries specifying the field details. If a field needs to be - analyzed, data has to be also passed. - - If ``parent`` is given, a relationship will be established between this table - and the specified parent. - - Args: - name (str): - Name of the new table. - data (str or pandas.DataFrame): - Table to be analyzed or path to the csv file. - If it's a relative path, use ``root_path`` to find the file. - Only used if fields is not ``None``. - Defaults to ``None``. - fields (list): - List of field names to build. If ``None`` is given, all the fields - found in the data will be used. - Defaults to ``None``. - fields_metadata (dict): - Metadata to be used when creating fields. This will overwrite the - metadata built from the fields found in data. - Defaults to ``None``. - primary_key (str): - Field name to add as primary key, it must not exists. Defaults to ``None``. - parent (str): - Table name to refere a foreign key field. Defaults to ``None``. - parent_key (str): - Name of the field from the ``parent`` table that is pointed by the given - ``foreign_key``. Defaults to the ``parent`` primary key. - foreign_key (str): - Name of the field from the added table that forms a relationship with - the ``parent`` table. Defaults to the same name as ``parent_key``. - - Raises: - ValueError: - If the table ``name`` already exists or ``data`` is not passed and - fields need to be built from it. - """ - if name in self.get_tables(): - raise ValueError('Table "{}" already exists.'.format(name)) - - path = None - if data is not None: - if isinstance(data, str): - path = data - if not os.path.isabs(data): - data = os.path.join(self.root_path, data) - - data = pd.read_csv(data) - - fields = set(fields or data.columns) - if fields_metadata: - fields = fields - set(fields_metadata.keys()) - else: - fields_metadata = dict() - - fields_metadata.update(self._get_field_details(data, fields)) - - elif fields_metadata is None: - fields_metadata = dict() - - table_metadata = {'fields': fields_metadata} - if path: - table_metadata['path'] = path - - self._metadata['tables'][name] = table_metadata - - try: - if primary_key: - self.set_primary_key(name, primary_key) - - if parent: - self.add_relationship(parent, name, parent_key, foreign_key) - - except ValueError: - # Cleanup - del self._metadata['tables'][name] - raise - - # ###################### # - # Metadata Serialization # - # ###################### # - - def to_dict(self): - """Get a dict representation of this metadata. - - Returns: - dict: - dict representation of this metadata. - """ - return copy.deepcopy(self._metadata) - - def to_json(self, path): - """Dump this metadata into a JSON file. - - Args: - path (str): - Path of the JSON file where this metadata will be stored. - """ - with open(path, 'w') as out_file: - json.dump(self._metadata, out_file, indent=4) - - @staticmethod - def _get_graphviz_extension(path): - if path: - path_splitted = path.split('.') - if len(path_splitted) == 1: - raise ValueError('Path without graphviz extansion.') - - graphviz_extension = path_splitted[-1] - - if graphviz_extension not in graphviz.backend.FORMATS: - raise ValueError( - '"{}" not a valid graphviz extension format.'.format(graphviz_extension) - ) - - return '.'.join(path_splitted[:-1]), graphviz_extension - - return None, None - - def _visualize_add_nodes(self, plot): - """Add nodes into a `graphviz.Digraph`. - - Each node represent a metadata table. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - # Append table fields - fields = [] - - for name, value in self.get_fields(table).items(): - if value.get('subtype') is not None: - fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) - - else: - fields.append('{} : {}'.format(name, value['type'])) - - fields = r'\l'.join(fields) - - # Append table extra information - extras = [] - - primary_key = self.get_primary_key(table) - if primary_key is not None: - extras.append('Primary key: {}'.format(primary_key)) - - parents = self.get_parents(table) - for parent in parents: - foreign_key = self.get_foreign_key(parent, table) - extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) - - path = self.get_table_meta(table).get('path') - if path is not None: - extras.append('Data path: {}'.format(path)) - - extras = r'\l'.join(extras) - - # Add table node - title = r'{%s|%s\l|%s\l}' % (table, fields, extras) - plot.node(table, label=title) - - def _visualize_add_edges(self, plot): - """Add edges into a `graphviz.Digraph`. - - Each edge represents a relationship between two metadata tables. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - for parent in list(self.get_parents(table)): - plot.edge( - parent, - table, - label=' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ), - arrowhead='crow' - ) - - def visualize(self, path=None): - """Plot metadata usign graphviz. - - Try to generate a plot using graphviz. - If a ``path`` is provided save the output into a file. - - Args: - path (str): - Output file path to save the plot, it requires a graphviz - supported extension. If ``None`` do not save the plot. - Defaults to ``None``. - """ - filename, graphviz_extension = self._get_graphviz_extension(path) - plot = graphviz.Digraph( - 'Metadata', - format=graphviz_extension, - node_attr={ - "shape": "Mrecord", - "fillcolor": "lightgoldenrod1", - "style": "filled" - }, - ) - - self._visualize_add_nodes(plot) - self._visualize_add_edges(plot) - - if filename: - plot.render(filename=filename, cleanup=True, format=graphviz_extension) - else: - return plot - - def __str__(self): - tables = self.get_tables() - relationships = [ - ' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ) - for table in tables - for parent in list(self.get_parents(table)) - ] - - return ( - "Metadata\n" - " root_path: {}\n" - " tables: {}\n" - " relationships:\n" - "{}" - ).format( - os.path.abspath(self.root_path), - tables, - '\n'.join(relationships) - ) From 96395e89ea9789210f4aaf90e58fcbd0d1b7ec65 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 22:21:22 +0200 Subject: [PATCH 32/33] Recover parameters properly --- sdv/metadata/table.py | 11 ++++ sdv/tabular/copulas.py | 61 +++++++++++++++++------ tests/integration/tabular/test_copulas.py | 24 ++++++--- 3 files changed, 74 insertions(+), 22 deletions(-) diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py index 893549272..598a4a725 100644 --- a/sdv/metadata/table.py +++ b/sdv/metadata/table.py @@ -386,3 +386,14 @@ def to_json(self, path): """ with open(path, 'w') as out_file: json.dump(self._metadata, out_file, indent=4) + + @classmethod + def from_json(cls, path): + """Load a Table from a JSON + + Args: + path (str): + Path of the JSON file to load + """ + with open(path, 'r') as in_file: + return cls(json.load(in_file)) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index dc82dd862..7ae00120f 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -20,8 +20,9 @@ class GaussianCopula(BaseTabularModel): ``categorical_fuzzy``. """ - DISTRIBUTION = copulas.univariate.GaussianUnivariate + DEFAULT_DISTRIBUTION = copulas.univariate.Univariate _distribution = None + _categorical_transformer = None _model = None HYPERPARAMETERS = { @@ -51,6 +52,7 @@ class GaussianCopula(BaseTabularModel): ] } } + DEFAULT_TRANSFORMER = 'one_hot_encoding' CATEGORICAL_TRANSFORMERS = { 'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False), 'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True), @@ -61,10 +63,24 @@ class GaussianCopula(BaseTabularModel): 'O': rdt.transformers.OneHotEncodingTransformer } - def __init__(self, distribution=None, categorical_transformer='categorical', - *args, **kwargs): + def __init__(self, distribution=None, categorical_transformer=None, *args, **kwargs): super().__init__(*args, **kwargs) - self._distribution = distribution or self.DISTRIBUTION + + if self._metadata is not None and 'model_kwargs' in self._metadata._metadata: + model_kwargs = self._metadata._metadata['model_kwargs'] + if distribution is None: + distribution = model_kwargs['distribution'] + + if categorical_transformer is None: + categorical_transformer = model_kwargs['categorical_transformer'] + + self._size = model_kwargs['_size'] + + self._distribution = distribution or self.DEFAULT_DISTRIBUTION + + categorical_transformer = categorical_transformer or self.DEFAULT_TRANSFORMER + self._categorical_transformer = categorical_transformer + self.TRANSFORMER_TEMPLATES['O'] = self.CATEGORICAL_TRANSFORMERS[categorical_transformer] def _update_metadata(self): @@ -72,10 +88,15 @@ def _update_metadata(self): univariates = parameters['univariates'] columns = parameters['columns'] - fields = self._metadata.get_fields() - for field_name, univariate in zip(columns, univariates): - field_meta = fields[field_name] - field_meta['distribution'] = univariate['type'] + distributions = {} + for column, univariate in zip(columns, univariates): + distributions[column] = univariate['type'] + + self._metadata._metadata['model_kwargs'] = { + 'distribution': distributions, + 'categorical_transformer': self._categorical_transformer, + '_size': self._size + } def _fit(self, data): """Fit the model to the table. @@ -86,6 +107,7 @@ def _fit(self, data): """ self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) self._model.fit(data) + self._update_metadata() def _sample(self, size): """Sample ``size`` rows from the model. @@ -100,16 +122,24 @@ def _sample(self, size): """ return self._model.sample(size) - def get_parameters(self): + def get_parameters(self, flatten=False): """Get copula model parameters. Compute model ``covariance`` and ``distribution.std`` before it returns the flatten dict. + Args: + flatten (bool): + Whether to flatten the parameters or not before + returning them. + Returns: dict: - Copula flatten parameters. + Copula parameters. """ + if not flatten: + return self._model.to_dict() + values = list() triangle = np.tril(self._model.covariance) @@ -195,7 +225,7 @@ def _unflatten_gaussian_copula(self, model_parameters): return model_parameters - def set_parameters(self, parameters): + def set_parameters(self, parameters, unflatten=False): """Set copula model parameters. Add additional keys after unflatte the parameters @@ -204,10 +234,13 @@ def set_parameters(self, parameters): Args: dict: Copula flatten parameters. + unflatten (bool): + Whether the parameters need to be unflattened or not. """ - parameters = unflatten_dict(parameters) - parameters.setdefault('distribution', self.distribution) + if unflatten: + parameters = unflatten_dict(parameters) + parameters.setdefault('distribution', self._distribution) - parameters = self._unflatten_gaussian_copula(parameters) + parameters = self._unflatten_gaussian_copula(parameters) self._model = copulas.multivariate.GaussianMultivariate.from_dict(parameters) diff --git a/tests/integration/tabular/test_copulas.py b/tests/integration/tabular/test_copulas.py index a40fbe7bb..0a34a0e15 100644 --- a/tests/integration/tabular/test_copulas.py +++ b/tests/integration/tabular/test_copulas.py @@ -27,7 +27,14 @@ def test_gaussian_copula(): ) gc.fit(users) - sampled = gc.sample() + parameters = gc.get_parameters() + new_gc = GaussianCopula( + table_metadata=gc.get_metadata(), + categorical_transformer='one_hot_encoding', + ) + new_gc.set_parameters(parameters) + + sampled = new_gc.sample() # test shape is right assert sampled.shape == users.shape @@ -38,11 +45,12 @@ def test_gaussian_copula(): # country codes have been replaced with new ones assert set(sampled.country.unique()) != set(users.country.unique()) - assert gc.get_metadata().to_dict() == { - 'fields': { - 'user_id': {'type': 'id', 'subtype': 'integer'}, - 'country': {'type': 'categorical'}, - 'gender': {'type': 'categorical'}, - 'age': {'type': 'numerical', 'subtype': 'integer'} - } + metadata = gc.get_metadata().to_dict() + assert metadata['fields'] == { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} } + + assert 'model_kwargs' in metadata From d9acc57e5e0f48750a22934c1e611bb48c8df9a4 Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 9 Jul 2020 22:42:09 +0200 Subject: [PATCH 33/33] Add release notes for v0.3.5 --- HISTORY.md | 18 ++++++++ docs/tutorials/demo_metadata.json | 74 ------------------------------- 2 files changed, 18 insertions(+), 74 deletions(-) delete mode 100644 docs/tutorials/demo_metadata.json diff --git a/HISTORY.md b/HISTORY.md index 00b164949..3b9e68395 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,23 @@ # History +## 0.3.5 - 2020-07-09 + +This release introduces a new subpackage `sdv.tabular` with models designed specifically +for single table modeling, while still providing all the usual conveniences from SDV, such +as: + +* Seamless multi-type support +* Missing data handling +* PII anonymization + +Currently implemented models are: + +* GaussianCopula: Multivariate distributions modeled using copula functions. This is stronger + version, with more marginal distributions and options, than the one used to model multi-table + datasets. +* CTGAN: GAN-based data synthesizer that can generate synthetic tabular data with high fidelity. + + ## 0.3.4 - 2020-07-04 ## New Features diff --git a/docs/tutorials/demo_metadata.json b/docs/tutorials/demo_metadata.json deleted file mode 100644 index f59aa96da..000000000 --- a/docs/tutorials/demo_metadata.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "tables": { - "users": { - "fields": { - "gender": { - "type": "categorical" - }, - "age": { - "type": "numerical", - "subtype": "integer" - }, - "user_id": { - "type": "id", - "subtype": "integer" - }, - "country": { - "type": "categorical" - } - }, - "primary_key": "user_id" - }, - "sessions": { - "fields": { - "os": { - "type": "categorical" - }, - "device": { - "type": "categorical" - }, - "user_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "users", - "field": "user_id" - } - }, - "session_id": { - "type": "id", - "subtype": "integer" - } - }, - "primary_key": "session_id" - }, - "transactions": { - "fields": { - "timestamp": { - "type": "datetime", - "format": "%Y-%m-%d" - }, - "amount": { - "type": "numerical", - "subtype": "float" - }, - "session_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "sessions", - "field": "session_id" - } - }, - "approved": { - "type": "boolean" - }, - "transaction_id": { - "type": "id", - "subtype": "integer" - } - }, - "primary_key": "transaction_id" - } - } -} \ No newline at end of file