diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 12d63d205..02e92fd67 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,6 +18,7 @@ jobs: - name: Build run: | + sudo apt install pandoc python -m pip install --upgrade pip pip install -e .[dev] make docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0f882e819..7357cc190 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,12 +24,12 @@ jobs: - if: matrix.os == 'ubuntu-latest' name: Install graphviz - Ubuntu run: | - sudo apt-get install graphviz + sudo apt-get install graphviz pandoc - if: matrix.os == 'macos-latest' name: Install graphviz - MacOS run: | - brew install graphviz + brew install graphviz pandoc - name: Install dependencies run: | diff --git a/.gitignore b/.gitignore index 241e9f48f..2e25954c9 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ ENV/ .*.swp sdv/data/ +tutorials/sdv.pkl +tutorials/demo_metadata.json diff --git a/.travis.yml b/.travis.yml index 8204e8dfc..8345682c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: # Command to install dependencies install: - sudo apt-get update - - sudo apt-get install graphviz + - sudo apt-get install graphviz pandoc - pip install -U tox-travis codecov after_success: codecov diff --git a/EVALUATION.md b/EVALUATION.md index 5d64ac6db..a8343184f 100644 --- a/EVALUATION.md +++ b/EVALUATION.md @@ -12,7 +12,7 @@ generate a simple standardized score. After you have modeled your databased and generated samples out of the SDV models you will be left with a dictionary that contains table names and dataframes. -For exmple, if we model and sample the demo dataset: +For example, if we model and sample the demo dataset: ```python3 from sdv import SDV @@ -44,3 +44,94 @@ the value will be negative. For further options, including visualizations and more detailed reports, please refer to the [SDMetrics](https://github.com/sdv-dev/SDMetrics) library. + + +## SDV Benchmark + +SDV also provides a simple functionality to evaluate the performance of SDV across a +collection of demo datasets or custom datasets hosted in a local folder. + +In order to execute this evaluation you can execute the function `sdv.benchmark.run_benchmark`: + +```python3 +from sdv.benchmark import run_benchmark + +scores = run_benchmark() +``` + +This function has the following arguments: + +* `datasets`: List of dataset names, which can either be names of demo datasets or + names of custom datasets stored in a local folder. +* `datasets_path`: Path where the custom datasets are stored. If not provided, the + dataset names are interpreted as demo datasets. +* `distributed`: Whether to execute the benchmark using Dask. Defaults to True. +* `timeout`: Maximum time allowed for each dataset to be modeled, sampled and evaluated. + Any dataset that takes longer to run will return a score of `None`. + +For example, the following command will run the SDV benchmark on all the given demo datasets +using `dask` and a timeout of 60 seconds: + +```python +scores = run_benchmark( + datasets=['DCG_v1', 'trains_v1', 'UTube_v1'], + distributed=True, + timeout=60 +) +``` + +And the result will be a DataFrame containing a table with the columns `dataset`, `score`: + +| dataset | score | +|:-------:|:-----:| +| DCG_v1 | -14.49341665631863 | +| trains_v1 | -30.26840342069557 | +| UTube_v1 | -8.57618576332235 | + +Additionally, if some dataset has raised an error or has reached the timeout, an `error` +column will be added indicating the details. + +### Demo Datasets + +The collection of datasets can be seen using the `sdv.demo.get_demo_demos`, +which returns a table with a description of the dataset properties: + +```python3 +from sdv.demo import get_available_demos + +demos = get_available_demos() +``` + +The result is a table indicating the name of the dataset and a few properties, such as the +number of tables that compose the dataset and the total number of rows and columns: + +| name | tables | rows | columns | +|-----------------------|----------|---------|-----------| +| UTube_v1 | 2 | 2735 | 10 | +| SAP_v1 | 4 | 3841029 | 71 | +| NCAA_v1 | 9 | 202305 | 333 | +| airbnb-simplified | 2 | 5751408 | 22 | +| Atherosclerosis_v1 | 4 | 12781 | 307 | +| rossmann | 3 | 2035533 | 21 | +| walmart | 4 | 544869 | 24 | +| AustralianFootball_v1 | 4 | 139179 | 193 | +| Pyrimidine_v1 | 2 | 296 | 38 | +| world_v1 | 3 | 5302 | 39 | +| Accidents_v1 | 3 | 1463093 | 87 | +| trains_v1 | 2 | 83 | 15 | +| legalActs_v1 | 5 | 1754397 | 50 | +| DCG_v1 | 2 | 8258 | 9 | +| imdb_ijs_v1 | 7 | 5647694 | 50 | +| SalesDB_v1 | 4 | 6735507 | 35 | +| MuskSmall_v1 | 2 | 568 | 173 | +| KRK_v1 | 1 | 1000 | 9 | +| Chess_v1 | 2 | 2052 | 57 | +| Telstra_v1 | 5 | 148021 | 23 | +| mutagenesis_v1 | 3 | 10324 | 26 | +| PremierLeague_v1 | 4 | 11308 | 250 | +| census | 1 | 32561 | 15 | +| FNHK_v1 | 3 | 2113275 | 43 | +| imdb_MovieLens_v1 | 7 | 1249411 | 58 | +| financial_v1 | 8 | 1079680 | 84 | +| ftp_v1 | 2 | 96491 | 13 | +| Triazine_v1 | 2 | 1302 | 35 | diff --git a/HISTORY.md b/HISTORY.md index b2a2e996d..3b9e68395 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,34 @@ # History +## 0.3.5 - 2020-07-09 + +This release introduces a new subpackage `sdv.tabular` with models designed specifically +for single table modeling, while still providing all the usual conveniences from SDV, such +as: + +* Seamless multi-type support +* Missing data handling +* PII anonymization + +Currently implemented models are: + +* GaussianCopula: Multivariate distributions modeled using copula functions. This is stronger + version, with more marginal distributions and options, than the one used to model multi-table + datasets. +* CTGAN: GAN-based data synthesizer that can generate synthetic tabular data with high fidelity. + + +## 0.3.4 - 2020-07-04 + +## New Features + +* Support for Multiple Parents - [Issue #162](https://github.com/sdv-dev/SDV/issues/162) by @csala +* Sample by default the same number of rows as in the original table - [Issue #163](https://github.com/sdv-dev/SDV/issues/163) by @csala + +### General Improvements + +* Add benchmark - [Issue #165](https://github.com/sdv-dev/SDV/issues/165) by @csala + ## 0.3.3 - 2020-06-26 ### General Improvements diff --git a/Makefile b/Makefile index bdcbeac21..5fc2f37f3 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,7 @@ clean-pyc: ## remove Python file artifacts .PHONY: clean-docs clean-docs: ## remove previously built docs rm -f docs/api/*.rst + rm -rf docs/tutorials -$(MAKE) -C docs clean 2>/dev/null # this fails if sphinx is not yet installed .PHONY: clean-coverage @@ -110,7 +111,7 @@ test-readme: ## run the readme snippets .PHONY: test-tutorials test-tutorials: ## run the tutorial notebooks - jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 examples/*.ipynb --stdout > /dev/null + jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 tutorials/*.ipynb --stdout > /dev/null .PHONY: test test: test-unit test-readme test-tutorials ## test everything that needs test dependencies @@ -134,6 +135,7 @@ coverage: ## check code coverage quickly with the default Python .PHONY: docs docs: clean-docs ## generate Sphinx HTML documentation, including API docs + cp -r tutorials docs/tutorials sphinx-apidoc --separate --no-toc -o docs/api/ sdv $(MAKE) -C docs html diff --git a/docs/conf.py b/docs/conf.py index 1c7e3ea90..c99d64607 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,6 +32,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'm2r', + 'nbsphinx', 'sphinx.ext.autodoc', 'sphinx.ext.githubpages', 'sphinx.ext.viewcode', @@ -53,6 +54,9 @@ # The master toctree document. master_doc = 'index' +# Jupyter Notebooks +nbsphinx_execute = 'never' + # General information about the project. project = 'SDV' slug = 'sdv' diff --git a/docs/index.rst b/docs/index.rst index 4f136e234..b79a5add0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,9 +7,11 @@ Overview .. toctree:: - :caption: Advanced Usage - :maxdepth: 3 + :caption: User Guides + :maxdepth: 2 + tutorials/02_Single_Table_Modeling + tutorials/03_Relational_Data_Modeling metadata .. toctree:: diff --git a/docs/metadata.rst b/docs/metadata.rst index c9f4d8607..d7f7dd933 100644 --- a/docs/metadata.rst +++ b/docs/metadata.rst @@ -1,5 +1,5 @@ -Metadata -======== +Working with Metadata +===================== In order to use **SDV** you will need a ``Metadata`` object alongside your data. diff --git a/docs/tutorials/01_Quickstart.ipynb b/docs/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/docs/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/02_Single_Table_Modeling.ipynb b/docs/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..fa536bd3d --- /dev/null +++ b/docs/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM38
11UKNaN23
22USAF34
33ESNaN47
44ESF29
55UKF39
66FRNaN40
77ESM38
88ESF32
99ESF36
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKF48
11USANaN38
22USAM29
33BGM22
44USAM43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" + ] + } + ], + "source": [ + "model.fit(census.sample(1000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
050Local-gov1697191st-4th9Widowed?HusbandWhiteMale114838Columbia<=50K
132?1524791st-4th9Never-marriedAdm-clericalWifeBlackMale-422021Jamaica>50K
222Private69617Bachelors0Separated?HusbandWhiteMale61138Guatemala<=50K
325?65285810th16Married-civ-spouseHandlers-cleanersNot-in-familyWhiteFemale152-2739Cuba<=50K
443Private301956Some-college8Married-civ-spouse?WifeWhiteMale-133-1239India<=50K
566Private401171Prof-school13SeparatedProtective-servUnmarriedBlackFemale-124-140Cuba<=50K
652Private278399Bachelors12Never-marriedProf-specialtyUnmarriedOtherMale122567-647Columbia<=50K
736Federal-gov229817HS-grad8Married-AF-spouseFarming-fishingNot-in-familyWhiteMale81938Portugal>50K
827Federal-gov306972Some-college8Never-marriedExec-managerialHusbandAsian-Pac-IslanderFemale42144339Japan>50K
928Local-gov4161611st-4th8DivorcedAdm-clericalUnmarriedWhiteFemale-349109061Guatemala>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", + "\n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", + "\n", + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-144.971907591418" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/03_Relational_Data_Modeling.ipynb b/docs/tutorials/03_Relational_Data_Modeling.ipynb new file mode 100644 index 000000000..230061b00 --- /dev/null +++ b/docs/tutorials/03_Relational_Data_Modeling.ipynb @@ -0,0 +1,1039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", + "\n", + "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", + "\n", + "From those stores we obtained a training of historical data \n", + "between 2010-02-05 and 2012-11-01. This historical data includes the sales of each department on a specific date.\n", + "In this notebook, we will show you step-by-step how to download the \"Walmart\" dataset, explain the structure and sample the data.\n", + "\n", + "In this demonstration we will show how SDV can be used to generate synthetic data. And lately, this data can be used to train machine learning models.\n", + "\n", + "*The dataset used in this example can be found in [Kaggle](https://www.kaggle.com/c/walmart-recruiting-store-sales-forecasting/data), but we will show how to download it from SDV.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data model summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

stores

\n", + "\n", + "| Field | Type | Subtype | Additional Properties |\n", + "|-------|-------------|---------|-----------------------|\n", + "| Store | id | integer | Primary key |\n", + "| Size | numerical | integer | |\n", + "| Type | categorical | | |\n", + "\n", + "Contains information about the 45 stores, indicating the type and size of store." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

features

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|-----------------------------|\n", + "| Store | id | integer | foreign key (stores.Store) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| IsHoliday | boolean | | |\n", + "| Fuel_Price | numerical | float | |\n", + "| Unemployment | numerical | float | |\n", + "| Temperature | numerical | float | |\n", + "| CPI | numerical | float | |\n", + "| MarkDown1 | numerical | float | |\n", + "| MarkDown2 | numerical | float | |\n", + "| MarkDown3 | numerical | float | |\n", + "| MarkDown4 | numerical | float | |\n", + "| MarkDown5 | numerical | float | |\n", + "\n", + "Contains historical training data, which covers to 2010-02-05 to 2012-11-01." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

depts

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|------------------------------|\n", + "| Store | id | integer | foreign key (stores.Stores) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| Weekly_Sales | numerical | float | |\n", + "| Dept | numerical | integer | |\n", + "| IsHoliday | boolean | | |\n", + "\n", + "Contains additional data related to the store, department, and regional activity for the given dates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load data\n", + "\n", + "Let's start downloading the data set. In this case, we will download the data set *walmart*. We will use the SDV function `load_demo`, we can specify the name of the dataset we want to use and if we want its Metadata object or not. To know more about the demo data [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.demo.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(dataset_name='walmart', metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our dataset is downloaded from an [Amazon S3 bucket](http://sdv-datasets.s3.amazonaws.com/index.html) that contains all available data sets of the `load_demo` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now visualize the metadata structure" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "stores\n", + "\n", + "stores\n", + "\n", + "Type : categorical\n", + "Size : numerical - integer\n", + "Store : id - integer\n", + "\n", + "Primary key: Store\n", + "Data path: stores.csv\n", + "\n", + "\n", + "\n", + "features\n", + "\n", + "features\n", + "\n", + "Date : datetime\n", + "MarkDown1 : numerical - float\n", + "Store : id - integer\n", + "IsHoliday : boolean\n", + "MarkDown4 : numerical - float\n", + "MarkDown3 : numerical - float\n", + "Fuel_Price : numerical - float\n", + "Unemployment : numerical - float\n", + "Temperature : numerical - float\n", + "MarkDown5 : numerical - float\n", + "MarkDown2 : numerical - float\n", + "CPI : numerical - float\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: features.csv\n", + "\n", + "\n", + "\n", + "stores->features\n", + "\n", + "\n", + "   features.Store -> stores.Store\n", + "\n", + "\n", + "\n", + "depts\n", + "\n", + "depts\n", + "\n", + "Date : datetime\n", + "Weekly_Sales : numerical - float\n", + "Store : id - integer\n", + "Dept : numerical - integer\n", + "IsHoliday : boolean\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: train.csv\n", + "\n", + "\n", + "\n", + "stores->depts\n", + "\n", + "\n", + "   depts.Store -> stores.Store\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And also validate that the metadata is correctly defined for our data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate(tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Create an instance of SDV and train the instance\n", + "\n", + "Once we download it, we have to create an SDV instance. With that instance, we have to analyze the loaded tables to generate a statistical model from the data. In this case, the process of adjusting the model is quickly because the dataset is small. However, with larger datasets it can be a slow process." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables=tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: We may not want to train the model every time we want to generate new synthetic data. We can [save](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) the SDV instance to [load](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) it later." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Generate synthetic data\n", + "\n", + "Once the instance is trained, we are ready to generate the synthetic data.\n", + "\n", + "The easiest way to generate synthetic data for the entire dataset is to call the `sample_all` method. By default, this method generates only 5 rows, but we can specify the row number that will be generated with the `num_rows` argument. To learn more about the available arguments, see [sample_all](https://sdv-dev.github.io/SDV/api/sdv.sampler.html#sdv.sampler.Sampler.sample_all)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'stores': 45, 'features': 8190, 'depts': 421570}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.modeler.table_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "samples = sdv.sample_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This returns a dictionary with a `pandas.DataFrame` for each table." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TypeSizeStore
0B1061060
1B680711
2C2536032
3A2232683
4B1668484
\n", + "
" + ], + "text/plain": [ + " Type Size Store\n", + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['stores'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateMarkDown1StoreIsHolidayMarkDown4MarkDown3Fuel_PriceUnemploymentTemperatureMarkDown5MarkDown2CPI
02009-12-13 04:59:49.832576000222.2688220False758.900545-1080.5694354.5749666.05279551.354606-4016.012141NaN224.973722
12011-07-20 14:40:40.4002447366857.7120280FalseNaN-551.0452573.0182966.72066980.723862NaNNaN204.595072
22011-10-12 08:23:54.9876582405530.9294250FalseNaNNaN3.4663587.25657073.954938NaNNaN202.591941
32012-11-26 09:22:59.377291776NaN0FalseNaN3703.1371913.6237256.59861572.041454NaN3925.456611169.457078
42013-05-13 15:22:12.213717760NaN0False3773.942572NaN3.3817955.11500283.05832215018.944605NaN182.801974
\n", + "
" + ], + "text/plain": [ + " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", + "\n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", + "\n", + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['features'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateWeekly_SalesStoreDeptIsHoliday
02012-04-19 04:40:04.3703828487523.202470036False
12012-10-24 12:49:45.63017651213675.425498017False
22012-06-05 05:02:35.5517614082402.017324031False
32012-05-31 22:15:14.903856896-12761.073176081False
42011-09-07 15:34:43.2398356483642.817612058False
\n", + "
" + ], + "text/plain": [ + " Date Weekly_Sales Store Dept IsHoliday\n", + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['depts'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may not want to generate data for all tables in the dataset, rather for just one table. This is possible with SDV using the `sample` method. To use it we only need to specify the name of the table we want to synthesize and the row numbers to generate. In this case, the \"walmart\" data set has 3 tables: stores, features and depts.\n", + "\n", + "In the following example, we will generate 1000 rows of the \"features\" table." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'features': Date MarkDown1 Store IsHoliday \\\n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", + " .. ... ... ... ... \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", + " \n", + " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", + " .. ... ... ... ... ... \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", + " \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", + " \n", + " [1000 rows x 12 columns]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.sample('features', 1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SDV has tools to evaluate the synthetic data generated and compare them with the original data. To see more about the evaluation tools, [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.evaluation.html#sdv.evaluation.evaluate)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/5. Generate Metadata from Dataframes.ipynb b/docs/tutorials/04_Working_with_Metadata.ipynb similarity index 53% rename from examples/5. Generate Metadata from Dataframes.ipynb rename to docs/tutorials/04_Working_with_Metadata.ipynb index 5768ce084..5c79f8ebf 100644 --- a/examples/5. Generate Metadata from Dataframes.ipynb +++ b/docs/tutorials/04_Working_with_Metadata.ipynb @@ -1,21 +1,72 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Metadata\n", + "\n", + "In order to work with complex dataset structures you will need to pass additional information\n", + "about you data to SDV.\n", + "\n", + "This is done by using a Metadata.\n", + "\n", + "Let's have a quick look at how to do it.\n", + "\n", + "## Load demo data\n", + "\n", + "We will load the demo dataset included in SDV for the rest of the session." + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from sdv import load_demo" + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A part from the tables dict, this is returning a Metadata object that contains all the information\n", + "about our dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "metadata, tables = load_demo(metadata=True)" + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" ] }, { @@ -58,60 +109,22 @@ ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "tables" + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -121,142 +134,92 @@ ] }, { - "cell_type": "code", - "execution_count": 6, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.add_table('users', data=tables['users'], primary_key='user_id')" + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "new_meta.add_table('sessions', data=tables['sessions'], primary_key='session_id',\n", - " parent='users', foreign_key='user_id')" + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "transactions_fields = {\n", - " 'timestamp': {\n", - " 'type': 'datetime',\n", - " 'format': '%Y-%m-%d'\n", - " }\n", - "}\n", - "new_meta.add_table('transactions', tables['transactions'], fields_metadata=transactions_fields,\n", - " primary_key='transaction_id', parent='sessions')" + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'fields': {'user_id': {'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'}},\n", - " 'primary_key': 'user_id'},\n", - " 'sessions': {'fields': {'user_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'users', 'field': 'user_id'}},\n", - " 'os': {'type': 'categorical'},\n", - " 'device': {'type': 'categorical'},\n", - " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", - " 'primary_key': 'session_id'},\n", - " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", - " 'format': '%Y-%m-%d'},\n", - " 'approved': {'type': 'boolean'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'sessions', 'field': 'session_id'}}},\n", - " 'primary_key': 'transaction_id'}}}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "new_meta.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "new_meta.to_dict() == metadata.to_dict()" + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 11, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.to_json('demo_metadata.json')" + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "loaded = Metadata('demo_metadata.json')" + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 13, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "loaded.to_dict() == new_meta.to_dict()" + "Let's see what our Metadata looks like right now:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -279,10 +242,10 @@ "\n", "users\n", "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", "\n", "Primary key: user_id\n", "\n", @@ -292,10 +255,10 @@ "\n", "sessions\n", "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", "\n", "Primary key: session_id\n", "Foreign key (users): user_id\n", @@ -304,7 +267,7 @@ "\n", "users->sessions\n", "\n", - "\n", + "\n", "   sessions.user_id -> users.user_id\n", "\n", "\n", @@ -313,11 +276,11 @@ "\n", "transactions\n", "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", "\n", "Primary key: transaction_id\n", "Foreign key (sessions): session_id\n", @@ -326,114 +289,137 @@ "\n", "sessions->transactions\n", "\n", - "\n", + "\n", "   transactions.session_id -> sessions.session_id\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metadata.visualize()" + "new_meta.visualize()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "user_id : id - integer\n", - "gender : categorical\n", - "age : numerical - integer\n", - "country : categorical\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "user_id : id - integer\n", - "os : categorical\n", - "device : categorical\n", - "session_id : id - integer\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "timestamp : datetime\n", - "approved : boolean\n", - "amount : numerical - float\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], "text/plain": [ - "" + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" ] }, - "execution_count": 15, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_meta.visualize()" + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" ] } ], diff --git a/examples/0. Quickstart - README.ipynb b/examples/0. Quickstart - README.ipynb deleted file mode 100644 index fa7ba905b..000000000 --- a/examples/0. Quickstart - README.ipynb +++ /dev/null @@ -1,590 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.visualize()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-06-25 16:03:03,773 - INFO - modeler - Modeling users\n", - "2020-06-25 16:03:03,778 - INFO - modeler - Modeling sessions\n", - "2020-06-25 16:03:03,783 - INFO - modeler - Modeling transactions\n", - "2020-06-25 16:03:03,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,812 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,819 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,825 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,834 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,840 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,861 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,891 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,909 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,935 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,972 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:03,990 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,018 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,141 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-06-25 16:03:04,326 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.726615061615442], [1.0454201815611962e-17, 0.8202611657476363], [0.906087951592798, 0.0, 0.9063498818810053], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.7211659999181501, 0.09407387578039639, 0.9062265705680096, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.7265698690747444], [0.1962595829950089, -0.4101010550927784, 0.9064318910717786, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15363369110349148, 1.7265890069169954], [1.6501160558524515, 0.34588621287899607, 0.9060859813924124, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6848330864411356, 0.089532714183855, 1.7265426703918063], [0.19506154837332199, -0.41017105995034014, 0.9064258143689862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15319193719168775, 1.7265715342431274, 0.08933484748381948, 1.726601450683982], [1.558118079154194, -0.49788243931513426, 0.9061755927291784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.4967074003515342, 0.5906071587381734, 1.287493411569124, 0.5910547002051788, 1.7265085917492808], [0.19622238674880996, -0.4101312643965461, 0.9063303749716511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15392250979132835, 1.7265715878572472, 0.08897912790316595, 1.726633522580159, 0.5911723429044428, 1.7265328687279475], [-1.426683586477038, 0.6340766049086322, -0.9064943110193431, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.350356266672644, -0.7728329555360014, -1.1105383964628412, -0.7729561856331211, -1.7046109781197039, -0.7727141956567691, 1.7265200912725471], [0.19613513187055676, -0.4101447540378119, 0.9060999824108501, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1536514208926178, 1.726564276607309, 0.08894497870076296, 1.7266086430031098, 0.5907859264695859, 1.7266240785900109, -0.7726518888264287, 1.7265194550742498], [-0.19600331514505734, 0.4100687094013292, -0.9060002638533913, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.15335524134834988, -1.7265840276560094, -0.08909093050930667, -1.7265435834346627, -0.5903898337829075, -1.7265249406712062, 0.7728873327661734, -1.7266499172273617, 1.726589663707634]]\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(103)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 104 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "array([[ 1.72661506e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.04542018e-17, 8.20261166e-01, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 9.06087952e-01, 0.00000000e+00, 9.06349882e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.19209290e-07, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 1.19209290e-07,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 1.19209290e-07, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.72116600e+00, 9.40738758e-02, 9.06226571e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.72656987e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96259583e-01, -4.10101055e-01, 9.06431891e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53633691e-01, 1.72658901e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.65011606e+00, 3.45886213e-01, 9.06085981e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.68483309e+00, 8.95327142e-02,\n", - " 1.72654267e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.95061548e-01, -4.10171060e-01, 9.06425814e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53191937e-01, 1.72657153e+00,\n", - " 8.93348475e-02, 1.72660145e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.55811808e+00, -4.97882439e-01, 9.06175593e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.49670740e+00, 5.90607159e-01,\n", - " 1.28749341e+00, 5.91054700e-01, 1.72650859e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96222387e-01, -4.10131264e-01, 9.06330375e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53922510e-01, 1.72657159e+00,\n", - " 8.89791279e-02, 1.72663352e+00, 5.91172343e-01,\n", - " 1.72653287e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [-1.42668359e+00, 6.34076605e-01, -9.06494311e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, -1.35035627e+00, -7.72832956e-01,\n", - " -1.11053840e+00, -7.72956186e-01, -1.70461098e+00,\n", - " -7.72714196e-01, 1.72652009e+00, 0.00000000e+00,\n", - " 0.00000000e+00],\n", - " [ 1.96135132e-01, -4.10144754e-01, 9.06099982e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 1.53651421e-01, 1.72656428e+00,\n", - " 8.89449787e-02, 1.72660864e+00, 5.90785926e-01,\n", - " 1.72662408e+00, -7.72651889e-01, 1.72651946e+00,\n", - " 0.00000000e+00],\n", - " [-1.96003315e-01, 4.10068709e-01, -9.06000264e-01,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", - " 0.00000000e+00, -1.53355241e-01, -1.72658403e+00,\n", - " -8.90909305e-02, -1.72654358e+00, -5.90389834e-01,\n", - " -1.72652494e+00, 7.72887333e-01, -1.72664992e+00,\n", - " 1.72658966e+00]])\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ipdb> ll\n", - "\u001b[1;32m 90 \u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prepare_sampled_covariance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 91 \u001b[0m \"\"\"Prepare a covariance matrix.\n", - "\u001b[1;32m 92 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 93 \u001b[0m \u001b[0mArgs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 94 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 95 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0mafter\u001b[0m \u001b[0munflattening\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 96 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 97 \u001b[0m \u001b[0mResult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 98 \u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 99 \u001b[0m \u001b[0msymmetric\u001b[0m \u001b[0mPositive\u001b[0m \u001b[0msemi\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mdefinite\u001b[0m \u001b[0mmatrix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 100 \u001b[0m \"\"\"\n", - "\u001b[1;32m 101 \u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 102 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m--> 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[1;32m 104 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 105 \u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 106 \u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 107 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 108 \u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[1;32m 109 \u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(105)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 104 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 105 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 106 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(106)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 105 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck_matrix_symmetric_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 106 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_positive_definite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(108)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 108 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 109 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "--Return--\n", - "[[1.726928289354223, -6.284976501719013e-05, 0.9060437009442075, 1.1799638567951807e-18, -1.7530700091068226e-20, 5.17606189171064e-21, ...], [-6.284976501719013e-05, 0.8203273652660589, -3.601422526398147e-05, 5.414945862916836e-17, 3.245206248848917e-20, -1.2006964972373577e-20, ...], [0.9060437009442075, -3.601422526398147e-05, 0.9064526869220875, 1.933376368463285e-17, -2.6655267441977074e-20, 2.5829833603018656e-21, ...], [1.1799638567951807e-18, 5.414945862916836e-17, 1.933376368463285e-17, 1.192092915444104e-07, -1.1365716705948445e-20, 4.7277921079772496e-21, ...], [-1.7530700091068226e-20, 3.245206248848917e-20, -2.6655267441977074e-20, -1.1365716705948445e-20, 1.1920929154436114e-07, 5.437058680680646e-21, ...], [5.17606189171064e-21, -1.2006964972373577e-20, 2.5829833603018656e-21, 4.7277921079772496e-21, 5.437058680680646e-21, 1.1920929154437684e-07, ...], ...]\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(108)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 107 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 108 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 109 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> n\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(145)\u001b[0;36m_unflatten_gaussian_copula\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 144 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 145 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 146 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.9735777830655135], [1.0081332489157226e-18, 0.07922245304646469], [1.8944415499749927, 0.0, 1.8945678304393692], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.9731403682292472, 0.009080369609414676, 1.8945052811636178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9736633430154003], [1.8257571570695899, -0.03951915463265521, 1.8943818752505202, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8223719529889824, 1.9735960424607375], [1.9662149071445283, 0.03332947346115095, 1.8943702463439986, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9696074413161604, 1.8155029069801496, 1.9735786685725154], [1.8259699330878305, -0.039591165816158624, 1.8946306798075692, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8222344112809918, 1.973660492989475, 1.8155381885055653, 1.9736509880140303], [1.9574634628424201, -0.0480675268053235, 1.8944800309305712, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9514242514903646, 1.8641779050506904, 1.931397843569703, 1.8644699021432583, 1.97360470087791], [1.8261599075098556, -0.039602178471882965, 1.8944795379664088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8221689799836889, 1.9736216215949314, 1.8159480875771647, 1.9736090474517491, 1.8643658311250595, 1.9735713191144897], [-1.9447451350012899, 0.06124906062526714, -1.8941152019092415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9372479886935285, -1.8815955695505688, -1.9143447992000526, -1.8819069289711103, -1.9715806948911605, -1.8818122561418829, 1.9735728616891475], [1.826265421380993, -0.03969242272676249, 1.8947570439794252, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.821544416662928, 1.9736431265524037, 1.815795088268001, 1.9735745553644184, 1.8640973954156812, 1.9736683290141885, -1.8818260533698774, 1.9736371138967417], [-1.8260997940060133, 0.039593106594922045, -1.8945455094034567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.8221756212795035, -1.9736399789373342, -1.8157747174291041, -1.973587031691748, -1.863968052730401, -1.9737028348459618, 1.8816417352261723, -1.9736384778314904, 1.9735905131918667]]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[1.7798122161501428], [8.418645096312187e-18, 0.6606961152453757], [1.1193765416016843, 0.0, 1.119151425141485], [0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1920928955078125e-07], [1.7754388866407922, 0.07580182865822005, 1.1193005781250343, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.7797523138925522], [0.5470801791882579, -0.3302484465683945, 1.1192475956082857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5129529949615216, 1.779801879956687], [1.718206887546536, 0.2784760244595479, 1.119072525637323, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.746240556959611, 0.4609048552481457, 1.7798064162607745], [0.5470062693558146, -0.3303697163214912, 1.1192392868938212, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5128535772731618, 1.7797855289798796, 0.4614315710843969, 1.7798640319936492], [1.6441999827502363, -0.4009517141082254, 1.1191904874368754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5946523925660905, 0.8650486386888843, 1.4262448365733487, 0.8650813439515379, 1.7798317994703388], [0.546920059994199, -0.3302738771226835, 1.1191467637807095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5131699084433272, 1.7797358524773224, 0.4613957209455638, 1.7798293469122717, 0.865178931907089, 1.779727156681622], [-1.5381861979492537, 0.5108352793295525, -1.119103691856403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.4767659633854602, -1.011712169525542, -1.2837980435462297, -1.0115506970319363, -1.7621372594453895, -1.0115631089643444, 1.7798513072348816], [0.547088849554847, -0.3303011887943328, 1.119126575120807, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5129183354605175, 1.779756666332634, 0.46107358199142606, 1.7797463173063606, 0.8648122390861165, 1.7797904461596725, -1.0117639532832405, 1.7798273537990834], [-0.5470493658634797, 0.33018152099185216, -1.119089754226156, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5129960724566422, -1.7798135506798376, -0.4609291114365542, -1.779840914364154, -0.8650801853648311, -1.7798114433503558, 1.0116215056600828, -1.7798428944094702, 1.7798232994357222]]\n", - "ipdb> c\n", - "> \u001b[0;32m/home/xals/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m(102)\u001b[0;36m_prepare_sampled_covariance\u001b[0;34m()\u001b[0m\n", - "\u001b[0;32m 101 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m--> 102 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\u001b[0;32m 103 \u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0m\n", - "ipdb> covariance\n", - "[[None], [None, None], [None, None, None], [None, None, None, None]]\n", - "ipdb> q\n" - ] - }, - { - "ename": "BdbQuit", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msdv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sdv.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotFittedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'SDV instance has not been fitted'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample_all\u001b[0;34m(self, num_rows, reset_primary_keys)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, table_name, num_rows, reset_primary_keys, sample_children)\u001b[0m\n\u001b[1;32m 228\u001b[0m }\n\u001b[1;32m 229\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_rows\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msampled_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_children\u001b[0;34m(self, table_name, sampled)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchild_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchild_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_row\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_table\u001b[0;34m(self, table_name, parent_name, parent_row, sampled)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mprevious\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled_rows\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreset_primary_keys\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_children\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_children\u001b[0;34m(self, table_name, sampled)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchild_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtable_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchild_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_sample_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent_row\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msampled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/sampler.py\u001b[0m in \u001b[0;36m_sample_table\u001b[0;34m(self, table_name, parent_name, parent_row, sampled)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 170\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_parameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mextension\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 171\u001b[0m \u001b[0mnum_rows\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mextension\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'child_rows'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36mset_parameters\u001b[0;34m(self, parameters)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'distribution'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistribution\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mparameters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_unflatten_gaussian_copula\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'univariates'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'type'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistribution\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_unflatten_gaussian_copula\u001b[0;34m(self, model_parameters)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'covariance'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'covariance'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_prepare_sampled_covariance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_prepare_sampled_covariance\u001b[0;34m(self, covariance)\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[1;32m 101\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Projects/MIT/SDV/sdv/models/copulas.py\u001b[0m in \u001b[0;36m_prepare_sampled_covariance\u001b[0;34m(self, covariance)\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[1;32m 101\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mipdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquare_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0mcovariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcovariance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcovariance\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/lib/python3.6/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mBdbQuit\u001b[0m: " - ] - } - ], - "source": [ - "sdv.sample_all(10)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/1. Quickstart - Single Table - In Memory.ipynb b/examples/1. Quickstart - Single Table - In Memory.ipynb deleted file mode 100644 index 6a07eeb32..000000000 --- a/examples/1. Quickstart - Single Table - In Memory.ipynb +++ /dev/null @@ -1,393 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "data = pd.DataFrame({\n", - " 'index': [1, 2, 3, 4, 5, 6, 7, 8],\n", - " 'integer': [1, None, 1, 2, 1, 2, 3, 2],\n", - " 'float': [0.1, None, 0.1, 0.2, 0.1, 0.2, 0.3, 0.1],\n", - " 'categorical': ['a', 'b', 'a', 'b', 'a', None, 'c', None],\n", - " 'bool': [False, True, False, True, False, False, False, None],\n", - " 'nullable': [1, None, 3, None, 5, None, 7, None],\n", - " 'datetime': [\n", - " '2010-01-01', '2010-02-01', '2010-01-01', '2010-02-01',\n", - " '2010-01-01', '2010-02-01', '2010-03-01', None\n", - " ]\n", - "})\n", - "data['datetime'] = pd.to_datetime(data['datetime'])\n", - "\n", - "tables = {\n", - " 'data': data\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
011.00.1aFalse1.02010-01-01
12NaNNaNbTrueNaN2010-02-01
231.00.1aFalse3.02010-01-01
342.00.2bTrueNaN2010-02-01
451.00.1aFalse5.02010-01-01
562.00.2NoneFalseNaN2010-02-01
673.00.3cFalse7.02010-03-01
782.00.1NoneNoneNaNNaT
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable datetime\n", - "0 1 1.0 0.1 a False 1.0 2010-01-01\n", - "1 2 NaN NaN b True NaN 2010-02-01\n", - "2 3 1.0 0.1 a False 3.0 2010-01-01\n", - "3 4 2.0 0.2 b True NaN 2010-02-01\n", - "4 5 1.0 0.1 a False 5.0 2010-01-01\n", - "5 6 2.0 0.2 None False NaN 2010-02-01\n", - "6 7 3.0 0.3 c False 7.0 2010-03-01\n", - "7 8 2.0 0.1 None None NaN NaT" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"integer\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"float\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"categorical\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": False,\n", - " \"pii_category\": \"email\"\n", - " },\n", - " {\n", - " \"name\": \"bool\",\n", - " \"type\": \"boolean\",\n", - " },\n", - " {\n", - " \"name\": \"nullable\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"datetime\",\n", - " \"type\": \"datetime\",\n", - " \"format\": \"%Y-%m-%d\"\n", - " },\n", - " ],\n", - " \"name\": \"data\",\n", - " \"primary_key\": \"index\"\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:45:45,775 - INFO - modeler - Modeling data\n", - "2020-02-01 19:45:45,776 - INFO - metadata - Loading transformer NumericalTransformer for field integer\n", - "2020-02-01 19:45:45,777 - INFO - metadata - Loading transformer NumericalTransformer for field float\n", - "2020-02-01 19:45:45,777 - INFO - metadata - Loading transformer CategoricalTransformer for field categorical\n", - "2020-02-01 19:45:45,778 - INFO - metadata - Loading transformer BooleanTransformer for field bool\n", - "2020-02-01 19:45:45,779 - INFO - metadata - Loading transformer NumericalTransformer for field nullable\n", - "2020-02-01 19:45:45,779 - INFO - metadata - Loading transformer DatetimeTransformer for field datetime\n", - "2020-02-01 19:45:45,824 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables={'data': data})" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
0010.080051aNaNNaNNaT
110-0.015712aFalseNaN2009-12-08 11:20:58.439345408
2210.142979aTrue4.9714082010-01-08 00:32:21.629585920
3310.133913aTrueNaN2010-01-14 21:37:22.051623936
4410.159480aFalse4.2940802010-01-09 21:09:11.245925888
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable \\\n", - "0 0 1 0.080051 a NaN NaN \n", - "1 1 0 -0.015712 a False NaN \n", - "2 2 1 0.142979 a True 4.971408 \n", - "3 3 1 0.133913 a True NaN \n", - "4 4 1 0.159480 a False 4.294080 \n", - "\n", - " datetime \n", - "0 NaT \n", - "1 2009-12-08 11:20:58.439345408 \n", - "2 2010-01-08 00:32:21.629585920 \n", - "3 2010-01-14 21:37:22.051623936 \n", - "4 2010-01-09 21:09:11.245925888 " - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples = sdv.sample_all()\n", - "samples['data']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/2. Quickstart - Single Table - Census.ipynb b/examples/2. Quickstart - Single Table - Census.ipynb deleted file mode 100644 index 595fe502c..000000000 --- a/examples/2. Quickstart - Single Table - Census.ipynb +++ /dev/null @@ -1,561 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'\n", - "columns = [\n", - " 'age',\n", - " 'workclass',\n", - " 'fnlwgt',\n", - " 'education',\n", - " 'education-num',\n", - " 'marital-status',\n", - " 'occupation',\n", - " 'relationship',\n", - " 'race',\n", - " 'sex',\n", - " 'capital-gain',\n", - " 'capital-loss',\n", - " 'hours-per-week',\n", - " 'native-country',\n", - " 'income'\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num \\\n", - "0 39 State-gov 77516 Bachelors 13 \n", - "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", - "2 38 Private 215646 HS-grad 9 \n", - "3 53 Private 234721 11th 7 \n", - "4 28 Private 338409 Bachelors 13 \n", - "\n", - " marital-status occupation relationship race sex \\\n", - "0 Never-married Adm-clerical Not-in-family White Male \n", - "1 Married-civ-spouse Exec-managerial Husband White Male \n", - "2 Divorced Handlers-cleaners Not-in-family White Male \n", - "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", - "4 Married-civ-spouse Prof-specialty Wife Black Female \n", - "\n", - " capital-gain capital-loss hours-per-week native-country income \n", - "0 2174 0 40 United-States <=50K \n", - "1 0 0 13 United-States <=50K \n", - "2 0 0 40 United-States <=50K \n", - "3 0 0 40 United-States <=50K \n", - "4 0 0 40 Cuba <=50K " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pd.read_csv(url, names=columns)\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'census': df\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"age\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"workclass\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"fnlwgt\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"education\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"education-num\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"marital-status\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"occupation\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"relationship\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"race\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"sex\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"capital-gain\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"capital-loss\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"hours-per-week\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"native-country\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"income\",\n", - " \"type\": \"categorical\",\n", - " }\n", - " ],\n", - " \"name\": \"census\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-06-25 23:39:05,857 - INFO - modeler - Modeling census\n", - "2020-06-25 23:39:05,857 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-06-25 23:39:05,858 - INFO - metadata - Loading transformer CategoricalTransformer for field workclass\n", - "2020-06-25 23:39:05,858 - INFO - metadata - Loading transformer NumericalTransformer for field fnlwgt\n", - "2020-06-25 23:39:05,859 - INFO - metadata - Loading transformer CategoricalTransformer for field education\n", - "2020-06-25 23:39:05,859 - INFO - metadata - Loading transformer NumericalTransformer for field education-num\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field marital-status\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field occupation\n", - "2020-06-25 23:39:05,860 - INFO - metadata - Loading transformer CategoricalTransformer for field relationship\n", - "2020-06-25 23:39:05,861 - INFO - metadata - Loading transformer CategoricalTransformer for field race\n", - "2020-06-25 23:39:05,861 - INFO - metadata - Loading transformer CategoricalTransformer for field sex\n", - "2020-06-25 23:39:05,862 - INFO - metadata - Loading transformer NumericalTransformer for field capital-gain\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer NumericalTransformer for field capital-loss\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer NumericalTransformer for field hours-per-week\n", - "2020-06-25 23:39:05,863 - INFO - metadata - Loading transformer CategoricalTransformer for field native-country\n", - "2020-06-25 23:39:05,864 - INFO - metadata - Loading transformer CategoricalTransformer for field income\n", - "2020-06-25 23:39:06,119 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
042Private47585Assoc-voc9Never-marriedProf-specialtyNot-in-familyWhiteMale2398344United-States<=50K
156Private92870HS-grad13Married-civ-spouseSalesNot-in-familyWhiteMale362426930United-States<=50K
254Private218711HS-grad13Never-marriedOther-serviceNot-in-familyWhiteFemale4420-26943United-States<=50K
331?71625Some-college14Married-civ-spouseProf-specialtyHusbandWhiteMale319689851United-States<=50K
437Private184276Bachelors14Never-marriedCraft-repairHusbandWhiteMale478590732United-States<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num marital-status \\\n", - "0 42 Private 47585 Assoc-voc 9 Never-married \n", - "1 56 Private 92870 HS-grad 13 Married-civ-spouse \n", - "2 54 Private 218711 HS-grad 13 Never-married \n", - "3 31 ? 71625 Some-college 14 Married-civ-spouse \n", - "4 37 Private 184276 Bachelors 14 Never-married \n", - "\n", - " occupation relationship race sex capital-gain \\\n", - "0 Prof-specialty Not-in-family White Male 2398 \n", - "1 Sales Not-in-family White Male 3624 \n", - "2 Other-service Not-in-family White Female 4420 \n", - "3 Prof-specialty Husband White Male 3196 \n", - "4 Craft-repair Husband White Male 4785 \n", - "\n", - " capital-loss hours-per-week native-country income \n", - "0 3 44 United-States <=50K \n", - "1 269 30 United-States <=50K \n", - "2 -269 43 United-States <=50K \n", - "3 898 51 United-States <=50K \n", - "4 907 32 United-States <=50K " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled = sdv.sample('census', num_rows=len(df))\n", - "sampled['census'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "-43.51455835774797" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from sdv.evaluation import evaluate\n", - "\n", - "samples = sdv.sample_all(len(tables['census']))\n", - "\n", - "evaluate(samples, real=tables, metadata=sdv.metadata)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/3. Quickstart - Multitable - Files.ipynb b/examples/3. Quickstart - Multitable - Files.ipynb deleted file mode 100644 index b4183f965..000000000 --- a/examples/3. Quickstart - Multitable - Files.ipynb +++ /dev/null @@ -1,636 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:45:58,147 - INFO - modeler - Modeling customers\n", - "2020-02-01 19:45:58,148 - INFO - metadata - Loading table customers\n", - "2020-02-01 19:45:58,156 - INFO - metadata - Loading transformer CategoricalTransformer for field cust_postal_code\n", - "2020-02-01 19:45:58,157 - INFO - metadata - Loading transformer NumericalTransformer for field phone_number1\n", - "2020-02-01 19:45:58,158 - INFO - metadata - Loading transformer NumericalTransformer for field credit_limit\n", - "2020-02-01 19:45:58,158 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-02-01 19:45:58,174 - INFO - modeler - Modeling orders\n", - "2020-02-01 19:45:58,175 - INFO - metadata - Loading table orders\n", - "2020-02-01 19:45:58,179 - INFO - metadata - Loading transformer NumericalTransformer for field order_total\n", - "2020-02-01 19:45:58,183 - INFO - modeler - Modeling order_items\n", - "2020-02-01 19:45:58,183 - INFO - metadata - Loading table order_items\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer CategoricalTransformer for field product_id\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer NumericalTransformer for field unit_price\n", - "2020-02-01 19:45:58,187 - INFO - metadata - Loading transformer NumericalTransformer for field quantity\n", - "2020-02-01 19:45:58,824 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('quickstart/metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:45:58,874 - INFO - metadata - Loading table customers\n", - "2020-02-01 19:45:58,878 - INFO - metadata - Loading table orders\n", - "2020-02-01 19:45:58,881 - INFO - metadata - Loading table order_items\n" - ] - } - ], - "source": [ - "real = sdv.metadata.load_tables()\n", - "\n", - "samples = sdv.sample_all(len(real['customers']), reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
006314559414875351694UK
11113718369261990501UK
221137161061572741287CANADA
33609642432704621584UK
4460965705933008584UK
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 0 63145 5941487535 1694 UK\n", - "1 1 11371 8369261990 501 UK\n", - "2 2 11371 6106157274 1287 CANADA\n", - "3 3 6096 4243270462 1584 UK\n", - "4 4 6096 5705933008 584 UK" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
0501137161755532951000UK
14631458605551835500US
297338810609670355521431000CANADA
36304076314570355521432000US
48263621137161755532951000UK
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 50 11371 6175553295 1000 UK\n", - "1 4 63145 8605551835 500 US\n", - "2 97338810 6096 7035552143 1000 CANADA\n", - "3 630407 63145 7035552143 2000 US\n", - "4 826362 11371 6175553295 1000 UK" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
001581
111707
221878
332285
443840
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 0 1 581\n", - "1 1 1 707\n", - "2 2 1 878\n", - "3 3 2 285\n", - "4 4 3 840" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
01502310
1241507
21097338810730
3655996144730
4355996144939
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 1 50 2310\n", - "1 2 4 1507\n", - "2 10 97338810 730\n", - "3 6 55996144 730\n", - "4 3 55996144 939" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
0006623
110101085
2206573
331101146
44110814
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 0 0 6 62 3\n", - "1 1 0 10 108 5\n", - "2 2 0 6 57 3\n", - "3 3 1 10 114 6\n", - "4 4 1 10 81 4" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['order_items'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
0100107528
1101861254
2102161254
3103491254
4104191134
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 100 10 7 52 8\n", - "1 101 8 6 125 4\n", - "2 102 1 6 125 4\n", - "3 103 4 9 125 4\n", - "4 104 1 9 113 4" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['order_items'].head()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/4. Anonymization.ipynb b/examples/4. Anonymization.ipynb deleted file mode 100644 index 65382e218..000000000 --- a/examples/4. Anonymization.ipynb +++ /dev/null @@ -1,410 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "data = pd.DataFrame([\n", - " {\n", - " 'index': 1,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '1111222233334444'\n", - " },\n", - " {\n", - " 'index': 2,\n", - " 'name': 'Jeff',\n", - " 'credit_card_number': '0000000000000000'\n", - " },\n", - " {\n", - " 'index': 3,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '9999999999999999'\n", - " },\n", - " {\n", - " 'index': 4,\n", - " 'name': 'Joe',\n", - " 'credit_card_number': '8888888888888888'\n", - " },\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
credit_card_numberindexname
011112222333344441Bill
100000000000000002Jeff
299999999999999993Bill
388888888888888884Joe
\n", - "
" - ], - "text/plain": [ - " credit_card_number index name\n", - "0 1111222233334444 1 Bill\n", - "1 0000000000000000 2 Jeff\n", - "2 9999999999999999 3 Bill\n", - "3 8888888888888888 4 Joe" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": \"first_name\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": [\n", - " \"credit_card_number\",\n", - " \"visa\"\n", - " ]\n", - " }\n", - " ],\n", - " \"name\": \"anonymized\",\n", - " \"primary_key\": \"index\",\n", - " },\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\"\n", - " }\n", - " ],\n", - " \"name\": \"normal\",\n", - " \"primary_key\": \"index\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'anonymized': data,\n", - " 'normal': data.copy()\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:46:01,587 - INFO - modeler - Modeling anonymized\n", - "2020-02-01 19:46:01,589 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-02-01 19:46:01,589 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-02-01 19:46:01,640 - INFO - modeler - Modeling normal\n", - "2020-02-01 19:46:01,641 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-02-01 19:46:01,641 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-02-01 19:46:01,658 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "sampled = sdv.sample_all()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Brandon4134540887507040
11Larry4288700396707168
22Brandon4869967025488786
33Brandon4288700396707168
44Larry4869967025488786
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Brandon 4134540887507040\n", - "1 1 Larry 4288700396707168\n", - "2 2 Brandon 4869967025488786\n", - "3 3 Brandon 4288700396707168\n", - "4 4 Larry 4869967025488786" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['anonymized']" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Joe1111222233334444
11Bill0000000000000000
22Jeff1111222233334444
33Bill8888888888888888
44Jeff0000000000000000
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Joe 1111222233334444\n", - "1 1 Bill 0000000000000000\n", - "2 2 Jeff 1111222233334444\n", - "3 3 Bill 8888888888888888\n", - "4 4 Jeff 0000000000000000" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['normal']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/6. Metadata Validation.ipynb b/examples/6. Metadata Validation.ipynb deleted file mode 100644 index 70ab58913..000000000 --- a/examples/6. Metadata Validation.ipynb +++ /dev/null @@ -1,239 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate(tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'country'" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: id field `user_id` is neither a primary or a foreign key\n" - ] - } - ], - "source": [ - "from sdv.metadata import MetadataError\n", - "\n", - "try:\n", - " metadata.validate()\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'user_id'" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['fields']['gender']['type'] = 'numerical'" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: Invalid values found in column `gender` of table `users`: `could not convert string to float: 'M'`\n" - ] - } - ], - "source": [ - "try:\n", - " metadata.validate(tables)\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/Demo - Walmart.ipynb b/examples/Demo - Walmart.ipynb deleted file mode 100644 index a4015bb0f..000000000 --- a/examples/Demo - Walmart.ipynb +++ /dev/null @@ -1,908 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Demo \"walmart\"\n", - "\n", - "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", - "\n", - "From those stores we obtained a training of historical data \n", - "between 2010-02-05 and 2012-11-01. This historical data includes the sales of each department on a specific date.\n", - "In this notebook, we will show you step-by-step how to download the \"Walmart\" dataset, explain the structure and sample the data.\n", - "\n", - "In this demonstration we will show how SDV can be used to generate synthetic data. And lately, this data can be used to train machine learning models.\n", - "\n", - "*The dataset used in this example can be found in [Kaggle](https://www.kaggle.com/c/walmart-recruiting-store-sales-forecasting/data), but we will show how to download it from SDV.*" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Data model summary" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "

stores

\n", - "\n", - "| Field | Type | Subtype | Additional Properties |\n", - "|-------|-------------|---------|-----------------------|\n", - "| Store | id | integer | Primary key |\n", - "| Size | numerical | integer | |\n", - "| Type | categorical | | |\n", - "\n", - "Contains information about the 45 stores, indicating the type and size of store." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "

features

\n", - "\n", - "| Fields | Type | Subtype | Additional Properties |\n", - "|--------------|-----------|---------|-----------------------------|\n", - "| Store | id | integer | foreign key (stores.Store) |\n", - "| Date | datetime | | format: \"%Y-%m-%d\" |\n", - "| IsHoliday | boolean | | |\n", - "| Fuel_Price | numerical | float | |\n", - "| Unemployment | numerical | float | |\n", - "| Temperature | numerical | float | |\n", - "| CPI | numerical | float | |\n", - "| MarkDown1 | numerical | float | |\n", - "| MarkDown2 | numerical | float | |\n", - "| MarkDown3 | numerical | float | |\n", - "| MarkDown4 | numerical | float | |\n", - "| MarkDown5 | numerical | float | |\n", - "\n", - "Contains historical training data, which covers to 2010-02-05 to 2012-11-01." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "

depts

\n", - "\n", - "| Fields | Type | Subtype | Additional Properties |\n", - "|--------------|-----------|---------|------------------------------|\n", - "| Store | id | integer | foreign key (stores.Stores) |\n", - "| Date | datetime | | format: \"%Y-%m-%d\" |\n", - "| Weekly_Sales | numerical | float | |\n", - "| Dept | numerical | integer | |\n", - "| IsHoliday | boolean | | |\n", - "\n", - "Contains additional data related to the store, department, and regional activity for the given dates." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1. Load data\n", - "\n", - "Let's start downloading the data set. In this case, we will download the data set *walmart*. We will use the SDV function `load_demo`, we can specify the name of the dataset we want to use and if we want its Metadata object or not. To know more about the demo data [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.demo.html)." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:46:10,378 - INFO - metadata - Loading table stores\n", - "2020-02-01 19:46:10,386 - INFO - metadata - Loading table features\n", - "2020-02-01 19:46:10,401 - INFO - metadata - Loading table depts\n" - ] - } - ], - "source": [ - "from sdv import load_demo\n", - "\n", - "metadata, tables = load_demo(dataset_name='walmart', metadata=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Our dataset is downloaded from an [Amazon S3 bucket](http://sdv-datasets.s3.amazonaws.com/index.html) that contains all available data sets of the `load_demo` method." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now visualize the metadata structure" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "stores\n", - "\n", - "stores\n", - "\n", - "Type : categorical\n", - "Size : numerical - integer\n", - "Store : id - integer\n", - "\n", - "Primary key: Store\n", - "Data path: stores.csv\n", - "\n", - "\n", - "\n", - "features\n", - "\n", - "features\n", - "\n", - "Date : datetime\n", - "MarkDown1 : numerical - float\n", - "Store : id - integer\n", - "IsHoliday : boolean\n", - "MarkDown4 : numerical - float\n", - "MarkDown3 : numerical - float\n", - "Fuel_Price : numerical - float\n", - "Unemployment : numerical - float\n", - "Temperature : numerical - float\n", - "MarkDown5 : numerical - float\n", - "MarkDown2 : numerical - float\n", - "CPI : numerical - float\n", - "\n", - "Foreign key (stores): Store\n", - "Data path: features.csv\n", - "\n", - "\n", - "\n", - "stores->features\n", - "\n", - "\n", - "   features.Store -> stores.Store\n", - "\n", - "\n", - "\n", - "depts\n", - "\n", - "depts\n", - "\n", - "Date : datetime\n", - "Weekly_Sales : numerical - float\n", - "Store : id - integer\n", - "Dept : numerical - integer\n", - "IsHoliday : boolean\n", - "\n", - "Foreign key (stores): Store\n", - "Data path: train.csv\n", - "\n", - "\n", - "\n", - "stores->depts\n", - "\n", - "\n", - "   depts.Store -> stores.Store\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.visualize()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And also validate that the metadata is correctly defined for our data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate(tables)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2. Create an instance of SDV and train the instance\n", - "\n", - "Once we download it, we have to create an SDV instance. With that instance, we have to analyze the loaded tables to generate a statistical model from the data. In this case, the process of adjusting the model is quickly because the dataset is small. However, with larger datasets it can be a slow process." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-02-01 19:46:10,769 - INFO - modeler - Modeling stores\n", - "2020-02-01 19:46:10,769 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", - "2020-02-01 19:46:10,770 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", - "2020-02-01 19:46:10,781 - INFO - modeler - Modeling features\n", - "2020-02-01 19:46:10,781 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-02-01 19:46:10,782 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", - "2020-02-01 19:46:10,782 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-02-01 19:46:10,783 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", - "2020-02-01 19:46:10,783 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", - "2020-02-01 19:46:10,784 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", - "2020-02-01 19:46:10,785 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", - "2020-02-01 19:46:10,785 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", - "2020-02-01 19:46:10,786 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", - "2020-02-01 19:46:10,786 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", - "2020-02-01 19:46:10,787 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", - "2020-02-01 19:46:12,119 - INFO - modeler - Modeling depts\n", - "2020-02-01 19:46:12,119 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-02-01 19:46:12,119 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", - "2020-02-01 19:46:12,120 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", - "2020-02-01 19:46:12,120 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "/home/xals/.virtualenvs/SDV.merge/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", - " return getattr(obj, method)(*args, **kwds)\n", - "2020-02-01 19:46:13,498 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables=tables)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note: We may not want to train the model every time we want to generate new synthetic data. We can [save](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) the SDV instance to [load](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) it later." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3. Generate synthetic data\n", - "\n", - "Once the instance is trained, we are ready to generate the synthetic data.\n", - "\n", - "The easiest way to generate synthetic data for the entire dataset is to call the `sample_all` method. By default, this method generates only 5 rows, but we can specify the row number that will be generated with the `num_rows` argument. To learn more about the available arguments, see [sample_all](https://sdv-dev.github.io/SDV/api/sdv.sampler.html#sdv.sampler.Sampler.sample_all)." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "samples = sdv.sample_all()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This returns a dictionary with a `pandas.DataFrame` for each table." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "

\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TypeSizeStore
0B1087250
1B714071
2A1922912
3B917903
4B1477004
\n", - "
" - ], - "text/plain": [ - " Type Size Store\n", - "0 B 108725 0\n", - "1 B 71407 1\n", - "2 A 192291 2\n", - "3 B 91790 3\n", - "4 B 147700 4" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['stores'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DateMarkDown1StoreIsHolidayMarkDown4MarkDown3Fuel_PriceUnemploymentTemperatureMarkDown5MarkDown2CPI
02012-01-06 12:07:18.407483136NaN0FalseNaN3371.6625343.5182528.59522682.638090NaNNaN190.969301
12012-02-04 17:43:32.634723840NaN0FalseNaNNaN3.6901708.56454375.3415515178.361034NaN197.207944
22011-02-28 11:10:43.4057825283991.0260020FalseNaNNaN3.7361817.73696652.1317442939.5482982017.730801198.774994
32013-10-07 10:01:00.746535424NaN0FalseNaN2969.0194543.0046139.75943973.57087910744.7800134958.619653173.358556
42011-12-29 13:02:32.8859901446395.9424020FalseNaN-1840.9299553.88208810.15859477.868550NaNNaN197.546955
\n", - "
" - ], - "text/plain": [ - " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", - "0 2012-01-06 12:07:18.407483136 NaN 0 False NaN \n", - "1 2012-02-04 17:43:32.634723840 NaN 0 False NaN \n", - "2 2011-02-28 11:10:43.405782528 3991.026002 0 False NaN \n", - "3 2013-10-07 10:01:00.746535424 NaN 0 False NaN \n", - "4 2011-12-29 13:02:32.885990144 6395.942402 0 False NaN \n", - "\n", - " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", - "0 3371.662534 3.518252 8.595226 82.638090 NaN \n", - "1 NaN 3.690170 8.564543 75.341551 5178.361034 \n", - "2 NaN 3.736181 7.736966 52.131744 2939.548298 \n", - "3 2969.019454 3.004613 9.759439 73.570879 10744.780013 \n", - "4 -1840.929955 3.882088 10.158594 77.868550 NaN \n", - "\n", - " MarkDown2 CPI \n", - "0 NaN 190.969301 \n", - "1 NaN 197.207944 \n", - "2 2017.730801 198.774994 \n", - "3 4958.619653 173.358556 \n", - "4 NaN 197.546955 " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['features'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DateWeekly_SalesStoreDeptIsHoliday
02012-05-01 02:58:51.70299468816832.327095077False
12011-01-06 18:03:36.33304422452593.7513250-35False
22012-09-19 18:41:44.62330752041113.554799048False
32012-02-18 19:35:16.60373632046877.378050021False
42011-10-04 10:57:22.12522521647731.413911060False
\n", - "
" - ], - "text/plain": [ - " Date Weekly_Sales Store Dept IsHoliday\n", - "0 2012-05-01 02:58:51.702994688 16832.327095 0 77 False\n", - "1 2011-01-06 18:03:36.333044224 52593.751325 0 -35 False\n", - "2 2012-09-19 18:41:44.623307520 41113.554799 0 48 False\n", - "3 2012-02-18 19:35:16.603736320 46877.378050 0 21 False\n", - "4 2011-10-04 10:57:22.125225216 47731.413911 0 60 False" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['depts'].head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We may not want to generate data for all tables in the dataset, rather for just one table. This is possible with SDV using the `sample` method. To use it we only need to specify the name of the table we want to synthesize and the row numbers to generate. In this case, the \"walmart\" data set has 3 tables: stores, features and depts.\n", - "\n", - "In the following example, we will generate 1000 rows of the \"features\" table." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'features': Date MarkDown1 Store IsHoliday \\\n", - " 0 2012-05-01 16:47:44.635517440 9856.327932 5 False \n", - " 1 2011-03-02 09:17:23.001815040 NaN 5 False \n", - " 2 2009-12-09 12:09:15.280852480 NaN 5 False \n", - " 3 2012-09-27 22:00:49.812289792 7575.545504 5 False \n", - " 4 2011-05-09 00:56:06.155403520 NaN 5 False \n", - " 5 2011-09-07 10:54:38.274752256 -319.511710 5 False \n", - " 6 2011-07-15 04:51:41.685670656 NaN 5 False \n", - " 7 2010-10-26 23:57:49.155473408 NaN 5 False \n", - " 8 2013-01-07 03:08:40.389892352 -25.338635 5 False \n", - " 9 2014-02-11 19:59:19.362542592 8154.995622 5 False \n", - " 10 2009-08-16 21:08:36.649604352 NaN 5 True \n", - " 11 2011-04-07 23:12:30.381310976 NaN 5 True \n", - " 12 2011-06-01 23:40:25.271012608 NaN 5 False \n", - " 13 2010-11-25 16:27:01.937167872 NaN 5 False \n", - " 14 2012-04-01 23:49:59.091133440 NaN 5 False \n", - " 15 2011-11-29 06:44:29.486765568 5394.076595 5 False \n", - " 16 2012-12-23 01:54:43.755820800 616.736543 5 False \n", - " 17 2011-02-22 12:01:12.161310464 NaN 5 False \n", - " 18 2012-09-14 09:12:48.080929792 4264.005722 5 False \n", - " 19 2011-09-29 04:49:59.894278656 NaN 5 True \n", - " 20 2011-07-06 18:19:10.024774400 NaN 5 False \n", - " 21 2010-12-20 03:51:25.390490880 8089.734544 5 False \n", - " 22 2011-12-13 08:37:11.030729472 9237.405853 5 False \n", - " 23 2011-05-10 05:02:03.746472704 NaN 5 False \n", - " 24 2011-09-16 10:58:14.326885632 2452.122989 5 False \n", - " 25 2013-08-04 01:41:06.983268352 3136.407925 5 False \n", - " 26 2012-02-27 13:58:51.523597568 4152.240138 5 False \n", - " 27 2011-02-14 11:07:31.025720320 NaN 5 False \n", - " 28 2010-12-31 09:46:24.324750592 NaN 5 False \n", - " 29 2011-11-01 17:13:08.654945280 3392.664690 5 False \n", - " .. ... ... ... ... \n", - " 970 2011-10-05 18:36:28.353068288 NaN 5 False \n", - " 971 2012-01-12 04:51:28.610967552 NaN 5 False \n", - " 972 2012-03-24 12:11:51.862262272 NaN 5 False \n", - " 973 2013-07-12 16:33:09.626688000 6561.766212 5 False \n", - " 974 2011-07-07 11:50:18.363575808 NaN 5 False \n", - " 975 2012-04-22 12:03:07.605280512 451.002188 5 False \n", - " 976 2011-04-24 16:40:02.879920896 NaN 5 False \n", - " 977 2013-05-10 19:00:49.186777088 14456.109439 5 False \n", - " 978 2011-09-16 12:39:50.775564288 NaN 5 False \n", - " 979 2011-01-27 23:51:25.038483968 NaN 5 False \n", - " 980 2011-04-08 07:11:55.072327424 NaN 5 False \n", - " 981 2012-04-14 11:31:04.547722496 NaN 5 False \n", - " 982 2010-12-31 16:22:41.385283072 NaN 5 False \n", - " 983 2014-05-03 06:44:59.234656512 7187.928500 5 False \n", - " 984 2010-09-10 17:16:27.645561344 NaN 5 True \n", - " 985 2012-02-26 19:44:58.053891840 12936.660878 5 True \n", - " 986 2011-08-12 15:34:48.039582976 1592.820792 5 False \n", - " 987 2011-01-04 14:20:20.553192192 NaN 5 False \n", - " 988 2011-10-07 13:59:02.167729408 NaN 5 False \n", - " 989 2012-11-21 05:42:10.223450880 5971.323886 5 False \n", - " 990 2011-10-01 05:02:17.844523520 NaN 5 False \n", - " 991 2012-02-23 17:58:47.762346752 -2849.677150 5 False \n", - " 992 2011-04-20 10:01:13.460115456 NaN 5 False \n", - " 993 2010-12-21 01:41:52.356143872 NaN 5 False \n", - " 994 2011-01-17 09:10:37.772929536 NaN 5 False \n", - " 995 2011-06-03 03:47:46.385799936 NaN 5 False \n", - " 996 2009-01-13 10:16:37.091725312 NaN 5 False \n", - " 997 2012-03-11 15:48:35.927031808 NaN 5 True \n", - " 998 2010-11-17 15:02:44.981103616 NaN 5 False \n", - " 999 2010-09-07 12:56:31.595172096 NaN 5 False \n", - " \n", - " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", - " 0 2843.666588 2084.265105 3.539294 8.026048 67.279415 \n", - " 1 NaN NaN 3.136235 7.464250 35.839727 \n", - " 2 NaN NaN 3.117027 8.804691 55.531072 \n", - " 3 6291.467352 7644.651968 3.950191 8.449126 63.312543 \n", - " 4 NaN NaN 3.151100 10.724954 78.355700 \n", - " 5 319.073405 8048.361448 3.214550 6.730897 46.004140 \n", - " 6 NaN NaN 3.073315 8.794787 46.230595 \n", - " 7 NaN NaN 2.851889 8.455094 73.007249 \n", - " 8 -842.343741 1973.521805 3.219424 NaN 83.146804 \n", - " 9 4753.392021 2744.820760 4.389613 6.049940 55.535728 \n", - " 10 NaN NaN 2.921131 10.411630 81.934195 \n", - " 11 -380.767318 NaN 3.588675 10.915397 41.862765 \n", - " 12 3841.062967 -692.534491 3.183958 8.449273 91.520699 \n", - " 13 NaN NaN 2.676381 10.171808 59.799911 \n", - " 14 NaN NaN 4.068388 11.061642 49.266088 \n", - " 15 4759.123311 3958.005275 3.860776 8.559916 54.371928 \n", - " 16 -2880.580355 6861.020784 3.241783 7.662614 44.933106 \n", - " 17 NaN NaN 3.287047 8.232218 73.642978 \n", - " 18 NaN NaN 3.553487 4.932510 70.156403 \n", - " 19 NaN 3498.711115 3.650312 10.192201 64.554441 \n", - " 20 NaN NaN 3.711555 5.975218 45.371293 \n", - " 21 NaN 8008.053972 2.965420 6.715859 65.736442 \n", - " 22 NaN 5226.209147 3.770431 4.689591 48.907799 \n", - " 23 NaN NaN 3.148174 7.048833 63.931931 \n", - " 24 NaN NaN 3.228123 7.125082 54.541421 \n", - " 25 -2311.380219 9527.093555 4.376510 6.227941 39.120836 \n", - " 26 625.303990 NaN 3.789428 5.716800 99.128378 \n", - " 27 NaN NaN 3.259632 8.651978 75.544242 \n", - " 28 NaN NaN 3.703069 10.753015 40.004955 \n", - " 29 NaN NaN 3.741760 6.060680 100.625665 \n", - " .. ... ... ... ... ... \n", - " 970 NaN NaN 3.572695 7.966283 57.537589 \n", - " 971 NaN NaN 3.828349 9.731679 82.126278 \n", - " 972 NaN NaN 3.313500 NaN 46.785653 \n", - " 973 3054.127592 -724.702300 4.287166 9.651842 51.824356 \n", - " 974 NaN NaN 3.761812 8.287437 88.088342 \n", - " 975 -3266.187881 8215.493689 3.642633 9.500544 29.831979 \n", - " 976 NaN NaN 2.819650 7.832418 40.807527 \n", - " 977 4122.302492 6313.274115 4.174946 8.727237 70.737165 \n", - " 978 5838.158736 NaN 3.066125 NaN 31.147061 \n", - " 979 NaN NaN 3.151944 7.658218 61.645832 \n", - " 980 NaN 630.957174 3.183844 9.862215 17.948335 \n", - " 981 NaN NaN 3.271366 9.679093 61.315823 \n", - " 982 NaN NaN 3.867379 8.191903 51.461519 \n", - " 983 2608.840965 -562.271212 3.781071 NaN 42.556422 \n", - " 984 NaN NaN 2.956717 7.515493 75.537151 \n", - " 985 3915.581744 1272.550074 3.473630 5.244393 78.032536 \n", - " 986 2982.566125 NaN 2.981562 9.439555 62.315341 \n", - " 987 NaN NaN 2.839158 9.016700 58.018581 \n", - " 988 3139.374304 NaN 3.952791 6.206879 68.542122 \n", - " 989 4102.145584 -2255.694045 4.091854 9.904070 48.245039 \n", - " 990 NaN NaN 3.318221 9.150930 100.355027 \n", - " 991 NaN -852.368857 3.281618 6.774089 63.196870 \n", - " 992 NaN NaN 3.199449 6.715256 26.426601 \n", - " 993 NaN NaN 3.339197 11.524258 56.354388 \n", - " 994 NaN NaN 3.422232 7.571462 65.796130 \n", - " 995 1692.546761 NaN 3.178186 5.160924 43.172044 \n", - " 996 NaN NaN 2.748483 11.166695 34.682053 \n", - " 997 NaN NaN 3.713086 8.024620 35.345627 \n", - " 998 NaN NaN 3.683783 8.076637 43.072280 \n", - " 999 NaN NaN 2.804057 9.006230 79.267188 \n", - " \n", - " MarkDown5 MarkDown2 CPI \n", - " 0 10204.869614 2819.319182 194.172590 \n", - " 1 NaN 1682.652584 137.004267 \n", - " 2 NaN NaN 179.678323 \n", - " 3 7501.331616 4715.605494 161.759497 \n", - " 4 NaN NaN 128.911283 \n", - " 5 6750.992820 797.112925 205.683144 \n", - " 6 NaN NaN 203.634987 \n", - " 7 NaN NaN 230.663986 \n", - " 8 3871.914754 3988.518657 NaN \n", - " 9 1624.808895 735.714943 133.713539 \n", - " 10 NaN NaN 197.575244 \n", - " 11 NaN NaN 139.060197 \n", - " 12 NaN NaN 194.510387 \n", - " 13 NaN NaN 78.886026 \n", - " 14 NaN NaN 81.669058 \n", - " 15 7758.050721 NaN 198.052999 \n", - " 16 2853.381375 9269.494345 104.932028 \n", - " 17 NaN NaN 204.210702 \n", - " 18 3343.712458 NaN 178.168110 \n", - " 19 NaN NaN 161.572206 \n", - " 20 NaN NaN 128.629900 \n", - " 21 NaN NaN 181.929594 \n", - " 22 3656.616824 1231.941589 193.633886 \n", - " 23 NaN 33.178516 198.074896 \n", - " 24 -1626.757019 NaN 105.550579 \n", - " 25 644.446061 10905.729749 137.971932 \n", - " 26 6373.039633 NaN 192.460300 \n", - " 27 NaN NaN 213.233609 \n", - " 28 NaN NaN 123.195502 \n", - " 29 2826.315133 NaN 195.701117 \n", - " .. ... ... ... \n", - " 970 NaN NaN 148.906214 \n", - " 971 NaN NaN 204.262748 \n", - " 972 NaN NaN NaN \n", - " 973 -1499.994553 16305.139629 140.188752 \n", - " 974 NaN NaN 161.099549 \n", - " 975 2685.234955 4271.729391 108.874111 \n", - " 976 NaN NaN 188.576448 \n", - " 977 3914.364495 3982.010819 196.334916 \n", - " 978 NaN NaN NaN \n", - " 979 NaN NaN 158.554149 \n", - " 980 NaN NaN 163.242219 \n", - " 981 NaN NaN 184.637472 \n", - " 982 NaN NaN 95.502993 \n", - " 983 5061.605364 4816.687238 NaN \n", - " 984 NaN NaN 150.901411 \n", - " 985 5375.719303 10141.431209 250.053444 \n", - " 986 6132.068608 7917.959484 120.500763 \n", - " 987 NaN NaN 127.836518 \n", - " 988 NaN NaN 206.249045 \n", - " 989 9976.937232 3037.291454 126.800358 \n", - " 990 NaN NaN 239.016886 \n", - " 991 4704.812744 NaN 213.808756 \n", - " 992 NaN NaN 136.904711 \n", - " 993 NaN NaN 177.862208 \n", - " 994 NaN NaN 166.085349 \n", - " 995 NaN NaN 163.449486 \n", - " 996 NaN NaN 136.797738 \n", - " 997 464.410675 NaN 186.260545 \n", - " 998 NaN NaN 140.825171 \n", - " 999 NaN NaN 168.683814 \n", - " \n", - " [1000 rows x 12 columns]}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sdv.sample('features', 1000)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "SDV has tools to evaluate the synthetic data generated and compare them with the original data. To see more about the evaluation tools, [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.evaluation.html#sdv.evaluation.evaluate)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb deleted file mode 100644 index d4b39df8b..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb +++ /dev/null @@ -1,206 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Subsampling\n", - "\n", - "This notebooks shows how to generate the `users_sample.csv` and `sessions_sample.csv` files.\n", - "Before running this notebook, make sure to having downloaded the `train_users_2.csv` and `sessions.csv`\n", - "files from https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings/data" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "users = pd.read_csv('train_users_2.csv')\n", - "sessions = pd.read_csv('sessions.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(213451, 16)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "users.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(10567737, 6)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample = users.sample(1000).reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample = sessions[\n", - " sessions['user_id'].isin(users_sample['id'])\n", - "].reset_index(drop=True).copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(24988, 6)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions_sample.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "user_ids = users_sample['id'].reset_index().set_index('id')['index']" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id\n", - "c2cga16ny5 0\n", - "08yitv7nz0 1\n", - "3oobewtnls 2\n", - "owd2csvj6u 3\n", - "i66dlbdyrt 4\n", - "Name: index, dtype: int64" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "user_ids.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample['id'] = users_sample['id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample['user_id'] = sessions_sample['user_id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample.to_csv('users_sample.csv', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample.to_csv('sessions_sample.csv', index=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb deleted file mode 100644 index e4b43f916..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb +++ /dev/null @@ -1,730 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Synthesis\n", - "\n", - "This notebook shows a usage example over a sample of the Airbnb New User Bookings dataset\n", - "available on Kaggle: https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings\n", - "\n", - "Before running this notebook, make sure to having downloaded the data and run the\n", - "`Airbnb Dataset Sampling` notebook which can be found in the same folder as this one." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:17,764 - INFO - modeler - Modeling users\n", - "2019-11-03 16:07:17,764 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:17,781 - INFO - metadata - Loading transformer DatetimeTransformer for field date_account_created\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp_first_active\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field date_first_booking\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_method\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_flow\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field language\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_channel\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_provider\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field first_affiliate_tracked\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_app\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_device_type\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_browser\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field country_destination\n", - "2019-11-03 16:07:17,852 - INFO - modeler - Modeling sessions\n", - "2019-11-03 16:07:17,853 - INFO - metadata - Loading table sessions\n", - "2019-11-03 16:07:17,872 - INFO - metadata - Loading transformer CategoricalTransformer for field action\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_type\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_detail\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field device_type\n", - "2019-11-03 16:07:17,874 - INFO - metadata - Loading transformer NumericalTransformer for field secs_elapsed\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " baseCov = np.cov(mat.T)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", - " c *= np.true_divide(1, fact)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", - " c *= np.true_divide(1, fact)\n", - "2019-11-03 16:07:22,394 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:22,443 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:22,454 - INFO - metadata - Loading table sessions\n" - ] - } - ], - "source": [ - "real = sdv.metadata.get_tables()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
7907902011-04-302011-04-30 17:58:152012-03-22FEMALE36.0basic2ensem-non-brandgoogleuntrackedWebMac DesktopFirefoxUS
5175172011-09-202011-09-20 08:14:062011-09-22-unknown-NaNbasic2ensem-brandgoogleomgWebMac DesktopChromeUS
9029022014-05-172014-05-17 01:03:00NaTMALE48.0basic24endirectdirectuntrackedMowebiPhoneMobile SafariNDF
70702013-10-212013-10-21 21:50:15NaT-unknown-NaNbasic0endirectdirectlinkedWebMac DesktopFirefoxNDF
6086082013-04-012013-04-01 20:05:132013-04-02-unknown-NaNbasic24endirectdirectlinkedMowebMac DesktopSafariother
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active date_first_booking \\\n", - "790 790 2011-04-30 2011-04-30 17:58:15 2012-03-22 \n", - "517 517 2011-09-20 2011-09-20 08:14:06 2011-09-22 \n", - "902 902 2014-05-17 2014-05-17 01:03:00 NaT \n", - "70 70 2013-10-21 2013-10-21 21:50:15 NaT \n", - "608 608 2013-04-01 2013-04-01 20:05:13 2013-04-02 \n", - "\n", - " gender age signup_method signup_flow language affiliate_channel \\\n", - "790 FEMALE 36.0 basic 2 en sem-non-brand \n", - "517 -unknown- NaN basic 2 en sem-brand \n", - "902 MALE 48.0 basic 24 en direct \n", - "70 -unknown- NaN basic 0 en direct \n", - "608 -unknown- NaN basic 24 en direct \n", - "\n", - " affiliate_provider first_affiliate_tracked signup_app first_device_type \\\n", - "790 google untracked Web Mac Desktop \n", - "517 google omg Web Mac Desktop \n", - "902 direct untracked Moweb iPhone \n", - "70 direct linked Web Mac Desktop \n", - "608 direct linked Moweb Mac Desktop \n", - "\n", - " first_browser country_destination \n", - "790 Firefox US \n", - "517 Chrome US \n", - "902 Mobile Safari NDF \n", - "70 Firefox NDF \n", - "608 Safari other " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
6345101search_resultsclickview_search_resultsMac Desktop1592.0
541772lookupNaNNaNMac Desktop651.0
22406823showviewp3Android Phone1198.0
6279101search_resultsclickview_search_resultsMac Desktop3389.0
6561883ajax_refresh_subtotalclickchange_trip_characteristicsMac Desktop268.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type \\\n", - "6345 101 search_results click \n", - "5417 72 lookup NaN \n", - "22406 823 show view \n", - "6279 101 search_results click \n", - "6561 883 ajax_refresh_subtotal click \n", - "\n", - " action_detail device_type secs_elapsed \n", - "6345 view_search_results Mac Desktop 1592.0 \n", - "5417 NaN Mac Desktop 651.0 \n", - "22406 p3 Android Phone 1198.0 \n", - "6279 view_search_results Mac Desktop 3389.0 \n", - "6561 change_trip_characteristics Mac Desktop 268.0 " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['sessions'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "length = len(real['users'])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1000" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "length" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "samples = sdv.sample_all(length, reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
8578572014-03-16 21:04:00.3134963202014-03-15 05:39:43.249984000NaT-unknown-45.0facebook0endirectdirectuntrackedWebiPhoneIENDF
4894892014-10-03 17:46:33.6981690882014-10-14 06:11:35.937526784NaTFEMALE56.0basic0endirectdirectuntrackedWebWindows Desktop-unknown-NDF
7807802014-03-02 14:25:57.2513733122014-03-05 14:28:40.767416832NaT-unknown-NaNbasic0endirectdirectlinkedWebWindows DesktopSafariNDF
5965962014-01-12 05:40:57.1618867202014-01-18 03:57:01.785481984NaTMALE51.0basic0ensem-brandgooglelinkedWebMac DesktopSafariNDF
7767762013-05-28 14:15:41.3111449602013-05-28 06:13:45.2424908802013-08-07 07:45:31.016155904FEMALE46.0basic0endirectdirectuntrackedWebMac DesktopChromeNDF
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active \\\n", - "857 857 2014-03-16 21:04:00.313496320 2014-03-15 05:39:43.249984000 \n", - "489 489 2014-10-03 17:46:33.698169088 2014-10-14 06:11:35.937526784 \n", - "780 780 2014-03-02 14:25:57.251373312 2014-03-05 14:28:40.767416832 \n", - "596 596 2014-01-12 05:40:57.161886720 2014-01-18 03:57:01.785481984 \n", - "776 776 2013-05-28 14:15:41.311144960 2013-05-28 06:13:45.242490880 \n", - "\n", - " date_first_booking gender age signup_method signup_flow \\\n", - "857 NaT -unknown- 45.0 facebook 0 \n", - "489 NaT FEMALE 56.0 basic 0 \n", - "780 NaT -unknown- NaN basic 0 \n", - "596 NaT MALE 51.0 basic 0 \n", - "776 2013-08-07 07:45:31.016155904 FEMALE 46.0 basic 0 \n", - "\n", - " language affiliate_channel affiliate_provider first_affiliate_tracked \\\n", - "857 en direct direct untracked \n", - "489 en direct direct untracked \n", - "780 en direct direct linked \n", - "596 en sem-brand google linked \n", - "776 en direct direct untracked \n", - "\n", - " signup_app first_device_type first_browser country_destination \n", - "857 Web iPhone IE NDF \n", - "489 Web Windows Desktop -unknown- NDF \n", - "780 Web Windows Desktop Safari NDF \n", - "596 Web Mac Desktop Safari NDF \n", - "776 Web Mac Desktop Chrome NDF " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
7985222ajax_refresh_subtotalNaNlisting_reviewsWindows Desktop6730.0
23883629ajax_refresh_subtotaldata-unknown-Windows Desktop4842.0
15908418termsclickview_search_resultsMac Desktop111772.0
8900250personalizedataNaNMac Desktop20486.0
26026691search-unknown--unknown-Mac Desktop149526.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type action_detail \\\n", - "7985 222 ajax_refresh_subtotal NaN listing_reviews \n", - "23883 629 ajax_refresh_subtotal data -unknown- \n", - "15908 418 terms click view_search_results \n", - "8900 250 personalize data NaN \n", - "26026 691 search -unknown- -unknown- \n", - "\n", - " device_type secs_elapsed \n", - "7985 Windows Desktop 6730.0 \n", - "23883 Windows Desktop 4842.0 \n", - "15908 Mac Desktop 111772.0 \n", - "8900 Mac Desktop 20486.0 \n", - "26026 Mac Desktop 149526.0 " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['sessions'].sample(5)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/metadata.json b/examples/airbnb-recruiting-new-user-bookings/metadata.json deleted file mode 100644 index 4c03506bf..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/metadata.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "tables": [ - { - "headers": true, - "name": "users", - "path": "users_sample.csv", - "primary_key": "id", - "fields": [ - { - "name": "id", - "type": "id" - }, - { - "name": "date_account_created", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "timestamp_first_active", - "type": "datetime", - "format": "%Y%m%d%H%M%S" - }, - { - "name": "date_first_booking", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "gender", - "type": "categorical" - }, - { - "name": "age", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "signup_method", - "type": "categorical" - }, - { - "name": "signup_flow", - "type": "categorical" - }, - { - "name": "language", - "type": "categorical" - }, - { - "name": "affiliate_channel", - "type": "categorical" - }, - { - "name": "affiliate_provider", - "type": "categorical" - }, - { - "name": "first_affiliate_tracked", - "type": "categorical" - }, - { - "name": "signup_app", - "type": "categorical" - }, - { - "name": "first_device_type", - "type": "categorical" - }, - { - "name": "first_browser", - "type": "categorical" - }, - { - "name": "country_destination", - "type": "categorical" - } - ] - }, - { - "headers": true, - "name": "sessions", - "path": "sessions_sample.csv", - "fields": [ - { - "name": "user_id", - "ref": { - "field": "id", - "table": "users" - }, - "type": "id" - }, - { - "name": "action", - "type": "categorical" - }, - { - "name": "action_type", - "type": "categorical" - }, - { - "name": "action_detail", - "type": "categorical" - }, - { - "name": "device_type", - "type": "categorical" - }, - { - "name": "secs_elapsed", - "type": "numerical", - "subtype": "integer" - } - ] - } - ] -} diff --git a/examples/demo_metadata.json b/examples/demo_metadata.json deleted file mode 100644 index a8b345822..000000000 --- a/examples/demo_metadata.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "tables": { - "users": { - "fields": { - "user_id": { - "type": "id", - "subtype": "integer" - }, - "age": { - "type": "numerical", - "subtype": "integer" - }, - "gender": { - "type": "categorical" - }, - "country": { - "type": "categorical" - } - }, - "primary_key": "user_id" - }, - "sessions": { - "fields": { - "os": { - "type": "categorical" - }, - "user_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "users", - "field": "user_id" - } - }, - "device": { - "type": "categorical" - }, - "session_id": { - "type": "id", - "subtype": "integer" - } - }, - "primary_key": "session_id" - }, - "transactions": { - "fields": { - "timestamp": { - "type": "datetime", - "format": "%Y-%m-%d" - }, - "transaction_id": { - "type": "id", - "subtype": "integer" - }, - "approved": { - "type": "boolean" - }, - "session_id": { - "type": "id", - "subtype": "integer", - "ref": { - "table": "sessions", - "field": "session_id" - } - }, - "amount": { - "type": "numerical", - "subtype": "float" - } - }, - "primary_key": "transaction_id" - } - } -} \ No newline at end of file diff --git a/examples/quickstart/customers.csv b/examples/quickstart/customers.csv deleted file mode 100644 index 14922885f..000000000 --- a/examples/quickstart/customers.csv +++ /dev/null @@ -1,8 +0,0 @@ -customer_id,cust_postal_code,phone_number1,credit_limit,country -50,11371,6175553295,1000,UK -4,63145,8605551835,500,US -97338810,6096,7035552143,1000,CANADA -630407,63145,7035552143,2000,US -826362,11371,6175553295,1000,UK -55996144,20166,4045553285,1000,FRANCE -598112,63145,6175553295,1000,SPAIN diff --git a/examples/quickstart/metadata.json b/examples/quickstart/metadata.json deleted file mode 100644 index 119d30c76..000000000 --- a/examples/quickstart/metadata.json +++ /dev/null @@ -1,96 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "customer_id", - "type": "id" - }, - { - "name": "cust_postal_code", - "type": "categorical" - }, - { - "name": "phone_number1", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "credit_limit", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "country", - "type": "categorical" - } - ], - "headers": true, - "name": "customers", - "path": "customers.csv", - "primary_key": "customer_id", - "use": true - }, - { - "fields": [ - { - "name": "order_id", - "type": "id" - }, - { - "name": "customer_id", - "ref": { - "field": "customer_id", - "table": "customers" - }, - "type": "id" - }, - { - "name": "order_total", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "orders", - "path": "orders.csv", - "primary_key": "order_id", - "use": true - }, - { - "fields": [ - { - "name": "order_item_id", - "type": "id" - }, - { - "name": "order_id", - "ref": { - "field": "order_id", - "table": "orders" - }, - "type": "id" - }, - { - "name": "product_id", - "type": "categorical" - }, - { - "name": "unit_price", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "quantity", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "order_items", - "path": "order_items.csv", - "primary_key": "order_item_id", - "use": true - } - ] -} diff --git a/examples/quickstart/order_items.csv b/examples/quickstart/order_items.csv deleted file mode 100644 index 7f47c3d69..000000000 --- a/examples/quickstart/order_items.csv +++ /dev/null @@ -1,50 +0,0 @@ -order_item_id,order_id,product_id,unit_price,quantity -100,10,7,52,8 -101,8,6,125,4 -102,1,6,125,4 -103,4,9,125,4 -104,1,9,113,4 -105,9,10,87,2 -106,10,6,39,4 -107,1,6,50,4 -108,2,3,31,2 -109,4,6,37,3 -110,10,10,50,4 -111,1,10,50,2 -112,6,14,50,1 -113,10,1,161,5 -114,2,2,73,4 -115,1,6,50,2 -116,5,6,63,2 -117,1,6,115,4 -118,4,10,50,4 -119,5,9,50,2 -120,10,15,37,0 -121,7,6,113,7 -122,1,6,125,2 -123,10,6,113,7 -124,8,10,125,4 -125,5,6,125,2 -126,1,6,125,2 -127,8,10,69,2 -128,8,10,50,2 -129,6,8,152,3 -130,5,2,103,4 -131,3,11,30,-1 -132,9,6,97,2 -133,7,1,165,5 -134,8,6,125,2 -135,8,6,50,2 -136,1,9,125,4 -137,6,8,152,3 -138,1,10,50,2 -139,3,6,37,3 -140,6,15,111,1 -141,9,10,125,2 -142,10,2,147,4 -143,3,6,58,4 -144,1,6,102,4 -145,9,10,50,4 -146,6,15,50,1 -147,7,3,30,-1 -148,1,6,125,4 diff --git a/examples/quickstart/orders.csv b/examples/quickstart/orders.csv deleted file mode 100644 index 08371dbb8..000000000 --- a/examples/quickstart/orders.csv +++ /dev/null @@ -1,11 +0,0 @@ -order_id,customer_id,order_total -1,50,2310 -2,4,1507 -10,97338810,730 -6,55996144,730 -3,55996144,939 -4,50,2380 -5,97338810,1570 -7,50,730 -8,97338810,2336 -9,4,743 diff --git a/sdv/__init__.py b/sdv/__init__.py index 5213dd4ca..d8e94dd37 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = """MIT Data To AI Lab""" __email__ = 'dailabmit@gmail.com' -__version__ = '0.3.3' +__version__ = '0.3.5.dev0' from sdv.demo import get_available_demos, load_demo from sdv.metadata import Metadata diff --git a/sdv/benchmark.py b/sdv/benchmark.py new file mode 100644 index 000000000..df4365d0e --- /dev/null +++ b/sdv/benchmark.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + +import logging +import multiprocessing +import os +from datetime import datetime + +import pandas as pd + +from sdv import SDV, Metadata +from sdv.demo import get_available_demos, load_demo +from sdv.evaluation import evaluate + +LOGGER = logging.getLogger(__name__) + + +def _score_dataset(dataset, datasets_path, output): + start = datetime.now() + + try: + if datasets_path is None: + metadata, tables = load_demo(dataset, metadata=True) + else: + metadata = Metadata(os.path.join(datasets_path, dataset, 'metadata.json')) + tables = metadata.load_tables() + + sdv = SDV() + LOGGER.info('Modeling dataset %s', dataset) + sdv.fit(metadata, tables) + + LOGGER.info('Sampling dataset %s', dataset) + sampled = sdv.sample_all(10) + + LOGGER.info('Evaluating dataset %s', dataset) + score = evaluate(sampled, metadata=metadata) + + LOGGER.info('%s: %s - ELAPSED: %s', dataset, score, datetime.now() - start) + output.update({ + 'dataset': dataset, + 'score': score, + }) + + except Exception as ex: + error = '{}: {}'.format(type(ex).__name__, str(ex)) + LOGGER.error('%s: %s - ELAPSED: %s', dataset, error, datetime.now() - start) + output.update({ + 'dataset': dataset, + 'error': error + }) + + +def score_dataset(dataset, datasets_path, timeout=None): + with multiprocessing.Manager() as manager: + output = manager.dict() + process = multiprocessing.Process( + target=_score_dataset, + args=(dataset, datasets_path, output) + ) + + process.start() + process.join(timeout) + process.terminate() + + if not output: + LOGGER.warn('%s: TIMEOUT', dataset) + return { + 'dataset': dataset, + 'error': 'timeout' + } + + return dict(output) + + +def benchmark(datasets=None, datasets_path=None, distributed=True, timeout=None): + if datasets is None: + if datasets_path is None: + datasets = get_available_demos().name + else: + datasets = os.listdir(datasets_path) + + if distributed: + import dask + + global score_dataset + score_dataset = dask.delayed(score_dataset) + + scores = list() + for dataset in datasets: + scores.append(score_dataset(dataset, datasets_path, timeout)) + + if distributed: + scores = dask.compute(*scores) + + return pd.DataFrame(scores) diff --git a/sdv/data/got_families/character_families.csv b/sdv/data/got_families/character_families.csv new file mode 100644 index 000000000..6b5560db5 --- /dev/null +++ b/sdv/data/got_families/character_families.csv @@ -0,0 +1,13 @@ +character_id,family_id,generation,type +1,1,8,father +1,4,5,mother +2,1,10,father +2,2,4,mother +3,3,7,both +4,4,12,both +5,1,10,father +5,2,4,mother +6,1,10,father +6,2,4,mother +7,1,10,father +7,2,4,mother diff --git a/sdv/data/got_families/characters.csv b/sdv/data/got_families/characters.csv new file mode 100644 index 000000000..f37e21af5 --- /dev/null +++ b/sdv/data/got_families/characters.csv @@ -0,0 +1,8 @@ +age,character_id,name +20,1,Jon +16,2,Arya +35,3,Tyrion +18,4,Daenerys +19,5,Sansa +24,6,Robb +15,7,Bran diff --git a/sdv/data/got_families/families.csv b/sdv/data/got_families/families.csv new file mode 100644 index 000000000..51745d59e --- /dev/null +++ b/sdv/data/got_families/families.csv @@ -0,0 +1,5 @@ +family_id,name +1,Stark +2,Tully +3,Lannister +4,Targaryen diff --git a/sdv/data/got_families/metadata.json b/sdv/data/got_families/metadata.json new file mode 100644 index 000000000..9ecff24a5 --- /dev/null +++ b/sdv/data/got_families/metadata.json @@ -0,0 +1,79 @@ +{ + "path": "", + "tables": [ + { + "use": true, + "primary_key": "character_id", + "fields": [ + { + "regex": "^[1-9]{1,2}$", + "type": "id", + "name": "character_id" + }, + { + "type": "categorical", + "name": "name" + }, + { + "subtype": "integer", + "type": "numerical", + "name": "age" + } + ], + "headers": true, + "path": "characters.csv", + "name": "characters" + }, + { + "use": true, + "primary_key": "family_id", + "fields": [ + { + "regex": "^[1-9]$", + "type": "id", + "name": "family_id" + }, + { + "type": "categorical", + "name": "name" + } + ], + "headers": true, + "path": "families.csv", + "name": "families" + }, + { + "headers": true, + "path": "character_families.csv", + "use": true, + "name": "character_families", + "fields": [ + { + "type": "id", + "ref": { + "field": "character_id", + "table": "characters" + }, + "name": "character_id" + }, + { + "type": "id", + "ref": { + "field": "family_id", + "table": "families" + }, + "name": "family_id" + }, + { + "type": "categorical", + "name": "type" + }, + { + "subtype": "integer", + "type": "numerical", + "name": "generation" + } + ] + } + ] +} diff --git a/sdv/demo.py b/sdv/demo.py index 672130ada..9ef350904 100644 --- a/sdv/demo.py +++ b/sdv/demo.py @@ -104,9 +104,9 @@ def _download(dataset_name, data_path): zf.extractall(data_path) -def _load(dataset_name, data_path): - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) +def _get_dataset_path(dataset_name, data_path): + if not os.path.exists(data_path): + os.makedirs(data_path) if not os.path.exists(os.path.join(data_path, dataset_name)): _download(dataset_name, data_path) @@ -151,8 +151,8 @@ def _load_dummy(): def _load_demo_dataset(dataset_name, data_path): - data_path = _load(dataset_name, data_path) - meta = Metadata(metadata=os.path.join(data_path, 'metadata.json')) + dataset_path = _get_dataset_path(dataset_name, data_path) + meta = Metadata(metadata=os.path.join(dataset_path, 'metadata.json')) tables = meta.load_tables() return meta, tables diff --git a/sdv/metadata.py b/sdv/metadata/__init__.py similarity index 87% rename from sdv/metadata.py rename to sdv/metadata/__init__.py index dbf01ca60..d9b3ad368 100644 --- a/sdv/metadata.py +++ b/sdv/metadata/__init__.py @@ -4,11 +4,21 @@ import os from collections import defaultdict -import graphviz import numpy as np import pandas as pd from rdt import HyperTransformer, transformers +from sdv.metadata import visualization +from sdv.metadata.errors import MetadataError +from sdv.metadata.table import Table + +__all__ = [ + 'Metadata', + 'MetadataError', + 'Table', + 'visualization' +] + LOGGER = logging.getLogger(__name__) @@ -51,10 +61,6 @@ def _load_csv(root_path, table_meta): return data -class MetadataError(Exception): - pass - - class Metadata: """Dataset Metadata. @@ -240,6 +246,30 @@ def get_tables(self): """ return list(self._metadata['tables'].keys()) + def get_field_meta(self, table_name, field_name): + """Get the metadata dict for a table. + + Args: + table_name (str): + Name of the table to which the field belongs. + field_name (str): + Name of the field to get data for. + + Returns: + dict: + field metadata + + Raises: + ValueError: + If the table or the field do not exist in this metadata. + """ + field_meta = self.get_fields(table_name).get(field_name) + if field_meta is None: + raise ValueError( + 'Table "{}" does not contain a field name "{}"'.format(table_name, field_name)) + + return copy.deepcopy(field_meta) + def get_fields(self, table_name): """Get table fields metadata. @@ -291,11 +321,9 @@ def get_foreign_key(self, parent, child): ValueError: If the relationship does not exist. """ - primary = self.get_primary_key(parent) - for name, field in self.get_fields(child).items(): ref = field.get('ref') - if ref and ref['field'] == primary: + if ref and ref['table'] == parent: return name raise ValueError('{} is not parent of {}'.format(parent, child)) @@ -370,8 +398,13 @@ def get_dtypes(self, table_name, ids=False): if ids and field_type == 'id': if (name != table_meta.get('primary_key')) and not field.get('ref'): - raise MetadataError( - 'id field `{}` is neither a primary or a foreign key'.format(name)) + for child_table in self.get_children(table_name): + if name == self.get_foreign_key(table_name, child_table): + break + + else: + raise MetadataError( + 'id field `{}` is neither a primary or a foreign key'.format(name)) if ids or (field_type != 'id'): dtypes[name] = dtype @@ -575,11 +608,6 @@ def _validate_circular_relationships(self, parent, children=None): for child in children: self._validate_circular_relationships(parent, self.get_children(child)) - def _validate_parents(self, table_name): - """Make sure that the table has only one parent.""" - if len(self.get_parents(table_name)) > 1: - raise MetadataError('Table {} has more than one parent.'.format(table_name)) - def validate(self, tables=None): """Validate this metadata. @@ -624,7 +652,6 @@ def validate(self, tables=None): self._validate_table(table_name, table_meta, table) self._validate_circular_relationships(table_name) - self._validate_parents(table_name) def _check_field(self, table, field, exists=False): """Validate the existance of the table and existance (or not) of field.""" @@ -678,14 +705,17 @@ def add_field(self, table, field, field_type, field_subtype=None, properties=Non def _get_key_subtype(field_meta): """Get the appropriate key subtype.""" field_type = field_meta['type'] + if field_type == 'categorical': field_subtype = 'string' + elif field_type in ('numerical', 'id'): field_subtype = field_meta['subtype'] if field_subtype not in ('integer', 'string'): raise ValueError( 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) ) + else: raise ValueError( 'Invalid field "type" for key field: "{}"'.format(field_type) @@ -746,7 +776,11 @@ def add_relationship(self, parent, child, foreign_key=None): * The child table already has a parent. * The new relationship closes a relationship circle. """ - # Validate table and field names + # Validate tables exists + self.get_table_meta(parent) + self.get_table_meta(child) + + # Validate field names primary_key = self.get_primary_key(parent) if not primary_key: raise ValueError('Parent table "{}" does not have a primary key'.format(parent)) @@ -754,33 +788,51 @@ def add_relationship(self, parent, child, foreign_key=None): if foreign_key is None: foreign_key = primary_key + parent_key_meta = copy.deepcopy(self.get_field_meta(parent, primary_key)) + child_key_meta = copy.deepcopy(self.get_field_meta(child, foreign_key)) + # Validate relationships - if self.get_parents(child): - raise ValueError('Table "{}" already has a parent'.format(child)) + child_ref = child_key_meta.get('ref') + if child_ref: + raise ValueError( + 'Field "{}.{}" already defines a relationship'.format(child, foreign_key)) grandchildren = self.get_children(child) if grandchildren: self._validate_circular_relationships(parent, grandchildren) - # Copy primary key details over to the foreign key - foreign_key_details = copy.deepcopy(self.get_fields(parent)[primary_key]) - foreign_key_details['ref'] = { + # Make sure that the parent key is an id + if parent_key_meta['type'] != 'id': + parent_key_meta['subtype'] = self._get_key_subtype(parent_key_meta) + parent_key_meta['type'] = 'id' + + # Update the child key meta + child_key_meta['subtype'] = self._get_key_subtype(child_key_meta) + child_key_meta['type'] = 'id' + child_key_meta['ref'] = { 'table': parent, 'field': primary_key } # Make sure that key subtypes are the same - foreign_meta = self.get_fields(child).get(foreign_key) - if foreign_meta: - foreign_subtype = self._get_key_subtype(foreign_meta) - if foreign_subtype != foreign_key_details['subtype']: - raise ValueError('Primary and Foreign key subtypes mismatch') + if child_key_meta['subtype'] != parent_key_meta['subtype']: + raise ValueError('Parent and Child key subtypes mismatch') - self._metadata['tables'][child]['fields'][foreign_key] = foreign_key_details + # Make a backup + metadata_backup = copy.deepcopy(self._metadata) + + self._metadata['tables'][parent]['fields'][primary_key] = parent_key_meta + self._metadata['tables'][child]['fields'][foreign_key] = child_key_meta # Re-analyze the relationships self._analyze_relationships() + try: + self.validate() + except MetadataError: + self._metadata = metadata_backup + raise + def _get_field_details(self, data, fields): """Get or build all the fields metadata. @@ -919,119 +971,7 @@ def to_json(self, path): with open(path, 'w') as out_file: json.dump(self._metadata, out_file, indent=4) - @staticmethod - def _get_graphviz_extension(path): - if path: - path_splitted = path.split('.') - if len(path_splitted) == 1: - raise ValueError('Path without graphviz extansion.') - - graphviz_extension = path_splitted[-1] - - if graphviz_extension not in graphviz.backend.FORMATS: - raise ValueError( - '"{}" not a valid graphviz extension format.'.format(graphviz_extension) - ) - - return '.'.join(path_splitted[:-1]), graphviz_extension - - return None, None - - def _visualize_add_nodes(self, plot): - """Add nodes into a `graphviz.Digraph`. - - Each node represent a metadata table. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - # Append table fields - fields = [] - - for name, value in self.get_fields(table).items(): - if value.get('subtype') is not None: - fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) - - else: - fields.append('{} : {}'.format(name, value['type'])) - - fields = r'\l'.join(fields) - - # Append table extra information - extras = [] - - primary_key = self.get_primary_key(table) - if primary_key is not None: - extras.append('Primary key: {}'.format(primary_key)) - - parents = self.get_parents(table) - for parent in parents: - foreign_key = self.get_foreign_key(parent, table) - extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) - - path = self.get_table_meta(table).get('path') - if path is not None: - extras.append('Data path: {}'.format(path)) - - extras = r'\l'.join(extras) - - # Add table node - title = r'{%s|%s\l|%s\l}' % (table, fields, extras) - plot.node(table, label=title) - - def _visualize_add_edges(self, plot): - """Add edges into a `graphviz.Digraph`. - - Each edge represents a relationship between two metadata tables. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - for parent in list(self.get_parents(table)): - plot.edge( - parent, - table, - label=' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ), - arrowhead='crow' - ) - - def visualize(self, path=None): - """Plot metadata usign graphviz. - - Try to generate a plot using graphviz. - If a ``path`` is provided save the output into a file. - - Args: - path (str): - Output file path to save the plot, it requires a graphviz - supported extension. If ``None`` do not save the plot. - Defaults to ``None``. - """ - filename, graphviz_extension = self._get_graphviz_extension(path) - plot = graphviz.Digraph( - 'Metadata', - format=graphviz_extension, - node_attr={ - "shape": "Mrecord", - "fillcolor": "lightgoldenrod1", - "style": "filled" - }, - ) - - self._visualize_add_nodes(plot) - self._visualize_add_edges(plot) - - if filename: - plot.render(filename=filename, cleanup=True, format=graphviz_extension) - else: - return plot - - def __str__(self): + def __repr__(self): tables = self.get_tables() relationships = [ ' {}.{} -> {}.{}'.format( @@ -1053,3 +993,18 @@ def __str__(self): tables, '\n'.join(relationships) ) + + def visualize(self, path=None): + """Plot metadata usign graphviz. + + Generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + path (str): + Output file path to save the plot. It requires a graphviz + supported extension. If ``None`` do not save the plot and + just return the ``graphviz.Digraph`` object. + Defaults to ``None``. + """ + return visualization.visualize(self, path) diff --git a/sdv/metadata/errors.py b/sdv/metadata/errors.py new file mode 100644 index 000000000..8e1bada44 --- /dev/null +++ b/sdv/metadata/errors.py @@ -0,0 +1,2 @@ +class MetadataError(Exception): + pass diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py new file mode 100644 index 000000000..598a4a725 --- /dev/null +++ b/sdv/metadata/table.py @@ -0,0 +1,399 @@ +import copy +import json +import logging + +import numpy as np +import pandas as pd +import rdt +from faker import Faker + +from sdv.metadata.errors import MetadataError + +LOGGER = logging.getLogger(__name__) + + +class Table: + """Table Metadata. + + The Metadata class provides a unified layer of abstraction over the metadata + of a single Table, which includes both the necessary details to load the data + from the filesystem and to know how to parse and transform it to numerical data. + """ + + _metadata = None + _hyper_transformer = None + _anonymization_mappings = None + + _FIELD_TEMPLATES = { + 'i': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'f': { + 'type': 'numerical', + 'subtype': 'float', + }, + 'O': { + 'type': 'categorical', + }, + 'b': { + 'type': 'boolean', + }, + 'M': { + 'type': 'datetime', + } + } + _DTYPES = { + ('categorical', None): 'object', + ('boolean', None): 'bool', + ('numerical', None): 'float', + ('numerical', 'float'): 'float', + ('numerical', 'integer'): 'int', + ('datetime', None): 'datetime64', + ('id', None): 'int', + ('id', 'integer'): 'int', + ('id', 'string'): 'str' + } + + def _get_faker(self, category): + """Return the faker object to anonymize data. + + Args: + category (str or tuple): + Fake category to use. If a tuple is passed, the first element is + the category and the rest are additional arguments for the Faker. + + Returns: + function: + Faker function to generate new fake data instances. + + Raises: + ValueError: + A ``ValueError`` is raised if the faker category we want don't exist. + """ + if isinstance(category, (tuple, list)): + category, *args = category + else: + args = tuple() + + try: + faker_method = getattr(Faker(), category) + + if not args: + return faker_method + + def faker(): + return faker_method(*args) + + return faker + + except AttributeError: + raise ValueError('Category "{}" couldn\'t be found on faker'.format(category)) + + def __init__(self, metadata=None, field_names=None, primary_key=None, + field_types=None, anonymize_fields=None, constraints=None, + transformer_templates=None): + # TODO: Validate that the given metadata is a valid dict + self._metadata = metadata or dict() + self._hyper_transformer = None + + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types or {} + self._anonymize_fields = anonymize_fields + self._constraints = constraints + + self._transformer_templates = transformer_templates or {} + + self._fakers = { + name: self._get_faker(category) + for name, category in (anonymize_fields or {}).items() + } + + def _get_field_dtype(self, field_name, field_metadata): + field_type = field_metadata['type'] + field_subtype = field_metadata.get('subtype') + dtype = self._DTYPES.get((field_type, field_subtype)) + if not dtype: + raise MetadataError( + 'Invalid type and subtype combination for field {}: ({}, {})'.format( + field_name, field_type, field_subtype) + ) + + return dtype + + def get_fields(self): + """Get fields metadata. + + Returns: + dict: + Dictionary of fields metadata for this table. + """ + return copy.deepcopy(self._metadata['fields']) + + def get_dtypes(self, ids=False): + """Get a ``dict`` with the ``dtypes`` for each field of the table. + + Args: + ids (bool): + Whether or not include the id fields. Defaults to ``False``. + + Returns: + dict: + Dictionary that contains the field names and data types. + + Raises: + ValueError: + If a field has an invalid type or subtype. + """ + dtypes = dict() + for name, field_meta in self._metadata['fields'].items(): + field_type = field_meta['type'] + + if ids or (field_type != 'id'): + dtypes[name] = self._get_field_dtype(name, field_meta) + + return dtypes + + def _build_fields_metadata(self, data): + """Build all the fields metadata. + + Args: + data (pandas.DataFrame): + Data to be analyzed. + + Returns: + dict: + Dict of valid fields. + + Raises: + ValueError: + If a column from the data analyzed is an unsupported data type + """ + field_names = self._field_names or data.columns + + fields_metadata = dict() + for field_name in field_names: + if field_name not in data: + raise ValueError('Field {} not found in given data'.format(field_name)) + + field_meta = self._field_types.get(field_name) + if field_meta: + # Validate the given meta + self._get_field_dtype(field_name, field_meta) + else: + dtype = data[field_name].dtype + field_template = self._FIELD_TEMPLATES.get(dtype.kind) + if field_template is None: + raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field_name)) + + field_meta = copy.deepcopy(field_template) + + fields_metadata[field_name] = field_meta + + return fields_metadata + + def _get_transformers(self, dtypes): + """Create the transformer instances needed to process the given dtypes. + + Args: + dtypes (dict): + mapping of field names and dtypes. + + Returns: + dict: + mapping of field names and transformer instances. + """ + transformer_templates = { + 'i': rdt.transformers.NumericalTransformer(dtype=int), + 'f': rdt.transformers.NumericalTransformer(dtype=float), + 'O': rdt.transformers.CategoricalTransformer, + 'b': rdt.transformers.BooleanTransformer, + 'M': rdt.transformers.DatetimeTransformer, + } + transformer_templates.update(self._transformer_templates) + + transformers = dict() + for name, dtype in dtypes.items(): + transformer_template = transformer_templates[np.dtype(dtype).kind] + if isinstance(transformer_template, type): + transformer = transformer_template() + else: + transformer = copy.deepcopy(transformer_template) + + LOGGER.info('Loading transformer %s for field %s', + transformer.__class__.__name__, name) + transformers[name] = transformer + + return transformers + + def _fit_hyper_transformer(self, data): + """Create and return a new ``rdt.HyperTransformer`` instance. + + First get the ``dtypes`` and then use them to build a transformer dictionary + to be used by the ``HyperTransformer``. + + Returns: + rdt.HyperTransformer + """ + dtypes = self.get_dtypes(ids=False) + transformers_dict = self._get_transformers(dtypes) + self._hyper_transformer = rdt.HyperTransformer(transformers=transformers_dict) + self._hyper_transformer.fit(data[list(dtypes.keys())]) + + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + + if field_type == 'categorical': + field_subtype = 'string' + + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + + def set_primary_key(self, field_name): + """Set the primary key of this table. + + The field must exist and either be an integer or categorical field. + + Args: + field_name (str): + Name of the field to be used as the new primary key. + + Raises: + ValueError: + If the table or the field do not exist or if the field has an + invalid type or subtype. + """ + if field_name is not None: + if field_name not in self._metadata['fields']: + raise ValueError('Field "{}" does not exist in this table'.format(field_name)) + + field_metadata = self._metadata['fields'][field_name] + field_subtype = self._get_key_subtype(field_metadata) + + field_metadata.update({ + 'type': 'id', + 'subtype': field_subtype + }) + + self._primary_key = field_name + + def _make_anonymization_mappings(self, data): + mappings = {} + for name, faker in self._fakers.items(): + uniques = data[name].unique() + fake_values = [faker() for x in range(len(uniques))] + mappings[name] = dict(zip(uniques, fake_values)) + + self._anonymization_mappings = mappings + + def _anonymize(self, data): + if self._anonymization_mappings: + data = data.copy() + for name, mapping in self._anonymization_mappings.items(): + data[name] = data[name].map(mapping) + + return data + + def fit(self, data): + """Fit this metadata to the given data. + + Args: + data (pandas.DataFrame): + Table to be analyzed. + """ + self._field_names = self._field_names or list(data.columns) + self._metadata['fields'] = self._build_fields_metadata(data) + self.set_primary_key(self._primary_key) + + if self._fakers: + self._make_anonymization_mappings(data) + data = self._anonymize(data) + + # TODO: Treat/Learn constraints + self._fit_hyper_transformer(data) + + def transform(self, data): + """Transform the given data. + + Args: + data (pandas.DataFrame): + Table data. + + Returns: + pandas.DataFrame: + Transformed data. + """ + data = self._anonymize(data[self._field_names]) + return self._hyper_transformer.transform(data) + + def reverse_transform(self, data): + """Reverse the transformed data to the original format. + + Args: + data (pandas.DataFrame): + Data to be reverse transformed. + + Returns: + pandas.DataFrame + """ + reversed_data = self._hyper_transformer.reverse_transform(data) + + fields = self._metadata['fields'] + for name, dtype in self.get_dtypes(ids=True).items(): + field_type = fields[name]['type'] + if field_type == 'id': + field_data = pd.Series(np.arange(len(reversed_data))) + else: + field_data = reversed_data[name] + + reversed_data[name] = field_data.dropna().astype(dtype) + + return reversed_data[self._field_names] + + # ###################### # + # Metadata Serialization # + # ###################### # + + def to_dict(self): + """Get a dict representation of this metadata. + + Returns: + dict: + dict representation of this metadata. + """ + return copy.deepcopy(self._metadata) + + def to_json(self, path): + """Dump this metadata into a JSON file. + + Args: + path (str): + Path of the JSON file where this metadata will be stored. + """ + with open(path, 'w') as out_file: + json.dump(self._metadata, out_file, indent=4) + + @classmethod + def from_json(cls, path): + """Load a Table from a JSON + + Args: + path (str): + Path of the JSON file to load + """ + with open(path, 'r') as in_file: + return cls(json.load(in_file)) diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py new file mode 100644 index 000000000..b14fb160a --- /dev/null +++ b/sdv/metadata/visualization.py @@ -0,0 +1,121 @@ +import graphviz + + +def _get_graphviz_extension(path): + if path: + path_splitted = path.split('.') + if len(path_splitted) == 1: + raise ValueError('Path without graphviz extansion.') + + graphviz_extension = path_splitted[-1] + + if graphviz_extension not in graphviz.backend.FORMATS: + raise ValueError( + '"{}" not a valid graphviz extension format.'.format(graphviz_extension) + ) + + return '.'.join(path_splitted[:-1]), graphviz_extension + + return None, None + + +def _add_nodes(metadata, digraph): + """Add nodes into a `graphviz.Digraph`. + + Each node represent a metadata table. + + Args: + metadata (Metadata): + Metadata object to plot. + digraph (graphviz.Digraph): + graphviz.Digraph being built + """ + for table in metadata.get_tables(): + # Append table fields + fields = [] + + for name, value in metadata.get_fields(table).items(): + if value.get('subtype') is not None: + fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) + + else: + fields.append('{} : {}'.format(name, value['type'])) + + fields = r'\l'.join(fields) + + # Append table extra information + extras = [] + + primary_key = metadata.get_primary_key(table) + if primary_key is not None: + extras.append('Primary key: {}'.format(primary_key)) + + parents = metadata.get_parents(table) + for parent in parents: + foreign_key = metadata.get_foreign_key(parent, table) + extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) + + path = metadata.get_table_meta(table).get('path') + if path is not None: + extras.append('Data path: {}'.format(path)) + + extras = r'\l'.join(extras) + + # Add table node + title = r'{%s|%s\l|%s\l}' % (table, fields, extras) + digraph.node(table, label=title) + + +def _add_edges(metadata, digraph): + """Add edges into a `graphviz.Digraph`. + + Each edge represents a relationship between two metadata tables. + + Args: + digraph (graphviz.Digraph) + """ + for table in metadata.get_tables(): + for parent in list(metadata.get_parents(table)): + digraph.edge( + parent, + table, + label=' {}.{} -> {}.{}'.format( + table, metadata.get_foreign_key(parent, table), + parent, metadata.get_primary_key(parent) + ), + arrowhead='oinv' + ) + + +def visualize(metadata, path=None): + """Plot metadata usign graphviz. + + Try to generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + metadata (Metadata): + Metadata object to plot. + path (str): + Output file path to save the plot, it requires a graphviz + supported extension. If ``None`` do not save the plot. + Defaults to ``None``. + """ + filename, graphviz_extension = _get_graphviz_extension(path) + digraph = graphviz.Digraph( + 'Metadata', + format=graphviz_extension, + node_attr={ + "shape": "Mrecord", + "fillcolor": "lightgoldenrod1", + "style": "filled" + }, + ) + + _add_nodes(metadata, digraph) + _add_edges(metadata, digraph) + + if filename: + digraph.render(filename=filename, cleanup=True, format=graphviz_extension) + else: + return digraph diff --git a/sdv/modeler.py b/sdv/modeler.py index aecfbd91a..7e766deb6 100644 --- a/sdv/modeler.py +++ b/sdv/modeler.py @@ -27,6 +27,7 @@ def __init__(self, metadata, model=GaussianCopula, model_kwargs=None): self.metadata = metadata self.model = model self.model_kwargs = dict() if model_kwargs is None else model_kwargs + self.table_sizes = dict() def _get_extension(self, child_name, child_table, foreign_key): """Generate list of extension for child tables. @@ -78,7 +79,7 @@ def cpa(self, table_name, tables, foreign_key=None): table_name (str): Name of the table to model. tables (dict): - Dict of tables tha have been already modeled. + Dict of original tables. foreign_key (str): Name of the foreign key that references this table. Used only when applying CPA on a child table. @@ -87,6 +88,7 @@ def cpa(self, table_name, tables, foreign_key=None): pandas.DataFrame: table data with the extensions created while modeling its children. """ + LOGGER.info('Modeling %s', table_name) if tables: @@ -94,6 +96,8 @@ def cpa(self, table_name, tables, foreign_key=None): else: table = self.metadata.load_table(table_name) + self.table_sizes[table_name] = len(table) + extended = self.metadata.transform(table_name, table) primary_key = self.metadata.get_primary_key(table_name) diff --git a/sdv/models/copulas.py b/sdv/models/copulas.py index 61c72bd9f..fcd8a3c8b 100644 --- a/sdv/models/copulas.py +++ b/sdv/models/copulas.py @@ -1,5 +1,7 @@ import numpy as np -from copulas import multivariate, univariate +from copulas import EPSILON +from copulas.multivariate import GaussianMultivariate +from copulas.univariate import GaussianUnivariate from sdv.models.base import SDVModel from sdv.models.utils import ( @@ -29,7 +31,7 @@ class GaussianCopula(SDVModel): 4 1.925887 """ - DISTRIBUTION = univariate.GaussianUnivariate + DISTRIBUTION = GaussianUnivariate distribution = None model = None @@ -46,7 +48,7 @@ def fit(self, table_data): Data to be fitted. """ table_data = impute(table_data) - self.model = multivariate.GaussianMultivariate(distribution=self.distribution) + self.model = GaussianMultivariate(distribution=self.distribution) self.model.fit(table_data) def sample(self, num_samples): @@ -79,11 +81,20 @@ def get_parameters(self): values.append(row[:index + 1]) self.model.covariance = np.array(values) - for distribution in self.model.distribs.values(): - if distribution.std is not None: - distribution.std = np.log(distribution.std) + params = self.model.to_dict() + univariates = dict() + for name, univariate in zip(params.pop('columns'), params['univariates']): + univariates[name] = univariate + if 'scale' in univariate: + scale = univariate['scale'] + if scale == 0: + scale = EPSILON - return flatten_dict(self.model.to_dict()) + univariate['scale'] = np.log(scale) + + params['univariates'] = univariates + + return flatten_dict(params) def _prepare_sampled_covariance(self, covariance): """Prepare a covariance matrix. @@ -126,17 +137,22 @@ def _unflatten_gaussian_copula(self, model_parameters): Model parameters ready to recreate the model. """ - distribution_kwargs = { - 'fitted': True, + univariate_kwargs = { 'type': model_parameters['distribution'] } - distribs = model_parameters['distribs'] - for distribution in distribs.values(): - distribution.update(distribution_kwargs) - distribution['std'] = np.exp(distribution['std']) + columns = list() + univariates = list() + for column, univariate in model_parameters['univariates'].items(): + columns.append(column) + univariate.update(univariate_kwargs) + univariate['scale'] = np.exp(univariate['scale']) + univariates.append(univariate) + + model_parameters['univariates'] = univariates + model_parameters['columns'] = columns - covariance = model_parameters['covariance'] + covariance = model_parameters.get('covariance') model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) return model_parameters @@ -156,8 +172,5 @@ def set_parameters(self, parameters): parameters.setdefault('distribution', self.distribution) parameters = self._unflatten_gaussian_copula(parameters) - for param in parameters['distribs'].values(): - param.setdefault('type', self.distribution) - param.setdefault('fitted', True) - self.model = multivariate.GaussianMultivariate.from_dict(parameters) + self.model = GaussianMultivariate.from_dict(parameters) diff --git a/sdv/models/utils.py b/sdv/models/utils.py index a839ecfb5..241a20048 100644 --- a/sdv/models/utils.py +++ b/sdv/models/utils.py @@ -20,11 +20,15 @@ def flatten_array(nested, prefix=''): for index in range(len(nested)): prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) - if isinstance(nested[index], (list, np.ndarray)): - result.update(flatten_array(nested[index], prefix=prefix_key)) + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) else: - result[prefix_key] = nested[index] + result[prefix_key] = value return result diff --git a/sdv/sampler.py b/sdv/sampler.py index b67f3e87a..dc6be988f 100644 --- a/sdv/sampler.py +++ b/sdv/sampler.py @@ -23,37 +23,52 @@ class Sampler: primary_key = None remaining_primary_key = None - def __init__(self, metadata, models, model, model_kwargs): + def __init__(self, metadata, models, model, model_kwargs, table_sizes): self.metadata = metadata self.models = models self.primary_key = dict() self.remaining_primary_key = dict() self.model = model self.model_kwargs = model_kwargs + self.table_sizes = table_sizes def _reset_primary_keys_generators(self): """Reset the primary key generators.""" self.primary_key = dict() self.remaining_primary_key = dict() - def _transform_synthesized_rows(self, synthesized, table_name): - """Reverse transform synthetized data. + def _finalize(self, sampled_data): + """Do the final touches to the generated data. + + This method reverts the previous transformations to go back + to values in the original space and also adds the parent + keys in case foreign key relationships exist between the tables. Args: - synthesized (pandas.DataFrame): - Generated data from model - table_name (str): - Name of the table. + sampled_data (dict): + Generated data Return: pandas.DataFrame: Formatted synthesized data. """ - reversed_data = self.metadata.reverse_transform(table_name, synthesized) + final_data = dict() + for table_name, table_rows in sampled_data.items(): + parents = self.metadata.get_parents(table_name) + if parents: + for parent_name in parents: + foreign_key = self.metadata.get_foreign_key(parent_name, table_name) + if foreign_key not in table_rows: + parent_ids = self._find_parent_ids(table_name, parent_name, sampled_data) + table_rows[foreign_key] = parent_ids + + reversed_data = self.metadata.reverse_transform(table_name, table_rows) + + fields = self.metadata.get_fields(table_name) - fields = self.metadata.get_fields(table_name) + final_data[table_name] = reversed_data[list(fields.keys())] - return reversed_data[list(fields.keys())] + return final_data def _get_primary_keys(self, table_name, num_rows): """Return the primary key and amount of values for the requested table. @@ -118,7 +133,7 @@ def _get_primary_keys(self, table_name, num_rows): return primary_key, primary_key_values - def _get_extension(self, parent_row, table_name): + def _extract_parameters(self, parent_row, table_name): """Get the params from a generated parent row. Args: @@ -127,7 +142,6 @@ def _get_extension(self, parent_row, table_name): table_name (str): Name of the table to make the model for. """ - prefix = '__{}__'.format(table_name) keys = [key for key in parent_row.keys() if key.startswith(prefix)] new_keys = {key: key[len(prefix):] for key in keys} @@ -157,34 +171,85 @@ def _sample_rows(self, model, num_rows, table_name): return sampled - def _sample_children(self, table_name, sampled): - table_rows = sampled[table_name] + def _sample_children(self, table_name, sampled_data): + table_rows = sampled_data[table_name] for child_name in self.metadata.get_children(table_name): for _, row in table_rows.iterrows(): - self._sample_table(child_name, table_name, row, sampled) + self._sample_child_rows(child_name, table_name, row, sampled_data) - def _sample_table(self, table_name, parent_name, parent_row, sampled): - extension = self._get_extension(parent_row, table_name) + def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data): + parameters = self._extract_parameters(parent_row, table_name) model = self.model(**self.model_kwargs) - model.set_parameters(extension) - num_rows = max(round(extension['child_rows']), 0) + model.set_parameters(parameters) + num_rows = max(round(parameters['child_rows']), 0) - sampled_rows = self._sample_rows(model, num_rows, table_name) + table_rows = self._sample_rows(model, num_rows, table_name) parent_key = self.metadata.get_primary_key(parent_name) foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - sampled_rows[foreign_key] = parent_row[parent_key] + table_rows[foreign_key] = parent_row[parent_key] - previous = sampled.get(table_name) + previous = sampled_data.get(table_name) if previous is None: - sampled[table_name] = sampled_rows + sampled_data[table_name] = table_rows + else: + sampled_data[table_name] = pd.concat([previous, table_rows]).reset_index(drop=True) + + self._sample_children(table_name, sampled_data) + + @staticmethod + def _find_parent_id(likelihoods, num_rows): + mean = likelihoods.mean() + if (likelihoods == 0).all(): + # All rows got 0 likelihood, fallback to num_rows + likelihoods = num_rows + elif pd.isnull(mean) or mean == 0: + # Some rows got singlar matrix error and the rest were 0 + # Fallback to num_rows on the singular matrix rows and + # keep 0s on the rest. + likelihoods = likelihoods.fillna(num_rows) + else: + # at least one row got a valid likelihood, so fill the + # rows that got a singular matrix error with the mean + likelihoods = likelihoods.fillna(mean) + + weights = likelihoods.values / likelihoods.sum() + + return np.random.choice(likelihoods.index, p=weights) + + def _get_likelihoods(self, table_rows, parent_rows, table_name): + likelihoods = dict() + for parent_id, row in parent_rows.iterrows(): + parameters = self._extract_parameters(row, table_name) + model = self.model(**self.model_kwargs) + model.set_parameters(parameters) + try: + likelihoods[parent_id] = model.model.probability_density(table_rows) + except np.linalg.LinAlgError: + likelihoods[parent_id] = None + + return pd.DataFrame(likelihoods, index=table_rows.index) + + def _find_parent_ids(self, table_name, parent_name, sampled_data): + table_rows = sampled_data[table_name] + if parent_name in sampled_data: + parent_rows = sampled_data[parent_name] else: - sampled[table_name] = pd.concat([previous, sampled_rows]).reset_index(drop=True) + ratio = self.table_sizes[parent_name] / self.table_sizes[table_name] + num_parent_rows = max(int(round(len(table_rows) * ratio)), 1) + parent_model = self.models[parent_name] + parent_rows = self._sample_rows(parent_model, num_parent_rows, parent_name) - self._sample_children(table_name, sampled) + primary_key = self.metadata.get_primary_key(parent_name) + parent_rows = parent_rows.set_index(primary_key) + num_rows = parent_rows['__' + table_name + '__child_rows'].clip(0) - def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children=True): + likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name) + return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) + + def sample(self, table_name, num_rows=None, reset_primary_keys=False, + sample_children=True, sample_parents=True): """Sample one table. Child tables will be sampled when ``sample_children`` is ``True``. @@ -195,11 +260,14 @@ def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children table_name (str): Table name to sample. num_rows (int): - Amount of rows to sample. + Amount of rows to sample. If ``None``, sample the same number of rows + as there were in the original table. reset_primary_keys (bool): Whether or not reset the primary keys generators. Defaults to ``False``. sample_children (bool): Whether or not sample child tables. Defaults to ``True``. + sample_parents (bool): + Whether or not sample parent tables. Defaults to ``True``. Returns: dict or pandas.DataFrame: @@ -207,37 +275,27 @@ def sample(self, table_name, num_rows, reset_primary_keys=False, sample_children and child tables. - Returns a ``pandas.DataFrame`` when ``sample_children`` is ``False``. """ - if reset_primary_keys: self._reset_primary_keys_generators() - model = self.models[table_name] - - sampled_rows = self._sample_rows(model, num_rows, table_name) + if num_rows is None: + num_rows = self.table_sizes[table_name] - parents = self.metadata.get_parents(table_name) - if parents: - parent_name = list(parents)[0] - foreign_key = self.metadata.get_foreign_key(parent_name, table_name) - parent_id = self._get_primary_keys(parent_name, 1)[1][0] - sampled_rows[foreign_key] = parent_id + model = self.models[table_name] + table_rows = self._sample_rows(model, num_rows, table_name) if sample_children: sampled_data = { - table_name: sampled_rows + table_name: table_rows } self._sample_children(table_name, sampled_data) - - for table, sampled_rows in sampled_data.items(): - sampled_data[table] = self._transform_synthesized_rows(sampled_rows, table) - - return sampled_data + return self._finalize(sampled_data) else: - return self._transform_synthesized_rows(sampled_rows, table_name) + return self._finalize({table_name: table_rows})[table_name] - def sample_all(self, num_rows=5, reset_primary_keys=False): + def sample_all(self, num_rows=None, reset_primary_keys=False): """Samples the entire dataset. ``sample_all`` returns a dictionary with all the tables of the dataset sampled. @@ -249,9 +307,10 @@ def sample_all(self, num_rows=5, reset_primary_keys=False): Args: num_rows (int): - Number of rows to be sampled on the parent tables. + Number of rows to be sampled on the first parent tables. If ``None``, + sample the same number of rows as in the original tables. reset_primary_keys (bool): - Wheter or not reset the primary key generators. + Whether or not reset the primary key generators. Returns: dict: diff --git a/sdv/sdv.py b/sdv/sdv.py index f309801a0..4f05282b7 100644 --- a/sdv/sdv.py +++ b/sdv/sdv.py @@ -67,16 +67,18 @@ def fit(self, metadata, tables=None, root_path=None): self.modeler = Modeler(self.metadata, self.model, self.model_kwargs) self.modeler.model_database(tables) - self.sampler = Sampler(self.metadata, self.modeler.models, self.model, self.model_kwargs) + self.sampler = Sampler(self.metadata, self.modeler.models, self.model, + self.model_kwargs, self.modeler.table_sizes) - def sample(self, table_name, num_rows, sample_children=True, reset_primary_keys=False): + def sample(self, table_name, num_rows=None, sample_children=True, reset_primary_keys=False): """Sample ``num_rows`` rows from the indicated table. Args: table_name (str): Name of the table to sample from. num_rows (int): - Amount of rows to sample. + Amount of rows to sample. If ``None``, sample the same number of rows + as there were in the original table. sample_children (bool): Whether or not to sample children tables. Defaults to ``True``. reset_primary_keys (bool): @@ -100,12 +102,13 @@ def sample(self, table_name, num_rows, sample_children=True, reset_primary_keys= reset_primary_keys=reset_primary_keys ) - def sample_all(self, num_rows=5, reset_primary_keys=False): + def sample_all(self, num_rows=None, reset_primary_keys=False): """Sample the entire dataset. Args: num_rows (int): - Amount of rows to sample. Defaults to ``5``. + Number of rows to be sampled on the first parent tables. If ``None``, + sample the same number of rows as in the original tables. reset_primary_keys (bool): Wheter or not reset the primary key generators. Defaults to ``False``. diff --git a/sdv/tabular/__init__.py b/sdv/tabular/__init__.py new file mode 100644 index 000000000..dae118d1d --- /dev/null +++ b/sdv/tabular/__init__.py @@ -0,0 +1,7 @@ +from sdv.tabular.copulas import GaussianCopula +from sdv.tabular.ctgan import CTGAN + +__all__ = [ + 'GaussianCopula', + 'CTGAN' +] diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py new file mode 100644 index 000000000..f21be7123 --- /dev/null +++ b/sdv/tabular/base.py @@ -0,0 +1,213 @@ +import pickle + +from sdv.metadata import Table + + +class BaseTabularModel(): + """Base class for all the tabular models. + + The ``BaseTabularModel`` class defines the common API that all the + TabularModels need to implement, as well as common functionality. + """ + + TRANSFORMER_TEMPLATES = None + + _metadata = None + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, table_metadata=None, *args, **kwargs): + """Initialize a Tabular Model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. + anonymize_fields (dict[str, str]): + Dict specifying which fields to anonymize and what faker + category they belong to. + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + *args, **kwargs: + Subclasses will add any arguments or keyword arguments needed. + """ + if table_metadata is not None: + if isinstance(table_metadata, dict): + table_metadata = Table(table_metadata,) + + for arg in (field_names, primary_key, field_types, anonymize_fields): + if arg: + raise ValueError( + 'If table_metadata is given {} must be None'.format(arg.__name__)) + + self._metadata = table_metadata + + else: + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types + self._anonymize_fields = anonymize_fields + + def _fit_metadata(self, data): + """Generate a new Table metadata and fit it to the data. + + The information provided will be used to create the Table instance + and then the rest of information will be learned from the given + data. + + Args: + data (pandas.DataFrame): + Data to learn from. + """ + metadata = Table( + field_names=self._field_names, + primary_key=self._primary_key, + field_types=self._field_types, + anonymize_fields=self._anonymize_fields, + transformer_templates=self.TRANSFORMER_TEMPLATES, + ) + metadata.fit(data) + + self._metadata = metadata + + def fit(self, data): + """Fit this model to the data. + + If the table metadata has not been given, learn it from the data. + + Args: + data (pandas.DataFrame or str): + Data to fit the model to. It can be passed as a + ``pandas.DataFrame`` or as an ``str``. + If an ``str`` is passed, it is assumed to be + the path to a CSV file which can be loaded using + ``pandas.read_csv``. + """ + if self._metadata is None: + self._fit_metadata(data) + + self._size = len(data) + + transformed = self._metadata.transform(data) + self._fit(transformed) + + def get_metadata(self): + """Get metadata about the table. + + This will return an ``sdv.metadata.Table`` object containing + the information about the data that this model has learned. + + This Table metadata will contain some common information, + such as field names and data types, as well as additional + information that each Sub-class might add, such as the + observed data field distributions and their parameters. + + Returns: + sdv.metadata.Table: + Table metadata. + """ + return self._metadata + + def sample(self, size=None, values=None): + """Sample rows from this table. + + Args: + size (int): + Number of rows to sample. If not given the model + will generate as many rows as there were in the + data passed to the ``fit`` method. + values (dict): <- FUTURE + Fixed values to use for knowledge-based sampling. + In case the model does not support knowledge-based + sampling, a discard+resample strategy will be used. + + Returns: + pandas.DataFrame: + Sampled data. + """ + size = size or self._size + sampled = self._sample(size) + return self._metadata.reverse_transform(sampled) + + def get_parameters(self): + """Get the parameters learned from the data. + + The result is a flat dict (single level) which contains + all the necessary parameters to be able to reproduce + this model. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + parameters (dict): + flat dict (single level) which contains all the + necessary parameters to be able to reproduce + this model. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() + + @classmethod + def from_parameters(cls): + """Regenerate a previously learned model from its parameters. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + BaseTabularModel: + New instance with its parameters set. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() + + def save(self, path): + """Save this model instance to the given path using pickle. + + Args: + path (str): + Path where the SDV instance will be serialized. + """ + with open(path, 'wb') as output: + pickle.dump(self, output) + + @classmethod + def load(cls, path): + """Load a TabularModel instance from a given path. + + Args: + path (str): + Path from which to load the instance. + + Returns: + TabularModel: + The loaded tabular model. + """ + with open(path, 'rb') as f: + return pickle.load(f) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py new file mode 100644 index 000000000..7ae00120f --- /dev/null +++ b/sdv/tabular/copulas.py @@ -0,0 +1,246 @@ +import copulas +import numpy as np +import rdt + +from sdv.tabular.base import BaseTabularModel +from sdv.tabular.utils import ( + check_matrix_symmetric_positive_definite, flatten_dict, make_positive_definite, square_matrix, + unflatten_dict) + + +class GaussianCopula(BaseTabularModel): + """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. + + Args: + distribution (copulas.univariate.Univariate or str): + Copulas univariate distribution to use. + categorical_transformer (str): + Type of transformer to use for the categorical variables, to choose + from ``one_hot_encoding``, ``label_encoding``, ``categorical`` and + ``categorical_fuzzy``. + """ + + DEFAULT_DISTRIBUTION = copulas.univariate.Univariate + _distribution = None + _categorical_transformer = None + _model = None + + HYPERPARAMETERS = { + 'distribution': { + 'type': 'str or copulas.univariate.Univariate', + 'default': 'copulas.univariate.Univariate', + 'description': 'Univariate distribution to use to model each column', + 'choices': [ + 'copulas.univariate.Univariate', + 'copulas.univariate.GaussianUnivariate', + 'copulas.univariate.GammaUnivariate', + 'copulas.univariate.BetaUnivariate', + 'copulas.univariate.StudentTUnivariate', + 'copulas.univariate.GaussianKDE', + 'copulas.univariate.TruncatedGaussian', + ] + }, + 'categorical_transformer': { + 'type': 'str', + 'default': 'categoircal_fuzzy', + 'description': 'Type of transformer to use for the categorical variables', + 'choices': [ + 'categorical', + 'categorical_fuzzy', + 'one_hot_encoding', + 'label_encoding' + ] + } + } + DEFAULT_TRANSFORMER = 'one_hot_encoding' + CATEGORICAL_TRANSFORMERS = { + 'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False), + 'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True), + 'one_hot_encoding': rdt.transformers.OneHotEncodingTransformer, + 'label_encoding': rdt.transformers.LabelEncodingTransformer, + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.OneHotEncodingTransformer + } + + def __init__(self, distribution=None, categorical_transformer=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self._metadata is not None and 'model_kwargs' in self._metadata._metadata: + model_kwargs = self._metadata._metadata['model_kwargs'] + if distribution is None: + distribution = model_kwargs['distribution'] + + if categorical_transformer is None: + categorical_transformer = model_kwargs['categorical_transformer'] + + self._size = model_kwargs['_size'] + + self._distribution = distribution or self.DEFAULT_DISTRIBUTION + + categorical_transformer = categorical_transformer or self.DEFAULT_TRANSFORMER + self._categorical_transformer = categorical_transformer + + self.TRANSFORMER_TEMPLATES['O'] = self.CATEGORICAL_TRANSFORMERS[categorical_transformer] + + def _update_metadata(self): + parameters = self._model.to_dict() + univariates = parameters['univariates'] + columns = parameters['columns'] + + distributions = {} + for column, univariate in zip(columns, univariates): + distributions[column] = univariate['type'] + + self._metadata._metadata['model_kwargs'] = { + 'distribution': distributions, + 'categorical_transformer': self._categorical_transformer, + '_size': self._size + } + + def _fit(self, data): + """Fit the model to the table. + + Args: + table_data (pandas.DataFrame): + Data to be fitted. + """ + self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) + self._model.fit(data) + self._update_metadata() + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) + + def get_parameters(self, flatten=False): + """Get copula model parameters. + + Compute model ``covariance`` and ``distribution.std`` + before it returns the flatten dict. + + Args: + flatten (bool): + Whether to flatten the parameters or not before + returning them. + + Returns: + dict: + Copula parameters. + """ + if not flatten: + return self._model.to_dict() + + values = list() + triangle = np.tril(self._model.covariance) + + for index, row in enumerate(triangle.tolist()): + values.append(row[:index + 1]) + + self._model.covariance = np.array(values) + params = self._model.to_dict() + univariates = dict() + for name, univariate in zip(params.pop('columns'), params['univariates']): + univariates[name] = univariate + if 'scale' in univariate: + scale = univariate['scale'] + if scale == 0: + scale = copulas.EPSILON + + univariate['scale'] = np.log(scale) + + params['univariates'] = univariates + + return flatten_dict(params) + + def _prepare_sampled_covariance(self, covariance): + """Prepare a covariance matrix. + + Args: + covariance (list): + covariance after unflattening model parameters. + + Result: + list[list]: + symmetric Positive semi-definite matrix. + """ + covariance = np.array(square_matrix(covariance)) + covariance = (covariance + covariance.T - (np.identity(covariance.shape[0]) * covariance)) + + if not check_matrix_symmetric_positive_definite(covariance): + covariance = make_positive_definite(covariance) + + return covariance.tolist() + + def _unflatten_gaussian_copula(self, model_parameters): + """Prepare unflattened model params to recreate Gaussian Multivariate instance. + + The preparations consist basically in: + + - Transform sampled negative standard deviations from distributions into positive + numbers + + - Ensure the covariance matrix is a valid symmetric positive-semidefinite matrix. + + - Add string parameters kept inside the class (as they can't be modelled), + like ``distribution_type``. + + Args: + model_parameters (dict): + Sampled and reestructured model parameters. + + Returns: + dict: + Model parameters ready to recreate the model. + """ + + univariate_kwargs = { + 'type': model_parameters['distribution'] + } + + columns = list() + univariates = list() + for column, univariate in model_parameters['univariates'].items(): + columns.append(column) + univariate.update(univariate_kwargs) + if 'scale' in univariate: + univariate['scale'] = np.exp(univariate['scale']) + + univariates.append(univariate) + + model_parameters['univariates'] = univariates + model_parameters['columns'] = columns + + covariance = model_parameters.get('covariance') + model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) + + return model_parameters + + def set_parameters(self, parameters, unflatten=False): + """Set copula model parameters. + + Add additional keys after unflatte the parameters + in order to set expected parameters for the copula. + + Args: + dict: + Copula flatten parameters. + unflatten (bool): + Whether the parameters need to be unflattened or not. + """ + if unflatten: + parameters = unflatten_dict(parameters) + parameters.setdefault('distribution', self._distribution) + + parameters = self._unflatten_gaussian_copula(parameters) + + self._model = copulas.multivariate.GaussianMultivariate.from_dict(parameters) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py new file mode 100644 index 000000000..eb1f87d33 --- /dev/null +++ b/sdv/tabular/ctgan.py @@ -0,0 +1,142 @@ +import rdt + +from sdv.tabular.base import BaseTabularModel + + +class CTGAN(BaseTabularModel): + """Model wrapping ``CTGANSynthesizer`` copula. + + Args: + TBD + """ + + _CTGAN_CLASS = None + _model = None + + HYPERPARAMETERS = { + 'TBD' + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.LabelEncodingTransformer + } + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, constraints=None, table_metadata=None, + epochs=300, log_frequency=True, embedding_dim=128, gen_dim=(256, 256), + dis_dim=(256, 256), l2scale=1e-6, batch_size=500): + """Initialize this CTGAN model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. + anonymize_fields (dict[str, str or tuple]): + Dict specifying which fields to anonymize and what faker + category they belong to. If arguments for the faker need to be + passed to fine tune the value generation a tuple can be passed, + where the first element is the category and the rest are additional + positional arguments for the Faker. + constraints (list[dict]): + List of dicts specifying field and inter-field constraints. + TODO: Format TBD + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + epochs (int): + Number of training epochs. Defaults to 300. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. + embedding_dim (int): + Size of the random sample passed to the Generator. Defaults to 128. + gen_dim (tuple or list of ints): + Size of the output samples for each one of the Residuals. A Resiudal Layer + will be created for each one of the values provided. Defaults to (256, 256). + dis_dim (tuple or list of ints): + Size of the output samples for each one of the Discriminator Layers. A Linear + Layer will be created for each one of the values provided. Defaults to (256, 256). + l2scale (float): + Wheight Decay for the Adam Optimizer. Defaults to 1e-6. + batch_size (int): + Number of data samples to process in each step. + """ + super().__init__( + field_names=field_names, + primary_key=primary_key, + field_types=field_types, + anonymize_fields=anonymize_fields, + constraints=constraints, + table_metadata=table_metadata + ) + try: + from ctgan import CTGANSynthesizer # Lazy import to make dependency optional + + self._CTGAN_CLASS = CTGANSynthesizer + except ImportError as ie: + ie.msg += ( + '\n\nIt seems like `ctgan` is not installed.\n' + 'Please install it using:\n\n pip install ctgan' + ) + raise + + self._embedding_dim = embedding_dim + self._gen_dim = gen_dim + self._dis_dim = dis_dim + self._l2scale = l2scale + self._batch_size = batch_size + self._epochs = epochs + self._log_frequency = log_frequency + + def _fit(self, data): + """Fit the model to the table. + + Args: + data (pandas.DataFrame): + Data to be learned. + """ + self._model = self._CTGAN_CLASS( + embedding_dim=self._embedding_dim, + gen_dim=self._gen_dim, + dis_dim=self._dis_dim, + l2scale=self._l2scale, + batch_size=self._batch_size, + ) + categoricals = [ + field + for field, meta in self._metadata.get_fields().items() + if meta['type'] == 'categorical' + ] + self._model.fit( + data, + epochs=self._epochs, + discrete_columns=categoricals, + log_frequency=self._log_frequency, + ) + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) diff --git a/sdv/tabular/utils.py b/sdv/tabular/utils.py new file mode 100644 index 000000000..241a20048 --- /dev/null +++ b/sdv/tabular/utils.py @@ -0,0 +1,219 @@ +import numpy as np + +IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] + + +def flatten_array(nested, prefix=''): + """Flatten an array as a dict. + + Args: + nested (list, numpy.array): + Iterable to flatten. + prefix (str): + Name to append to the array indices. Defaults to ``''``. + + Returns: + dict: + Flattened array. + """ + result = dict() + for index in range(len(nested)): + prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) + + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def flatten_dict(nested, prefix=''): + """Flatten a dictionary. + + This method returns a flatten version of a dictionary, concatenating key names with + double underscores. + + Args: + nested (dict): + Original dictionary to flatten. + prefix (str): + Prefix to append to key name. Defaults to ``''``. + + Returns: + dict: + Flattened dictionary. + """ + result = dict() + + for key, value in nested.items(): + prefix_key = '__'.join([prefix, str(key)]) if len(prefix) else key + + if key in IGNORED_DICT_KEYS and not isinstance(value, (dict, list)): + continue + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix_key)) + + elif isinstance(value, (np.ndarray, list)): + result.update(flatten_array(value, prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def _key_order(key_value): + parts = list() + for part in key_value[0].split('__'): + if part.isdigit(): + part = int(part) + + parts.append(part) + + return parts + + +def unflatten_dict(flat): + """Transform a flattened dict into its original form. + + Args: + flat (dict): + Flattened dict. + + Returns: + dict: + Nested dict (if corresponds) + """ + unflattened = dict() + + for key, value in sorted(flat.items(), key=_key_order): + if '__' in key: + key, subkey = key.split('__', 1) + subkey, name = subkey.rsplit('__', 1) + + if name.isdigit(): + column_index = int(name) + row_index = int(subkey) + + array = unflattened.setdefault(key, list()) + + if len(array) == row_index: + row = list() + array.append(row) + elif len(array) == row_index + 1: + row = array[row_index] + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + if len(row) == column_index: + row.append(value) + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + else: + subdict = unflattened.setdefault(key, dict()) + if subkey.isdigit(): + subkey = int(subkey) + + inner = subdict.setdefault(subkey, dict()) + inner[name] = value + + else: + unflattened[key] = value + + return unflattened + + +def impute(data): + for column in data: + column_data = data[column] + if column_data.dtype in (np.int, np.float): + fill_value = column_data.mean() + else: + fill_value = column_data.mode()[0] + + data[column] = data[column].fillna(fill_value) + + return data + + +def square_matrix(triangular_matrix): + """Fill with zeros a triangular matrix to reshape it to a square one. + + Args: + triangular_matrix (list [list [float]]): + Array of arrays of + + Returns: + list: + Square matrix. + """ + length = len(triangular_matrix) + zero = [0.0] + + for item in triangular_matrix: + item.extend(zero * (length - len(item))) + + return triangular_matrix + + +def check_matrix_symmetric_positive_definite(matrix): + """Check if a matrix is symmetric positive-definite. + + Args: + matrix (list or numpy.ndarray): + Matrix to evaluate. + + Returns: + bool + """ + try: + if len(matrix.shape) != 2 or matrix.shape[0] != matrix.shape[1]: + # Not 2-dimensional or square, so not simmetric. + return False + + np.linalg.cholesky(matrix) + return True + + except np.linalg.LinAlgError: + return False + + +def make_positive_definite(matrix): + """Find the nearest positive-definite matrix to input. + + Args: + matrix (numpy.ndarray): + Matrix to transform + + Returns: + numpy.ndarray: + Closest symetric positive-definite matrix. + """ + symetric_matrix = (matrix + matrix.T) / 2 + _, s, V = np.linalg.svd(symetric_matrix) + symmetric_polar = np.dot(V.T, np.dot(np.diag(s), V)) + A2 = (symetric_matrix + symmetric_polar) / 2 + A3 = (A2 + A2.T) / 2 + + if check_matrix_symmetric_positive_definite(A3): + return A3 + + spacing = np.spacing(np.linalg.norm(matrix)) + identity = np.eye(matrix.shape[0]) + iterations = 1 + while not check_matrix_symmetric_positive_definite(A3): + min_eigenvals = np.min(np.real(np.linalg.eigvals(A3))) + A3 += identity * (-min_eigenvals * iterations**2 + spacing) + iterations += 1 + + return A3 diff --git a/setup.cfg b/setup.cfg index 24389c418..e4bf6d2c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.3 +current_version = 0.3.5.dev0 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index 0c8189cdf..ad3d04e62 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,17 @@ install_requires = [ 'exrex>=0.9.4,<0.11', - 'numpy>=1.15.4,<1.17', + 'numpy>=1.15.4,<2', 'pandas>=0.23.4,<0.25', - 'copulas>=0.2.5,<0.3', - 'rdt>=0.2.1,<0.3', + 'copulas>=0.3.1,<0.4', + 'rdt>=0.2.3,<0.3', 'graphviz>=0.13.2', - 'sdmetrics>=0.0.1,<0.0.2' + 'sdmetrics>=0.0.2.dev0,<0.0.3', + 'scikit-learn<0.23,>=0.21', +] + +ctgan_requires = [ + 'ctgan>=0.2.2.dev0,<0.3', ] setup_requires = [ @@ -40,6 +45,7 @@ # docs 'm2r>=0.2.0,<0.3', + 'nbsphinx>=0.5.0,<0.7', 'Sphinx>=1.7.1,<3', 'sphinx_rtd_theme>=0.2.4,<0.5', 'autodocsumm>=0.1.10', @@ -76,8 +82,8 @@ ], description='Automated Generative Modeling and Sampling', extras_require={ - 'test': tests_require, - 'dev': development_requires + tests_require, + 'test': tests_require + ctgan_requires, + 'dev': development_requires + tests_require + ctgan_requires, }, install_package_data=True, install_requires=install_requires, @@ -93,6 +99,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/SDV', - version='0.3.3', + version='0.3.5.dev0', zip_safe=False, ) diff --git a/tests/integration/tabular/test_copulas.py b/tests/integration/tabular/test_copulas.py new file mode 100644 index 000000000..0a34a0e15 --- /dev/null +++ b/tests/integration/tabular/test_copulas.py @@ -0,0 +1,56 @@ +from sdv.demo import load_demo +from sdv.tabular.copulas import GaussianCopula + + +def test_gaussian_copula(): + users = load_demo(metadata=False)['users'] + + field_types = { + 'age': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'country': { + 'type': 'categorical' + } + } + anonymize_fields = { + 'country': 'country_code' + } + + gc = GaussianCopula( + field_names=['user_id', 'country', 'gender', 'age'], + field_types=field_types, + primary_key='user_id', + anonymize_fields=anonymize_fields, + categorical_transformer='one_hot_encoding', + ) + gc.fit(users) + + parameters = gc.get_parameters() + new_gc = GaussianCopula( + table_metadata=gc.get_metadata(), + categorical_transformer='one_hot_encoding', + ) + new_gc.set_parameters(parameters) + + sampled = new_gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + # country codes have been replaced with new ones + assert set(sampled.country.unique()) != set(users.country.unique()) + + metadata = gc.get_metadata().to_dict() + assert metadata['fields'] == { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + + assert 'model_kwargs' in metadata diff --git a/tests/integration/tabular/test_ctgan.py b/tests/integration/tabular/test_ctgan.py new file mode 100644 index 000000000..8cc8d4ad6 --- /dev/null +++ b/tests/integration/tabular/test_ctgan.py @@ -0,0 +1,29 @@ +from sdv.demo import load_demo +from sdv.tabular.ctgan import CTGAN + + +def test_ctgan(): + users = load_demo(metadata=False)['users'] + + gc = CTGAN( + primary_key='user_id', + epochs=1 + ) + gc.fit(users) + + sampled = gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + assert gc.get_metadata().to_dict() == { + 'fields': { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + } diff --git a/tests/integration/test_sdv.py b/tests/integration/test_sdv.py new file mode 100644 index 000000000..f72cb22b4 --- /dev/null +++ b/tests/integration/test_sdv.py @@ -0,0 +1,72 @@ +from sdv import SDV, load_demo + + +def test_sdv(): + metadata, tables = load_demo(metadata=True) + + sdv = SDV() + sdv.fit(metadata, tables) + + # Sample all + sampled = sdv.sample_all() + + assert set(sampled.keys()) == {'users', 'sessions', 'transactions'} + assert len(sampled['users']) == 10 + + # Sample with children + sampled = sdv.sample('users', reset_primary_keys=True) + + assert set(sampled.keys()) == {'users', 'sessions', 'transactions'} + assert len(sampled['users']) == 10 + + # Sample without children + users = sdv.sample('users', sample_children=False) + + assert users.shape == tables['users'].shape + assert set(users.columns) == set(tables['users'].columns) + + sessions = sdv.sample('sessions', sample_children=False) + + assert sessions.shape == tables['sessions'].shape + assert set(sessions.columns) == set(tables['sessions'].columns) + + transactions = sdv.sample('transactions', sample_children=False) + + assert transactions.shape == tables['transactions'].shape + assert set(transactions.columns) == set(tables['transactions'].columns) + + +def test_sdv_multiparent(): + metadata, tables = load_demo('got_families', metadata=True) + + sdv = SDV() + sdv.fit(metadata, tables) + + # Sample all + sampled = sdv.sample_all() + + assert set(sampled.keys()) == {'characters', 'families', 'character_families'} + assert len(sampled['characters']) == 7 + + # Sample with children + sampled = sdv.sample('characters', reset_primary_keys=True) + + assert set(sampled.keys()) == {'characters', 'character_families'} + assert len(sampled['characters']) == 7 + assert 'family_id' in sampled['character_families'] + + # Sample without children + characters = sdv.sample('characters', sample_children=False) + + assert characters.shape == tables['characters'].shape + assert set(characters.columns) == set(tables['characters'].columns) + + families = sdv.sample('families', sample_children=False) + + assert families.shape == tables['families'].shape + assert set(families.columns) == set(tables['families'].columns) + + character_families = sdv.sample('character_families', sample_children=False) + + assert character_families.shape == tables['character_families'].shape + assert set(character_families.columns) == set(tables['character_families'].columns) diff --git a/tests/test_metadata.py b/tests/metadata/test___init__.py similarity index 84% rename from tests/test_metadata.py rename to tests/metadata/test___init__.py index a322bc7c7..9968597ae 100644 --- a/tests/test_metadata.py +++ b/tests/metadata/test___init__.py @@ -1,7 +1,6 @@ from unittest import TestCase -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import Mock, call, patch -import graphviz import pandas as pd import pytest @@ -218,28 +217,6 @@ def test__dict_metadata_dict(self): } assert result == expected - def test__validate_parents_no_error(self): - """Test that any error is raised with a supported structure""" - # Setup - mock = MagicMock(spec_set=Metadata) - mock.get_parents.return_value = [] - - # Run - Metadata._validate_parents(mock, 'demo') - - # Asserts - mock.get_parents.assert_called_once_with('demo') - - def test__validate_parents_raise_error(self): - """Test that a ValueError is raised because the bad structure""" - # Setup - mock = MagicMock(spec_set=Metadata) - mock.get_parents.return_value = ['foo', 'bar'] - - # Run - with pytest.raises(MetadataError): - Metadata._validate_parents(mock, 'demo') - @patch('sdv.metadata.Metadata._analyze_relationships') @patch('sdv.metadata.Metadata._dict_metadata') def test___init__default_metadata_dict(self, mock_meta, mock_relationships): @@ -403,6 +380,7 @@ def test_get_dtypes_error_id(self): } metadata = Mock(spec_set=Metadata) metadata.get_table_meta.return_value = table_meta + metadata.get_children.return_value = [] metadata._DTYPES = Metadata._DTYPES # Run @@ -630,23 +608,16 @@ def test_get_primary_key(self): def test_get_foreign_key(self): """Test get foreign key""" # Setup - primary_key = 'a_primary_key' fields = { 'a_field': { 'ref': { + 'table': 'parent', 'field': 'a_primary_key' }, 'name': 'a_field' - }, - 'p_field': { - 'ref': { - 'field': 'another_key_field' - }, - 'name': 'p_field' } } metadata = Mock(spec_set=Metadata) - metadata.get_primary_key.return_value = primary_key metadata.get_fields.return_value = fields # Run @@ -654,7 +625,6 @@ def test_get_foreign_key(self): # Asserts assert result == 'a_field' - metadata.get_primary_key.assert_called_once_with('parent') metadata.get_fields.assert_called_once_with('child') def test_reverse_transform(self): @@ -992,121 +962,3 @@ def test_add_field(self): assert metadata._metadata == expected_metadata metadata._check_field.assert_called_once_with('a_table', 'a_field', exists=False) - - def test__get_graphviz_extension_path_without_extension(self): - """Raises a ValueError when the path doesn't contains an extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path') - - def test__get_graphviz_extension_invalid_extension(self): - """Raises a ValueError when the path contains an invalid extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path.foo') - - def test__get_graphviz_extension_none(self): - """Get graphviz with path equals to None.""" - # Run - result = Metadata._get_graphviz_extension(None) - - # Asserts - assert result == (None, None) - - def test__get_graphviz_extension_valid(self): - """Get a valid graphviz extension.""" - # Run - result = Metadata._get_graphviz_extension('/some/path.png') - - # Asserts - assert result == ('/some/path', 'png') - - def test__visualize_add_nodes(self): - """Add nodes into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - minimock = Mock() - - # pass tests in python3.5 - minimock.items.return_value = ( - ('a_field', {'type': 'numerical', 'subtype': 'integer'}), - ('b_field', {'type': 'id'}), - ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) - ) - - metadata.get_tables.return_value = ['demo'] - metadata.get_fields.return_value = minimock - - metadata.get_primary_key.return_value = 'b_field' - metadata.get_parents.return_value = set(['other']) - metadata.get_foreign_key.return_value = 'c_field' - - metadata.get_table_meta.return_value = {'path': None} - - plot = Mock() - - # Run - Metadata._visualize_add_nodes(metadata, plot) - - # Asserts - expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ - r"c_field : id\l|Primary key: b_field\l" \ - r"Foreign key (other): c_field\l}" - - metadata.get_fields.assert_called_once_with('demo') - metadata.get_primary_key.assert_called_once_with('demo') - metadata.get_parents.assert_called_once_with('demo') - metadata.get_table_meta.assert_called_once_with('demo') - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_table_meta.assert_called_once_with('demo') - - plot.node.assert_called_once_with('demo', label=expected_node_label) - - def test__visualize_add_edges(self): - """Add edges into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - - metadata.get_tables.return_value = ['demo', 'other'] - metadata.get_parents.side_effect = [set(['other']), set()] - - metadata.get_foreign_key.return_value = 'fk' - metadata.get_primary_key.return_value = 'pk' - - plot = Mock() - - # Run - Metadata._visualize_add_edges(metadata, plot) - - # Asserts - expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') - - metadata.get_tables.assert_called_once_with() - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_primary_key.assert_called_once_with('other') - assert metadata.get_parents.call_args_list == [call('demo'), call('other')] - - plot.edge.assert_called_once_with( - 'other', - 'demo', - label=expected_edge_label, - arrowhead='crow' - ) - - @patch('sdv.metadata.graphviz') - def test_visualize(self, graphviz_mock): - """Metadata visualize digraph""" - # Setup - plot = Mock(spec_set=graphviz.Digraph) - graphviz_mock.Digraph.return_value = plot - - metadata = MagicMock(spec_set=Metadata) - metadata._get_graphviz_extension.return_value = ('output', 'png') - - # Run - Metadata.visualize(metadata, path='output.png') - - # Asserts - metadata._get_graphviz_extension.assert_called_once_with('output.png') - metadata._visualize_add_nodes.assert_called_once_with(plot) - metadata._visualize_add_edges.assert_called_once_with(plot) - - plot.render.assert_called_once_with(filename='output', cleanup=True, format='png') diff --git a/tests/metadata/test_visualization.py b/tests/metadata/test_visualization.py new file mode 100644 index 000000000..b9f588593 --- /dev/null +++ b/tests/metadata/test_visualization.py @@ -0,0 +1,132 @@ +from unittest.mock import MagicMock, Mock, call, patch + +import graphviz +import pytest + +from sdv.metadata import Metadata, visualization + + +def test__get_graphviz_extension_path_without_extension(): + """Raises a ValueError when the path doesn't contains an extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path') + + +def test__get_graphviz_extension_invalid_extension(): + """Raises a ValueError when the path contains an invalid extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path.foo') + + +def test__get_graphviz_extension_none(): + """Get graphviz with path equals to None.""" + # Run + result = visualization._get_graphviz_extension(None) + + # Asserts + assert result == (None, None) + + +def test__get_graphviz_extension_valid(): + """Get a valid graphviz extension.""" + # Run + result = visualization._get_graphviz_extension('/some/path.png') + + # Asserts + assert result == ('/some/path', 'png') + + +def test__add_nodes(): + """Add nodes into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + minimock = Mock() + + # pass tests in python3.5 + minimock.items.return_value = ( + ('a_field', {'type': 'numerical', 'subtype': 'integer'}), + ('b_field', {'type': 'id'}), + ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) + ) + + metadata.get_tables.return_value = ['demo'] + metadata.get_fields.return_value = minimock + + metadata.get_primary_key.return_value = 'b_field' + metadata.get_parents.return_value = set(['other']) + metadata.get_foreign_key.return_value = 'c_field' + + metadata.get_table_meta.return_value = {'path': None} + + plot = Mock() + + # Run + visualization._add_nodes(metadata, plot) + + # Asserts + expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ + r"c_field : id\l|Primary key: b_field\l" \ + r"Foreign key (other): c_field\l}" + + metadata.get_fields.assert_called_once_with('demo') + metadata.get_primary_key.assert_called_once_with('demo') + metadata.get_parents.assert_called_once_with('demo') + metadata.get_table_meta.assert_called_once_with('demo') + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_table_meta.assert_called_once_with('demo') + + plot.node.assert_called_once_with('demo', label=expected_node_label) + + +def test__add_edges(): + """Add edges into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + + metadata.get_tables.return_value = ['demo', 'other'] + metadata.get_parents.side_effect = [set(['other']), set()] + + metadata.get_foreign_key.return_value = 'fk' + metadata.get_primary_key.return_value = 'pk' + + plot = Mock() + + # Run + visualization._add_edges(metadata, plot) + + # Asserts + expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') + + metadata.get_tables.assert_called_once_with() + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_primary_key.assert_called_once_with('other') + assert metadata.get_parents.call_args_list == [call('demo'), call('other')] + + plot.edge.assert_called_once_with( + 'other', + 'demo', + label=expected_edge_label, + arrowhead='oinv' + ) + + +@patch('sdv.metadata.visualization.graphviz.Digraph', spec_set=graphviz.Digraph) +@patch('sdv.metadata.visualization._add_nodes') +@patch('sdv.metadata.visualization._add_edges') +def test_visualize(add_nodes_mock, add_edges_mock, digraph_mock): + """Metadata visualize digraph""" + # Setup + # plot = Mock(spec_set=graphviz.Digraph) + # graphviz_mock.Digraph.return_value = plot + + metadata = MagicMock(spec_set=Metadata) + + # Run + visualization.visualize(metadata, path='output.png') + + # Asserts + digraph = digraph_mock.return_value + add_nodes_mock.assert_called_once_with(metadata, digraph) + add_edges_mock.assert_called_once_with(metadata, digraph) + + digraph.render.assert_called_once_with(filename='output', cleanup=True, format='png') diff --git a/tests/models/test_copulas.py b/tests/models/test_copulas.py index 0dd0579d1..56a14215b 100644 --- a/tests/models/test_copulas.py +++ b/tests/models/test_copulas.py @@ -26,8 +26,11 @@ def test__unflatten_gaussian_copula(self): # Run model_parameters = { - 'distribs': { - 'foo': {'std': 0.5} + 'univariates': { + 'foo': { + 'scale': 0.0, + 'loc': 5 + }, }, 'covariance': [[0.4, 0.1], [0.1]], 'distribution': 'GaussianUnivariate' @@ -36,13 +39,14 @@ def test__unflatten_gaussian_copula(self): # Asserts expected = { - 'distribs': { - 'foo': { - 'fitted': True, - 'std': 1.6487212707001282, + 'univariates': [ + { + 'scale': 1.0, + 'loc': 5, 'type': 'GaussianUnivariate' } - }, + ], + 'columns': ['foo'], 'distribution': 'GaussianUnivariate', 'covariance': [[0.4, 0.2], [0.2, 0.0]] } diff --git a/tests/test_modeler.py b/tests/test_modeler.py index bf162d6c8..f701e2469 100644 --- a/tests/test_modeler.py +++ b/tests/test_modeler.py @@ -69,6 +69,7 @@ def test_cpa_with_tables_no_primary_key(self): modeler.model = Mock(spec=SDVModel) modeler.model_kwargs = dict() modeler.models = dict() + modeler.table_sizes = {'data': 5} modeler.metadata.transform.return_value = pd.DataFrame({'data': [1, 2, 3]}) modeler.metadata.get_primary_key.return_value = None diff --git a/tests/test_sampler.py b/tests/test_sampler.py index a5eddf75f..6474780ad 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,6 +1,6 @@ -from unittest import TestCase -from unittest.mock import Mock +from unittest.mock import Mock, patch +import numpy as np import pandas as pd import pytest @@ -9,13 +9,19 @@ from sdv.sampler import Sampler -class TestSampler(TestCase): +class TestSampler: def test___init__(self): """Test create a default instance of Sampler class""" # Run models = {'test': Mock()} - sampler = Sampler('test_metadata', models, SDVModel, dict()) + sampler = Sampler( + 'test_metadata', + models, + SDVModel, + {'model': 'kwargs'}, + {'table': 'sizes'} + ) # Asserts assert sampler.metadata == 'test_metadata' @@ -23,7 +29,8 @@ def test___init__(self): assert sampler.primary_key == dict() assert sampler.remaining_primary_key == dict() assert sampler.model == SDVModel - assert sampler.model_kwargs == dict() + assert sampler.model_kwargs == {'model': 'kwargs'} + assert sampler.table_sizes == {'table': 'sizes'} def test__reset_primary_keys_generators(self): """Test reset values""" @@ -39,23 +46,49 @@ def test__reset_primary_keys_generators(self): assert sampler.primary_key == dict() assert sampler.remaining_primary_key == dict() - def test__transform_synthesized_rows(self): - """Test transform synthesized rows""" + def test__finalize(self): + """Test finalize""" # Setup - metadata_reverse_transform = pd.DataFrame({'foo': [0, 1], 'bar': [2, 3], 'tar': [4, 5]}) - sampler = Mock(spec=Sampler) sampler.metadata = Mock(spec=Metadata) - sampler.metadata.reverse_transform.return_value = metadata_reverse_transform - sampler.metadata.get_fields.return_value = {'foo': 'some data', 'tar': 'some data'} + sampler.metadata.get_parents.return_value = ['b', 'c'] + + sampler.metadata.reverse_transform.side_effect = lambda x, y: y + + sampler.metadata.get_fields.return_value = { + 'a': 'some data', + 'b': 'some data', # fk + 'c': 'some data' # fk + } + + sampler._find_parent_ids.return_value = [4, 5] + sampler.metadata.get_foreign_key.side_effect = [ + 'b', + 'c', + ] # Run - synthesized = pd.DataFrame({'data': [1, 2, 3]}) - result = Sampler._transform_synthesized_rows(sampler, synthesized, 'test') + sampled_data = { + 'test': pd.DataFrame({ + 'a': [0, 1], # actual data + 'b': [2, 3], # existing fk key + 'z': [6, 7] # not used + }) + } + result = Sampler._finalize(sampler, sampled_data) # Asserts - expected = pd.DataFrame({'foo': [0, 1], 'tar': [4, 5]}) - pd.testing.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + assert isinstance(result, dict) + expected = pd.DataFrame({ + 'a': [0, 1], + 'b': [2, 3], + 'c': [4, 5], + }) + pd.testing.assert_frame_equal( + result['test'].sort_index(axis=1), + expected.sort_index(axis=1) + ) + assert sampler._find_parent_ids.call_count == 1 def test__get_primary_keys_none(self): """Test returns a tuple of none when a table doesn't have a primary key""" @@ -133,15 +166,15 @@ def test__get_primary_keys_raises_value_error_remaining(self): with pytest.raises(ValueError): Sampler._get_primary_keys(sampler, 'test', 5) - def test__get_extension(self): - """Test get extension""" + def test__extract_parameters(self): + """Test extract parameters""" # Setup sampler = Mock(spec=Sampler) # Run parent_row = pd.Series([[0, 1], [1, 0]], index=['__foo__field', '__foo__field2']) table_name = 'foo' - result = Sampler._get_extension(sampler, parent_row, table_name) + result = Sampler._extract_parameters(sampler, parent_row, table_name) # Asserts expected = {'field': [0, 1], 'field2': [1, 0]} @@ -172,33 +205,33 @@ def test__sample_children(self): sampler.metadata.get_children.return_value = ['child A', 'child B', 'child C'] # Run - sampled = { + sampled_data = { 'test': pd.DataFrame({'field': [11, 22, 33]}) } - Sampler._sample_children(sampler, 'test', sampled) + Sampler._sample_children(sampler, 'test', sampled_data) # Asserts sampler.metadata.get_children.assert_called_once_with('test') expected_calls = [ - ['child A', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child A', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child A', 'test', pd.Series([33], index=['field'], name=2), sampled], - ['child B', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child B', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child B', 'test', pd.Series([33], index=['field'], name=2), sampled], - ['child C', 'test', pd.Series([11], index=['field'], name=0), sampled], - ['child C', 'test', pd.Series([22], index=['field'], name=1), sampled], - ['child C', 'test', pd.Series([33], index=['field'], name=2), sampled], + ['child A', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child A', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child A', 'test', pd.Series([33], index=['field'], name=2), sampled_data], + ['child B', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child B', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child B', 'test', pd.Series([33], index=['field'], name=2), sampled_data], + ['child C', 'test', pd.Series([11], index=['field'], name=0), sampled_data], + ['child C', 'test', pd.Series([22], index=['field'], name=1), sampled_data], + ['child C', 'test', pd.Series([33], index=['field'], name=2), sampled_data], ] - actual_calls = sampler._sample_table.call_args_list + actual_calls = sampler._sample_child_rows.call_args_list for result_call, expected_call in zip(actual_calls, expected_calls): assert result_call[0][0] == expected_call[0] assert result_call[0][1] == expected_call[1] assert result_call[0][3] == expected_call[3] pd.testing.assert_series_equal(result_call[0][2], expected_call[2]) - def test__sample_table_sampled_empty(self): + def test__sample_child_rows_sampled_empty(self): """Test sample table when sampled is still an empty dict.""" # Setup model = Mock(spec=SDVModel) @@ -208,7 +241,7 @@ def test__sample_table_sampled_empty(self): sampler.model = model sampler.model_kwargs = dict() - sampler._get_extension.return_value = {'child_rows': 5} + sampler._extract_parameters.return_value = {'child_rows': 5} table_model_mock = Mock() sampler.models = {'test': table_model_mock} @@ -222,10 +255,10 @@ def test__sample_table_sampled_empty(self): # Run parent_row = pd.Series({'id': 0}) sampled = dict() - Sampler._sample_table(sampler, 'test', 'parent', parent_row, sampled) + Sampler._sample_child_rows(sampler, 'test', 'parent', parent_row, sampled) # Asserts - sampler._get_extension.assert_called_once_with(parent_row, 'test') + sampler._extract_parameters.assert_called_once_with(parent_row, 'test') sampler._sample_rows.assert_called_once_with(model, 5, 'test') assert sampler._sample_children.call_count == 1 @@ -240,7 +273,7 @@ def test__sample_table_sampled_empty(self): expected_sampled ) - def test__sample_table_sampled_not_empty(self): + def test__sample_child_rows_sampled_not_empty(self): """Test sample table when sampled previous sampled rows exist.""" # Setup model = Mock(spec=SDVModel) @@ -249,7 +282,7 @@ def test__sample_table_sampled_not_empty(self): sampler = Mock(spec=Sampler) sampler.model = model sampler.model_kwargs = dict() - sampler._get_extension.return_value = {'child_rows': 5} + sampler._extract_parameters.return_value = {'child_rows': 5} table_model_mock = Mock() sampler.models = {'test': table_model_mock} @@ -270,11 +303,10 @@ def test__sample_table_sampled_not_empty(self): 'parent_id': [0, 0, 0, 0, 0] }) } - Sampler._sample_table(sampler, 'test', 'parent', parent_row, sampled) + Sampler._sample_child_rows(sampler, 'test', 'parent', parent_row, sampled) # Asserts - sampler._get_extension.assert_called_once_with(parent_row, 'test') - # sampler._get_model.assert_called_once_with({'child_rows': 5}, table_model_mock) + sampler._extract_parameters.assert_called_once_with(parent_row, 'test') sampler._sample_rows.assert_called_once_with(model, 5, 'test') assert sampler._sample_children.call_count == 1 @@ -309,30 +341,72 @@ def sample_side_effect(table, num_rows): pd.testing.assert_frame_equal(result['table a'], pd.DataFrame({'foo': range(3)})) pd.testing.assert_frame_equal(result['table c'], pd.DataFrame({'foo': range(3)})) - def test_sample_table_with_parents(self): - """Test sample table with parents.""" - sampler = Mock(spec=Sampler) - sampler.metadata = Mock(spec=Metadata) - sampler.metadata.get_parents.return_value = ['test_parent'] - sampler.metadata.get_foreign_key.return_value = 'id' - sampler.models = {'test': 'some model'} - sampler._get_primary_keys.return_value = None, pd.Series({'id': 0}) - sampler._sample_rows.return_value = pd.DataFrame({'id': [0, 1]}) - - Sampler.sample(sampler, 'test', 5) - sampler.metadata.get_parents.assert_called_once_with('test') - sampler.metadata.get_foreign_key.assert_called_once_with('test_parent', 'test') - - def test_sample_no_sample_children(self): - """Test sample no sample children""" - # Setup - sampler = Mock(spec=Sampler) - sampler.models = {'test': 'model'} - sampler.metadata.get_parents.return_value = None + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_0(self, choice_mock): + """If all likelihoods are 0, use num_rows.""" + likelihoods = pd.Series([0, 0, 0, 0]) + num_rows = pd.Series([1, 2, 3, 4]) - # Run - Sampler.sample(sampler, 'test', 5, sample_children=False) - sampler._transform_synthesized_rows.assert_called_once_with( - sampler._sample_rows.return_value, - 'test' - ) + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) + + assert choice_mock.call_count == 1 + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_singlar_matrix(self, choice_mock): + """If all likelihoods got singular matrix, use num_rows.""" + likelihoods = pd.Series([None, None, None, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([1 / 10, 2 / 10, 3 / 10, 4 / 10]) + + assert choice_mock.call_count == 1 + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_0_or_singlar_matrix(self, choice_mock): + """If likehoods are either 0 or NaN, fill the gaps with num_rows.""" + likelihoods = pd.Series([0, None, 0, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0, 2 / 6, 0, 4 / 6]) + + assert choice_mock.call_count == 1 + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_some_good(self, choice_mock): + """If some likehoods are good, fill the gaps with num_rows.""" + likelihoods = pd.Series([0.5, None, 1.5, None]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0.5 / 4, 1 / 4, 1.5 / 4, 1 / 4]) + + assert choice_mock.call_count == 1 + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) + + @patch('sdv.sampler.np.random.choice') + def test__find_parent_id_all_good(self, choice_mock): + """If all are good, use the likelihoods unmodified.""" + likelihoods = pd.Series([0.5, 1, 1.5, 2]) + num_rows = pd.Series([1, 2, 3, 4]) + + Sampler._find_parent_id(likelihoods, num_rows) + + expected_weights = np.array([0.5 / 5, 1 / 5, 1.5 / 5, 2 / 5]) + + assert choice_mock.call_count == 1 + assert list(choice_mock.call_args[0][0]) == list(likelihoods.index) + np.testing.assert_array_equal(choice_mock.call_args[1]['p'], expected_weights) diff --git a/tests/test_sdv.py b/tests/test_sdv.py index f1373129e..21c61f7ef 100644 --- a/tests/test_sdv.py +++ b/tests/test_sdv.py @@ -86,7 +86,7 @@ def test_sample_all_fitted(self): # Asserts assert result == 'test' - sdv.sampler.sample_all.assert_called_once_with(5, reset_primary_keys=False) + sdv.sampler.sample_all.assert_called_once_with(None, reset_primary_keys=False) def test_sample_all_not_fitted(self): """Check that the sample_all raise an exception when is not fitted.""" diff --git a/tutorials/01_Quickstart.ipynb b/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/02_Single_Table_Modeling.ipynb b/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..fa536bd3d --- /dev/null +++ b/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM38
11UKNaN23
22USAF34
33ESNaN47
44ESF29
55UKF39
66FRNaN40
77ESM38
88ESF32
99ESF36
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKF48
11USANaN38
22USAM29
33BGM22
44USAM43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" + ] + } + ], + "source": [ + "model.fit(census.sample(1000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
050Local-gov1697191st-4th9Widowed?HusbandWhiteMale114838Columbia<=50K
132?1524791st-4th9Never-marriedAdm-clericalWifeBlackMale-422021Jamaica>50K
222Private69617Bachelors0Separated?HusbandWhiteMale61138Guatemala<=50K
325?65285810th16Married-civ-spouseHandlers-cleanersNot-in-familyWhiteFemale152-2739Cuba<=50K
443Private301956Some-college8Married-civ-spouse?WifeWhiteMale-133-1239India<=50K
566Private401171Prof-school13SeparatedProtective-servUnmarriedBlackFemale-124-140Cuba<=50K
652Private278399Bachelors12Never-marriedProf-specialtyUnmarriedOtherMale122567-647Columbia<=50K
736Federal-gov229817HS-grad8Married-AF-spouseFarming-fishingNot-in-familyWhiteMale81938Portugal>50K
827Federal-gov306972Some-college8Never-marriedExec-managerialHusbandAsian-Pac-IslanderFemale42144339Japan>50K
928Local-gov4161611st-4th8DivorcedAdm-clericalUnmarriedWhiteFemale-349109061Guatemala>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", + "\n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", + "\n", + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-144.971907591418" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/03_Relational_Data_Modeling.ipynb b/tutorials/03_Relational_Data_Modeling.ipynb new file mode 100644 index 000000000..230061b00 --- /dev/null +++ b/tutorials/03_Relational_Data_Modeling.ipynb @@ -0,0 +1,1039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", + "\n", + "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", + "\n", + "From those stores we obtained a training of historical data \n", + "between 2010-02-05 and 2012-11-01. This historical data includes the sales of each department on a specific date.\n", + "In this notebook, we will show you step-by-step how to download the \"Walmart\" dataset, explain the structure and sample the data.\n", + "\n", + "In this demonstration we will show how SDV can be used to generate synthetic data. And lately, this data can be used to train machine learning models.\n", + "\n", + "*The dataset used in this example can be found in [Kaggle](https://www.kaggle.com/c/walmart-recruiting-store-sales-forecasting/data), but we will show how to download it from SDV.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data model summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

stores

\n", + "\n", + "| Field | Type | Subtype | Additional Properties |\n", + "|-------|-------------|---------|-----------------------|\n", + "| Store | id | integer | Primary key |\n", + "| Size | numerical | integer | |\n", + "| Type | categorical | | |\n", + "\n", + "Contains information about the 45 stores, indicating the type and size of store." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

features

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|-----------------------------|\n", + "| Store | id | integer | foreign key (stores.Store) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| IsHoliday | boolean | | |\n", + "| Fuel_Price | numerical | float | |\n", + "| Unemployment | numerical | float | |\n", + "| Temperature | numerical | float | |\n", + "| CPI | numerical | float | |\n", + "| MarkDown1 | numerical | float | |\n", + "| MarkDown2 | numerical | float | |\n", + "| MarkDown3 | numerical | float | |\n", + "| MarkDown4 | numerical | float | |\n", + "| MarkDown5 | numerical | float | |\n", + "\n", + "Contains historical training data, which covers to 2010-02-05 to 2012-11-01." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

depts

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|------------------------------|\n", + "| Store | id | integer | foreign key (stores.Stores) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| Weekly_Sales | numerical | float | |\n", + "| Dept | numerical | integer | |\n", + "| IsHoliday | boolean | | |\n", + "\n", + "Contains additional data related to the store, department, and regional activity for the given dates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load data\n", + "\n", + "Let's start downloading the data set. In this case, we will download the data set *walmart*. We will use the SDV function `load_demo`, we can specify the name of the dataset we want to use and if we want its Metadata object or not. To know more about the demo data [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.demo.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(dataset_name='walmart', metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our dataset is downloaded from an [Amazon S3 bucket](http://sdv-datasets.s3.amazonaws.com/index.html) that contains all available data sets of the `load_demo` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now visualize the metadata structure" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "stores\n", + "\n", + "stores\n", + "\n", + "Type : categorical\n", + "Size : numerical - integer\n", + "Store : id - integer\n", + "\n", + "Primary key: Store\n", + "Data path: stores.csv\n", + "\n", + "\n", + "\n", + "features\n", + "\n", + "features\n", + "\n", + "Date : datetime\n", + "MarkDown1 : numerical - float\n", + "Store : id - integer\n", + "IsHoliday : boolean\n", + "MarkDown4 : numerical - float\n", + "MarkDown3 : numerical - float\n", + "Fuel_Price : numerical - float\n", + "Unemployment : numerical - float\n", + "Temperature : numerical - float\n", + "MarkDown5 : numerical - float\n", + "MarkDown2 : numerical - float\n", + "CPI : numerical - float\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: features.csv\n", + "\n", + "\n", + "\n", + "stores->features\n", + "\n", + "\n", + "   features.Store -> stores.Store\n", + "\n", + "\n", + "\n", + "depts\n", + "\n", + "depts\n", + "\n", + "Date : datetime\n", + "Weekly_Sales : numerical - float\n", + "Store : id - integer\n", + "Dept : numerical - integer\n", + "IsHoliday : boolean\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: train.csv\n", + "\n", + "\n", + "\n", + "stores->depts\n", + "\n", + "\n", + "   depts.Store -> stores.Store\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And also validate that the metadata is correctly defined for our data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate(tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Create an instance of SDV and train the instance\n", + "\n", + "Once we download it, we have to create an SDV instance. With that instance, we have to analyze the loaded tables to generate a statistical model from the data. In this case, the process of adjusting the model is quickly because the dataset is small. However, with larger datasets it can be a slow process." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables=tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: We may not want to train the model every time we want to generate new synthetic data. We can [save](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) the SDV instance to [load](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) it later." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Generate synthetic data\n", + "\n", + "Once the instance is trained, we are ready to generate the synthetic data.\n", + "\n", + "The easiest way to generate synthetic data for the entire dataset is to call the `sample_all` method. By default, this method generates only 5 rows, but we can specify the row number that will be generated with the `num_rows` argument. To learn more about the available arguments, see [sample_all](https://sdv-dev.github.io/SDV/api/sdv.sampler.html#sdv.sampler.Sampler.sample_all)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'stores': 45, 'features': 8190, 'depts': 421570}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.modeler.table_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "samples = sdv.sample_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This returns a dictionary with a `pandas.DataFrame` for each table." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TypeSizeStore
0B1061060
1B680711
2C2536032
3A2232683
4B1668484
\n", + "
" + ], + "text/plain": [ + " Type Size Store\n", + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['stores'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateMarkDown1StoreIsHolidayMarkDown4MarkDown3Fuel_PriceUnemploymentTemperatureMarkDown5MarkDown2CPI
02009-12-13 04:59:49.832576000222.2688220False758.900545-1080.5694354.5749666.05279551.354606-4016.012141NaN224.973722
12011-07-20 14:40:40.4002447366857.7120280FalseNaN-551.0452573.0182966.72066980.723862NaNNaN204.595072
22011-10-12 08:23:54.9876582405530.9294250FalseNaNNaN3.4663587.25657073.954938NaNNaN202.591941
32012-11-26 09:22:59.377291776NaN0FalseNaN3703.1371913.6237256.59861572.041454NaN3925.456611169.457078
42013-05-13 15:22:12.213717760NaN0False3773.942572NaN3.3817955.11500283.05832215018.944605NaN182.801974
\n", + "
" + ], + "text/plain": [ + " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", + "\n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", + "\n", + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['features'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateWeekly_SalesStoreDeptIsHoliday
02012-04-19 04:40:04.3703828487523.202470036False
12012-10-24 12:49:45.63017651213675.425498017False
22012-06-05 05:02:35.5517614082402.017324031False
32012-05-31 22:15:14.903856896-12761.073176081False
42011-09-07 15:34:43.2398356483642.817612058False
\n", + "
" + ], + "text/plain": [ + " Date Weekly_Sales Store Dept IsHoliday\n", + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['depts'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may not want to generate data for all tables in the dataset, rather for just one table. This is possible with SDV using the `sample` method. To use it we only need to specify the name of the table we want to synthesize and the row numbers to generate. In this case, the \"walmart\" data set has 3 tables: stores, features and depts.\n", + "\n", + "In the following example, we will generate 1000 rows of the \"features\" table." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'features': Date MarkDown1 Store IsHoliday \\\n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", + " .. ... ... ... ... \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", + " \n", + " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", + " .. ... ... ... ... ... \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", + " \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", + " \n", + " [1000 rows x 12 columns]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.sample('features', 1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SDV has tools to evaluate the synthetic data generated and compare them with the original data. To see more about the evaluation tools, [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.evaluation.html#sdv.evaluation.evaluate)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/04_Working_with_Metadata.ipynb b/tutorials/04_Working_with_Metadata.ipynb new file mode 100644 index 000000000..5c79f8ebf --- /dev/null +++ b/tutorials/04_Working_with_Metadata.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Metadata\n", + "\n", + "In order to work with complex dataset structures you will need to pass additional information\n", + "about you data to SDV.\n", + "\n", + "This is done by using a Metadata.\n", + "\n", + "Let's have a quick look at how to do it.\n", + "\n", + "## Load demo data\n", + "\n", + "We will load the demo dataset included in SDV for the rest of the session." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A part from the tables dict, this is returning a Metadata object that contains all the information\n", + "about our dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'primary_key': 'user_id',\n", + " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", + " 'sessions': {'primary_key': 'session_id',\n", + " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'device': {'type': 'categorical'},\n", + " 'os': {'type': 'categorical'}}},\n", + " 'transactions': {'primary_key': 'transaction_id',\n", + " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'approved': {'type': 'boolean'}}}}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import Metadata\n", + "\n", + "new_meta = Metadata()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's see what our Metadata looks like right now:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}