diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 12d63d205..02e92fd67 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,6 +18,7 @@ jobs: - name: Build run: | + sudo apt install pandoc python -m pip install --upgrade pip pip install -e .[dev] make docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0f882e819..7357cc190 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,12 +24,12 @@ jobs: - if: matrix.os == 'ubuntu-latest' name: Install graphviz - Ubuntu run: | - sudo apt-get install graphviz + sudo apt-get install graphviz pandoc - if: matrix.os == 'macos-latest' name: Install graphviz - MacOS run: | - brew install graphviz + brew install graphviz pandoc - name: Install dependencies run: | diff --git a/.gitignore b/.gitignore index 241e9f48f..2e25954c9 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ ENV/ .*.swp sdv/data/ +tutorials/sdv.pkl +tutorials/demo_metadata.json diff --git a/.travis.yml b/.travis.yml index 8204e8dfc..8345682c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: # Command to install dependencies install: - sudo apt-get update - - sudo apt-get install graphviz + - sudo apt-get install graphviz pandoc - pip install -U tox-travis codecov after_success: codecov diff --git a/Makefile b/Makefile index 31f1d6d6f..5fc2f37f3 100644 --- a/Makefile +++ b/Makefile @@ -50,6 +50,7 @@ clean-pyc: ## remove Python file artifacts .PHONY: clean-docs clean-docs: ## remove previously built docs rm -f docs/api/*.rst + rm -rf docs/tutorials -$(MAKE) -C docs clean 2>/dev/null # this fails if sphinx is not yet installed .PHONY: clean-coverage @@ -110,7 +111,7 @@ test-readme: ## run the readme snippets .PHONY: test-tutorials test-tutorials: ## run the tutorial notebooks - jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 examples/?.\ *.ipynb --stdout > /dev/null + jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 tutorials/*.ipynb --stdout > /dev/null .PHONY: test test: test-unit test-readme test-tutorials ## test everything that needs test dependencies @@ -134,6 +135,7 @@ coverage: ## check code coverage quickly with the default Python .PHONY: docs docs: clean-docs ## generate Sphinx HTML documentation, including API docs + cp -r tutorials docs/tutorials sphinx-apidoc --separate --no-toc -o docs/api/ sdv $(MAKE) -C docs html diff --git a/docs/conf.py b/docs/conf.py index 1c7e3ea90..c99d64607 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,6 +32,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'm2r', + 'nbsphinx', 'sphinx.ext.autodoc', 'sphinx.ext.githubpages', 'sphinx.ext.viewcode', @@ -53,6 +54,9 @@ # The master toctree document. master_doc = 'index' +# Jupyter Notebooks +nbsphinx_execute = 'never' + # General information about the project. project = 'SDV' slug = 'sdv' diff --git a/docs/index.rst b/docs/index.rst index 4f136e234..b79a5add0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,9 +7,11 @@ Overview .. toctree:: - :caption: Advanced Usage - :maxdepth: 3 + :caption: User Guides + :maxdepth: 2 + tutorials/02_Single_Table_Modeling + tutorials/03_Relational_Data_Modeling metadata .. toctree:: diff --git a/docs/metadata.rst b/docs/metadata.rst index c9f4d8607..d7f7dd933 100644 --- a/docs/metadata.rst +++ b/docs/metadata.rst @@ -1,5 +1,5 @@ -Metadata -======== +Working with Metadata +===================== In order to use **SDV** you will need a ``Metadata`` object alongside your data. diff --git a/docs/tutorials/01_Quickstart.ipynb b/docs/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/docs/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/02_Single_Table_Modeling.ipynb b/docs/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..fa536bd3d --- /dev/null +++ b/docs/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM38
11UKNaN23
22USAF34
33ESNaN47
44ESF29
55UKF39
66FRNaN40
77ESM38
88ESF32
99ESF36
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKF48
11USANaN38
22USAM29
33BGM22
44USAM43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" + ] + } + ], + "source": [ + "model.fit(census.sample(1000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
050Local-gov1697191st-4th9Widowed?HusbandWhiteMale114838Columbia<=50K
132?1524791st-4th9Never-marriedAdm-clericalWifeBlackMale-422021Jamaica>50K
222Private69617Bachelors0Separated?HusbandWhiteMale61138Guatemala<=50K
325?65285810th16Married-civ-spouseHandlers-cleanersNot-in-familyWhiteFemale152-2739Cuba<=50K
443Private301956Some-college8Married-civ-spouse?WifeWhiteMale-133-1239India<=50K
566Private401171Prof-school13SeparatedProtective-servUnmarriedBlackFemale-124-140Cuba<=50K
652Private278399Bachelors12Never-marriedProf-specialtyUnmarriedOtherMale122567-647Columbia<=50K
736Federal-gov229817HS-grad8Married-AF-spouseFarming-fishingNot-in-familyWhiteMale81938Portugal>50K
827Federal-gov306972Some-college8Never-marriedExec-managerialHusbandAsian-Pac-IslanderFemale42144339Japan>50K
928Local-gov4161611st-4th8DivorcedAdm-clericalUnmarriedWhiteFemale-349109061Guatemala>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", + "\n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", + "\n", + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-144.971907591418" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/Demo - Walmart.ipynb b/docs/tutorials/03_Relational_Data_Modeling.ipynb similarity index 55% rename from examples/Demo - Walmart.ipynb rename to docs/tutorials/03_Relational_Data_Modeling.ipynb index 2e3303259..230061b00 100644 --- a/examples/Demo - Walmart.ipynb +++ b/docs/tutorials/03_Relational_Data_Modeling.ipynb @@ -4,7 +4,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Demo \"walmart\"\n", + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", "\n", "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", "\n", @@ -91,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "scrolled": true }, @@ -100,9 +104,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:09,230 - INFO - metadata - Loading table stores\n", - "2020-07-02 18:50:09,237 - INFO - metadata - Loading table features\n", - "2020-07-02 18:50:09,255 - INFO - metadata - Loading table depts\n" + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" ] } ], @@ -128,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -184,7 +188,7 @@ "\n", "stores->features\n", "\n", - "\n", + "\n", "   features.Store -> stores.Store\n", "\n", "\n", @@ -206,17 +210,17 @@ "\n", "stores->depts\n", "\n", - "\n", + "\n", "   depts.Store -> stores.Store\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -234,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -252,7 +256,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "scrolled": false }, @@ -261,128 +265,130 @@ "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:09,639 - INFO - modeler - Modeling stores\n", - "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer CategoricalTransformer for field Type\n", - "2020-07-02 18:50:09,640 - INFO - metadata - Loading transformer NumericalTransformer for field Size\n", - "2020-07-02 18:50:09,653 - INFO - modeler - Modeling features\n", - "2020-07-02 18:50:09,653 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown1\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 18:50:09,654 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown4\n", - "2020-07-02 18:50:09,655 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown3\n", - "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Fuel_Price\n", - "2020-07-02 18:50:09,656 - INFO - metadata - Loading transformer NumericalTransformer for field Unemployment\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field Temperature\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown5\n", - "2020-07-02 18:50:09,657 - INFO - metadata - Loading transformer NumericalTransformer for field MarkDown2\n", - "2020-07-02 18:50:09,659 - INFO - metadata - Loading transformer NumericalTransformer for field CPI\n", - "2020-07-02 18:50:09,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,760 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,817 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,871 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,929 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,958 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:09,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,076 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,158 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,184 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,211 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,239 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,268 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,297 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,325 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,352 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,378 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,410 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,440 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,467 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,494 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,519 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,547 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,572 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,602 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,628 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,656 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,686 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,768 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,799 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,827 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,858 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,886 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,939 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,967 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:10,996 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,028 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,058 - INFO - modeler - Modeling depts\n", - "2020-07-02 18:50:11,058 - INFO - metadata - Loading transformer DatetimeTransformer for field Date\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Weekly_Sales\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer NumericalTransformer for field Dept\n", - "2020-07-02 18:50:11,059 - INFO - metadata - Loading transformer BooleanTransformer for field IsHoliday\n", - "2020-07-02 18:50:11,169 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,445 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,461 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,474 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,489 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,505 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,536 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,568 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,603 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2020-07-02 18:50:11,618 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,634 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,683 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,698 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,713 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,728 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,741 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,754 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,769 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,782 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,800 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,814 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,829 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,844 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,859 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,874 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,887 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,900 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,914 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,928 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,940 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,955 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:11,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,003 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,017 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,032 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,045 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,058 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,069 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,084 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 18:50:12,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", - " return getattr(obj, method)(*args, **kwds)\n", - "2020-07-02 18:50:12,380 - INFO - modeler - Modeling Complete\n" + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" ] } ], @@ -413,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -422,7 +428,7 @@ "{'stores': 45, 'features': 8190, 'depts': 421570}" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -433,7 +439,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "scrolled": false }, @@ -451,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -484,32 +490,32 @@ " \n", " 0\n", " B\n", - " 85496\n", - " 3\n", + " 106106\n", + " 0\n", " \n", " \n", " 1\n", - " A\n", - " 178862\n", - " 4\n", + " B\n", + " 68071\n", + " 1\n", " \n", " \n", " 2\n", - " B\n", - " 69654\n", - " 5\n", + " C\n", + " 253603\n", + " 2\n", " \n", " \n", " 3\n", " A\n", - " 211981\n", - " 6\n", + " 223268\n", + " 3\n", " \n", " \n", " 4\n", - " A\n", - " 131188\n", - " 7\n", + " B\n", + " 166848\n", + " 4\n", " \n", " \n", "\n", @@ -517,14 +523,14 @@ ], "text/plain": [ " Type Size Store\n", - "0 B 85496 3\n", - "1 A 178862 4\n", - "2 B 69654 5\n", - "3 A 211981 6\n", - "4 A 131188 7" + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -535,7 +541,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -576,78 +582,78 @@ " \n", " \n", " 0\n", - " 2012-04-23 03:09:52.638271488\n", - " NaN\n", - " 3\n", + " 2009-12-13 04:59:49.832576000\n", + " 222.268822\n", + " 0\n", " False\n", + " 758.900545\n", + " -1080.569435\n", + " 4.574966\n", + " 6.052795\n", + " 51.354606\n", + " -4016.012141\n", " NaN\n", - " -9803.940199\n", - " 3.561375\n", - " 8.838728\n", - " 67.162475\n", - " NaN\n", - " 2703.17729\n", - " 186.471991\n", + " 224.973722\n", " \n", " \n", " 1\n", - " 2011-04-19 08:45:12.429521664\n", - " 483.892955\n", - " 3\n", + " 2011-07-20 14:40:40.400244736\n", + " 6857.712028\n", + " 0\n", " False\n", - " 7504.524416\n", " NaN\n", - " 3.495118\n", - " 7.360667\n", - " 42.785730\n", - " 2772.597105\n", + " -551.045257\n", + " 3.018296\n", + " 6.720669\n", + " 80.723862\n", + " NaN\n", " NaN\n", - " 192.048268\n", + " 204.595072\n", " \n", " \n", " 2\n", - " 2011-01-30 22:13:59.841415680\n", - " NaN\n", - " 3\n", + " 2011-10-12 08:23:54.987658240\n", + " 5530.929425\n", + " 0\n", " False\n", " NaN\n", " NaN\n", - " 3.361946\n", - " 7.524812\n", - " 34.945770\n", + " 3.466358\n", + " 7.256570\n", + " 73.954938\n", " NaN\n", " NaN\n", - " 192.100673\n", + " 202.591941\n", " \n", " \n", " 3\n", - " 2011-10-08 08:16:00.977235968\n", - " 4661.175670\n", - " 3\n", + " 2012-11-26 09:22:59.377291776\n", + " NaN\n", + " 0\n", " False\n", - " 3392.028528\n", - " -5837.649003\n", - " 2.994273\n", - " 7.993152\n", - " 66.818180\n", " NaN\n", + " 3703.137191\n", + " 3.623725\n", + " 6.598615\n", + " 72.041454\n", " NaN\n", - " 197.921916\n", + " 3925.456611\n", + " 169.457078\n", " \n", " \n", " 4\n", - " 2011-09-29 23:18:43.912751616\n", + " 2013-05-13 15:22:12.213717760\n", " NaN\n", - " 3\n", + " 0\n", " False\n", + " 3773.942572\n", " NaN\n", + " 3.381795\n", + " 5.115002\n", + " 83.058322\n", + " 15018.944605\n", " NaN\n", - " 3.116659\n", - " 4.945393\n", - " 48.979036\n", - " NaN\n", - " 2957.77612\n", - " 192.677797\n", + " 182.801974\n", " \n", " \n", "\n", @@ -655,28 +661,28 @@ ], "text/plain": [ " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", - "0 2012-04-23 03:09:52.638271488 NaN 3 False NaN \n", - "1 2011-04-19 08:45:12.429521664 483.892955 3 False 7504.524416 \n", - "2 2011-01-30 22:13:59.841415680 NaN 3 False NaN \n", - "3 2011-10-08 08:16:00.977235968 4661.175670 3 False 3392.028528 \n", - "4 2011-09-29 23:18:43.912751616 NaN 3 False NaN \n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", "\n", - " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", - "0 -9803.940199 3.561375 8.838728 67.162475 NaN \n", - "1 NaN 3.495118 7.360667 42.785730 2772.597105 \n", - "2 NaN 3.361946 7.524812 34.945770 NaN \n", - "3 -5837.649003 2.994273 7.993152 66.818180 NaN \n", - "4 NaN 3.116659 4.945393 48.979036 NaN \n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", "\n", - " MarkDown2 CPI \n", - "0 2703.17729 186.471991 \n", - "1 NaN 192.048268 \n", - "2 NaN 192.100673 \n", - "3 NaN 197.921916 \n", - "4 2957.77612 192.677797 " + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -687,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -721,42 +727,42 @@ " \n", " \n", " 0\n", - " 2011-04-22 15:24:18.057608704\n", - " 11196.989134\n", - " 3\n", - " 38\n", + " 2012-04-19 04:40:04.370382848\n", + " 7523.202470\n", + " 0\n", + " 36\n", " False\n", " \n", " \n", " 1\n", - " 2010-02-06 19:21:02.431538688\n", - " -14038.529503\n", - " 3\n", - " 29\n", + " 2012-10-24 12:49:45.630176512\n", + " 13675.425498\n", + " 0\n", + " 17\n", " False\n", " \n", " \n", " 2\n", - " 2012-06-04 16:23:09.227934976\n", - " -6519.738485\n", - " 3\n", - " 46\n", + " 2012-06-05 05:02:35.551761408\n", + " 2402.017324\n", + " 0\n", + " 31\n", " False\n", " \n", " \n", " 3\n", - " 2011-08-09 17:18:54.910250752\n", - " 23194.918038\n", - " 3\n", - " 45\n", + " 2012-05-31 22:15:14.903856896\n", + " -12761.073176\n", + " 0\n", + " 81\n", " False\n", " \n", " \n", " 4\n", - " 2010-09-01 23:10:54.986872576\n", - " 16761.426407\n", - " 3\n", - " 29\n", + " 2011-09-07 15:34:43.239835648\n", + " 3642.817612\n", + " 0\n", + " 58\n", " False\n", " \n", " \n", @@ -765,14 +771,14 @@ ], "text/plain": [ " Date Weekly_Sales Store Dept IsHoliday\n", - "0 2011-04-22 15:24:18.057608704 11196.989134 3 38 False\n", - "1 2010-02-06 19:21:02.431538688 -14038.529503 3 29 False\n", - "2 2012-06-04 16:23:09.227934976 -6519.738485 3 46 False\n", - "3 2011-08-09 17:18:54.910250752 23194.918038 3 45 False\n", - "4 2010-09-01 23:10:54.986872576 16761.426407 3 29 False" + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -792,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": { "scrolled": false }, @@ -801,198 +807,198 @@ "data": { "text/plain": [ "{'features': Date MarkDown1 Store IsHoliday \\\n", - " 0 2012-03-09 19:29:04.121987584 4547.265066 48 False \n", - " 1 2012-01-12 12:05:27.772331264 NaN 51 False \n", - " 2 2012-10-09 23:52:58.454752256 13291.983302 51 False \n", - " 3 2011-07-26 23:58:33.986580992 NaN 49 False \n", - " 4 2012-01-17 02:27:51.821400832 6865.756770 50 False \n", - " 5 2009-12-27 02:32:16.625306880 NaN 51 False \n", - " 6 2011-07-12 23:01:03.955540224 NaN 48 False \n", - " 7 2011-09-25 10:16:19.542280704 7114.339579 48 False \n", - " 8 2011-05-06 02:00:24.919497728 NaN 48 False \n", - " 9 2013-10-23 10:17:59.945650432 2267.991754 50 False \n", - " 10 2011-05-18 08:31:22.419635968 NaN 48 False \n", - " 11 2010-07-27 18:16:54.600127744 NaN 49 False \n", - " 12 2010-11-04 00:37:47.996870656 NaN 51 False \n", - " 13 2012-03-23 15:04:51.172607488 10192.668685 48 False \n", - " 14 2007-11-14 22:28:51.711726848 NaN 48 False \n", - " 15 2013-02-07 20:55:46.684175872 9146.998153 49 False \n", - " 16 2011-07-18 00:03:42.998976512 NaN 50 False \n", - " 17 2010-10-02 21:34:09.917561088 NaN 49 False \n", - " 18 2010-03-12 18:25:31.509475328 NaN 51 False \n", - " 19 2013-06-24 02:50:15.027540736 2943.979267 50 False \n", - " 20 2010-09-24 15:42:39.457930752 NaN 48 False \n", - " 21 2012-03-27 19:00:12.877346816 7750.605869 51 False \n", - " 22 2010-05-26 23:28:27.091251968 NaN 48 False \n", - " 23 2011-05-01 23:07:37.680983296 NaN 52 False \n", - " 24 2012-10-07 16:56:39.076689152 11838.898938 51 False \n", - " 25 2011-02-19 21:38:01.903840000 NaN 52 False \n", - " 26 2012-01-18 06:02:24.764590592 7428.686729 52 False \n", - " 27 2012-09-28 01:36:19.603453184 4962.148181 52 False \n", - " 28 2011-07-28 19:21:49.952510208 12773.901269 48 False \n", - " 29 2011-09-04 16:53:45.561477888 NaN 50 False \n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", " .. ... ... ... ... \n", - " 970 2011-06-14 18:48:47.646437632 7563.489728 48 False \n", - " 971 2011-05-11 09:27:03.321238016 3763.579757 50 False \n", - " 972 2012-10-27 08:44:16.987041024 -2107.559543 49 False \n", - " 973 2011-03-26 17:36:28.323915264 NaN 48 False \n", - " 974 2012-03-30 09:18:45.289797888 NaN 52 False \n", - " 975 2012-12-25 00:57:24.386340608 3483.807111 51 False \n", - " 976 2010-09-17 18:59:15.931293440 NaN 52 False \n", - " 977 2012-04-18 06:40:20.102041344 2869.173605 50 False \n", - " 978 2010-01-18 10:21:26.293353216 NaN 49 False \n", - " 979 2010-08-13 00:08:32.324416256 NaN 50 False \n", - " 980 2012-01-17 04:04:26.614297088 NaN 50 False \n", - " 981 2011-06-23 01:40:48.209999104 NaN 48 False \n", - " 982 2012-05-15 04:01:02.762140160 3935.159830 50 False \n", - " 983 2011-03-03 10:20:57.530126848 5704.369183 52 False \n", - " 984 2011-09-11 11:54:13.537061888 -3939.448245 49 False \n", - " 985 2012-07-01 08:38:12.122587648 NaN 49 False \n", - " 986 2013-01-15 14:24:51.253277952 7802.046675 52 False \n", - " 987 2011-03-16 12:56:22.305051648 NaN 49 False \n", - " 988 2012-02-09 21:21:51.383019776 3563.474259 50 False \n", - " 989 2011-04-28 07:35:57.312682752 NaN 49 False \n", - " 990 2012-07-06 11:32:42.706951424 4134.684319 48 False \n", - " 991 2012-09-10 11:27:22.375229184 2670.997722 48 False \n", - " 992 2010-04-19 00:59:34.282525696 NaN 51 False \n", - " 993 2012-05-21 02:19:56.773740288 725.161402 51 False \n", - " 994 2012-08-25 02:46:15.815232768 2736.817373 52 False \n", - " 995 2011-11-27 15:29:59.302161664 NaN 48 False \n", - " 996 2012-11-19 23:22:47.651802880 NaN 49 False \n", - " 997 2012-03-06 18:23:43.093342208 NaN 51 True \n", - " 998 2011-11-17 03:00:20.669487872 NaN 50 False \n", - " 999 2010-10-18 08:04:04.443984128 NaN 51 False \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", " \n", " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", - " 0 2262.924323 4314.258484 2.944899 12.207513 44.866684 \n", - " 1 648.327530 NaN 3.415876 8.296209 79.755000 \n", - " 2 7464.648400 4818.694540 3.438865 10.375869 53.011465 \n", - " 3 7722.936675 NaN 3.238451 8.340640 81.331769 \n", - " 4 385.720439 2933.746873 3.352123 4.517977 33.216711 \n", - " 5 NaN NaN 2.595952 8.423675 58.153432 \n", - " 6 NaN NaN 3.337125 7.289977 68.198691 \n", - " 7 1376.144152 -1375.405701 3.586827 4.433586 57.457667 \n", - " 8 NaN 885.964999 3.132711 7.941049 82.624443 \n", - " 9 4380.021000 4861.726545 4.617105 7.636990 64.616449 \n", - " 10 NaN NaN 3.490697 9.699187 113.890285 \n", - " 11 NaN NaN 3.337917 10.428851 76.354063 \n", - " 12 NaN NaN 3.016818 10.044523 52.925513 \n", - " 13 NaN 1505.002284 3.244111 7.329469 34.127333 \n", - " 14 NaN NaN 2.497412 12.247448 100.350997 \n", - " 15 7982.863954 7539.825400 3.909380 5.689230 65.945165 \n", - " 16 NaN NaN 2.895094 7.718521 54.750702 \n", - " 17 NaN NaN 3.722465 7.339952 56.152632 \n", - " 18 NaN NaN 2.665366 8.045471 78.257558 \n", - " 19 -3310.004199 -808.378062 4.379650 5.320308 39.604265 \n", - " 20 NaN 4604.622951 3.058964 8.584222 44.177016 \n", - " 21 3790.217545 3053.642374 3.091990 5.069122 25.777591 \n", - " 22 NaN NaN 2.931042 9.062068 41.649615 \n", - " 23 NaN NaN 3.721348 6.586825 67.785117 \n", - " 24 4610.034450 -617.120861 4.010528 8.686732 35.445785 \n", - " 25 NaN NaN 3.249523 8.450699 38.639637 \n", - " 26 4005.965014 -28.792693 3.808513 9.188265 57.584018 \n", - " 27 6433.867544 -3996.702966 2.976900 8.991618 6.179927 \n", - " 28 NaN NaN 3.208463 11.875517 50.400717 \n", - " 29 NaN NaN 3.455848 7.377382 56.480212 \n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", " .. ... ... ... ... ... \n", - " 970 7130.653517 613.434417 3.798558 8.849840 45.081481 \n", - " 971 -950.449794 8369.741392 3.691277 8.687173 21.625873 \n", - " 972 1423.082811 3915.095969 3.716980 7.812599 38.891102 \n", - " 973 NaN NaN 3.277952 7.881991 74.018427 \n", - " 974 NaN NaN 3.444190 4.147388 79.128181 \n", - " 975 5645.978692 3893.099917 3.821702 8.410268 87.162204 \n", - " 976 NaN NaN 2.640021 7.458197 35.588448 \n", - " 977 -2627.096894 -142.014936 3.320107 6.970687 49.700872 \n", - " 978 NaN NaN 2.710125 9.927722 83.348488 \n", - " 979 NaN NaN 3.139506 8.348364 46.708000 \n", - " 980 NaN NaN 3.578857 7.555184 25.595379 \n", - " 981 NaN -1147.726465 3.685507 9.407928 91.745352 \n", - " 982 375.025287 11130.134824 3.467141 6.285321 38.179293 \n", - " 983 NaN NaN 2.675670 10.258619 50.599203 \n", - " 984 NaN NaN 3.375990 9.175979 44.036130 \n", - " 985 NaN 649.171815 3.370836 6.948002 22.863602 \n", - " 986 1944.149217 896.663662 3.971550 6.242637 97.107161 \n", - " 987 NaN NaN 3.524965 7.059279 65.469921 \n", - " 988 2550.079436 121.672409 2.898533 5.929999 58.612442 \n", - " 989 NaN NaN 3.237647 6.992924 72.225708 \n", - " 990 NaN NaN 4.075160 8.392256 80.837964 \n", - " 991 -1543.390402 5264.766267 3.157407 7.017727 81.515109 \n", - " 992 NaN NaN 3.293817 8.719239 65.735676 \n", - " 993 932.568195 1845.259380 3.157982 7.494639 65.103553 \n", - " 994 -2905.284812 -1268.636611 3.776062 6.736775 46.083801 \n", - " 995 NaN NaN 3.424730 NaN 79.231321 \n", - " 996 9027.387381 NaN 3.960882 8.911136 26.948635 \n", - " 997 NaN NaN 3.764028 8.049413 111.723841 \n", - " 998 NaN 6261.182193 3.609560 6.024725 38.411214 \n", - " 999 NaN NaN 3.051021 7.855145 75.446009 \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", " \n", - " MarkDown5 MarkDown2 CPI \n", - " 0 7218.910186 NaN 154.437715 \n", - " 1 4928.572337 1860.474257 126.006100 \n", - " 2 4381.712904 6491.080062 182.505905 \n", - " 3 NaN 2075.661361 170.252793 \n", - " 4 -1363.927516 6383.011111 181.850615 \n", - " 5 NaN NaN 250.558631 \n", - " 6 NaN NaN 213.846371 \n", - " 7 1644.489130 4966.726612 189.105468 \n", - " 8 NaN 4689.948004 152.931090 \n", - " 9 3291.191563 2344.032770 152.780297 \n", - " 10 NaN NaN 195.602913 \n", - " 11 NaN NaN 176.381561 \n", - " 12 NaN NaN 147.881741 \n", - " 13 1668.882903 NaN 155.004680 \n", - " 14 NaN NaN 193.442985 \n", - " 15 1437.517079 4874.558939 152.281926 \n", - " 16 8439.806431 NaN 130.530240 \n", - " 17 NaN NaN 201.486725 \n", - " 18 NaN NaN 188.760741 \n", - " 19 -6907.068884 3246.451314 108.120768 \n", - " 20 NaN NaN 148.788886 \n", - " 21 5724.976324 4496.854628 190.036218 \n", - " 22 NaN NaN 181.746008 \n", - " 23 NaN NaN 200.412670 \n", - " 24 6637.078664 -1388.022496 87.164451 \n", - " 25 NaN NaN 195.622231 \n", - " 26 2922.749710 NaN 224.233432 \n", - " 27 -3042.801607 NaN 131.909287 \n", - " 28 7063.102432 NaN 146.156666 \n", - " 29 NaN NaN 176.694595 \n", - " .. ... ... ... \n", - " 970 2882.447318 NaN 148.323409 \n", - " 971 1918.923930 NaN 141.723698 \n", - " 972 151.136535 1221.530133 155.532651 \n", - " 973 NaN NaN 213.405340 \n", - " 974 NaN NaN 186.636286 \n", - " 975 2598.731935 -846.566855 190.957987 \n", - " 976 NaN NaN 153.917260 \n", - " 977 8372.618269 1768.408869 249.068471 \n", - " 978 NaN NaN 177.318454 \n", - " 979 NaN NaN 208.719040 \n", - " 980 NaN 754.021358 108.580497 \n", - " 981 NaN NaN 159.526588 \n", - " 982 7349.003146 -732.495611 173.559001 \n", - " 983 NaN NaN 140.910196 \n", - " 984 401.319810 NaN 175.340840 \n", - " 985 NaN NaN 171.683362 \n", - " 986 -465.292159 -744.114065 177.553651 \n", - " 987 NaN NaN 144.344477 \n", - " 988 567.400188 NaN 213.833680 \n", - " 989 NaN NaN 157.113353 \n", - " 990 NaN -232.801449 132.592027 \n", - " 991 -72.622621 NaN 207.754643 \n", - " 992 NaN NaN 132.575018 \n", - " 993 4955.778653 4683.646548 160.260401 \n", - " 994 2898.089751 NaN 129.801694 \n", - " 995 NaN NaN NaN \n", - " 996 NaN 608.026190 115.452177 \n", - " 997 NaN NaN 206.323679 \n", - " 998 NaN 7020.489423 178.966240 \n", - " 999 NaN NaN 152.636567 \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", " \n", " [1000 rows x 12 columns]}" ] }, - "execution_count": 13, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/5. Generate Metadata from Dataframes.ipynb b/docs/tutorials/04_Working_with_Metadata.ipynb similarity index 52% rename from examples/5. Generate Metadata from Dataframes.ipynb rename to docs/tutorials/04_Working_with_Metadata.ipynb index 91b001036..5c79f8ebf 100644 --- a/examples/5. Generate Metadata from Dataframes.ipynb +++ b/docs/tutorials/04_Working_with_Metadata.ipynb @@ -1,21 +1,72 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Metadata\n", + "\n", + "In order to work with complex dataset structures you will need to pass additional information\n", + "about you data to SDV.\n", + "\n", + "This is done by using a Metadata.\n", + "\n", + "Let's have a quick look at how to do it.\n", + "\n", + "## Load demo data\n", + "\n", + "We will load the demo dataset included in SDV for the rest of the session." + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from sdv import load_demo" + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A part from the tables dict, this is returning a Metadata object that contains all the information\n", + "about our dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "metadata, tables = load_demo(metadata=True)" + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" ] }, { @@ -58,60 +109,22 @@ ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "tables" + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -121,141 +134,92 @@ ] }, { - "cell_type": "code", - "execution_count": 6, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.add_table('users', data=tables['users'], primary_key='user_id')" + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "new_meta.add_table('sessions', data=tables['sessions'], primary_key='session_id',\n", - " parent='users', foreign_key='user_id')" + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "transactions_fields = {\n", - " 'timestamp': {\n", - " 'type': 'datetime',\n", - " 'format': '%Y-%m-%d'\n", - " }\n", - "}\n", - "new_meta.add_table('transactions', tables['transactions'], fields_metadata=transactions_fields,\n", - " primary_key='transaction_id', parent='sessions')" + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", - " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}},\n", - " 'primary_key': 'user_id'},\n", - " 'sessions': {'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'users', 'field': 'user_id'}},\n", - " 'os': {'type': 'categorical'},\n", - " 'device': {'type': 'categorical'}},\n", - " 'primary_key': 'session_id'},\n", - " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", - " 'format': '%Y-%m-%d'},\n", - " 'session_id': {'type': 'id',\n", - " 'subtype': 'integer',\n", - " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", - " 'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'approved': {'type': 'boolean'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'}},\n", - " 'primary_key': 'transaction_id'}}}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "new_meta.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "new_meta.to_dict() == metadata.to_dict()" + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 11, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "new_meta.to_json('demo_metadata.json')" + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "loaded = Metadata('demo_metadata.json')" + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" ] }, { - "cell_type": "code", - "execution_count": 13, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "loaded.to_dict() == new_meta.to_dict()" + "Let's see what our Metadata looks like right now:" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -278,10 +242,10 @@ "\n", "users\n", "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", "\n", "Primary key: user_id\n", "\n", @@ -291,10 +255,10 @@ "\n", "sessions\n", "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", "\n", "Primary key: session_id\n", "Foreign key (users): user_id\n", @@ -303,7 +267,7 @@ "\n", "users->sessions\n", "\n", - "\n", + "\n", "   sessions.user_id -> users.user_id\n", "\n", "\n", @@ -312,11 +276,11 @@ "\n", "transactions\n", "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", "\n", "Primary key: transaction_id\n", "Foreign key (sessions): session_id\n", @@ -325,114 +289,137 @@ "\n", "sessions->transactions\n", "\n", - "\n", + "\n", "   transactions.session_id -> sessions.session_id\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "metadata.visualize()" + "new_meta.visualize()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "gender : categorical\n", - "user_id : id - integer\n", - "country : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "os : categorical\n", - "device : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "timestamp : datetime\n", - "session_id : id - integer\n", - "transaction_id : id - integer\n", - "approved : boolean\n", - "amount : numerical - float\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], "text/plain": [ - "" + "True" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_meta.visualize()" + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" ] } ], diff --git a/examples/demo_metadata.json b/docs/tutorials/demo_metadata.json similarity index 100% rename from examples/demo_metadata.json rename to docs/tutorials/demo_metadata.json index 1becd1764..f59aa96da 100644 --- a/examples/demo_metadata.json +++ b/docs/tutorials/demo_metadata.json @@ -2,31 +2,30 @@ "tables": { "users": { "fields": { - "country": { + "gender": { "type": "categorical" }, + "age": { + "type": "numerical", + "subtype": "integer" + }, "user_id": { "type": "id", "subtype": "integer" }, - "gender": { + "country": { "type": "categorical" - }, - "age": { - "type": "numerical", - "subtype": "integer" } }, "primary_key": "user_id" }, "sessions": { "fields": { - "device": { + "os": { "type": "categorical" }, - "session_id": { - "type": "id", - "subtype": "integer" + "device": { + "type": "categorical" }, "user_id": { "type": "id", @@ -36,8 +35,9 @@ "field": "user_id" } }, - "os": { - "type": "categorical" + "session_id": { + "type": "id", + "subtype": "integer" } }, "primary_key": "session_id" @@ -60,12 +60,12 @@ "field": "session_id" } }, + "approved": { + "type": "boolean" + }, "transaction_id": { "type": "id", "subtype": "integer" - }, - "approved": { - "type": "boolean" } }, "primary_key": "transaction_id" diff --git a/examples/0. Quickstart - README.ipynb b/examples/0. Quickstart - README.ipynb deleted file mode 100644 index f6f2fecdc..000000000 --- a/examples/0. Quickstart - README.ipynb +++ /dev/null @@ -1,414 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Metadata\n", - "\n", - "\n", - "\n", - "users\n", - "\n", - "users\n", - "\n", - "user_id : id - integer\n", - "country : categorical\n", - "gender : categorical\n", - "age : numerical - integer\n", - "\n", - "Primary key: user_id\n", - "\n", - "\n", - "\n", - "sessions\n", - "\n", - "sessions\n", - "\n", - "session_id : id - integer\n", - "user_id : id - integer\n", - "device : categorical\n", - "os : categorical\n", - "\n", - "Primary key: session_id\n", - "Foreign key (users): user_id\n", - "\n", - "\n", - "\n", - "users->sessions\n", - "\n", - "\n", - "   sessions.user_id -> users.user_id\n", - "\n", - "\n", - "\n", - "transactions\n", - "\n", - "transactions\n", - "\n", - "transaction_id : id - integer\n", - "session_id : id - integer\n", - "timestamp : datetime\n", - "amount : numerical - float\n", - "approved : boolean\n", - "\n", - "Primary key: transaction_id\n", - "Foreign key (sessions): session_id\n", - "\n", - "\n", - "\n", - "sessions->transactions\n", - "\n", - "\n", - "   transactions.session_id -> sessions.session_id\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.visualize()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:24:31,338 - INFO - modeler - Modeling users\n", - "2020-07-02 14:24:31,339 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-07-02 14:24:31,340 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", - "2020-07-02 14:24:31,341 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-07-02 14:24:31,356 - INFO - modeler - Modeling sessions\n", - "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field device\n", - "2020-07-02 14:24:31,357 - INFO - metadata - Loading transformer CategoricalTransformer for field os\n", - "2020-07-02 14:24:31,371 - INFO - modeler - Modeling transactions\n", - "2020-07-02 14:24:31,372 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp\n", - "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer NumericalTransformer for field amount\n", - "2020-07-02 14:24:31,373 - INFO - metadata - Loading transformer BooleanTransformer for field approved\n", - "2020-07-02 14:24:31,386 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:56: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", - " return getattr(obj, method)(*args, **kwds)\n", - "2020-07-02 14:24:31,404 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " baseCov = np.cov(mat.T)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", - " c *= np.true_divide(1, fact)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", - " c *= np.true_divide(1, fact)\n", - "2020-07-02 14:24:31,413 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,418 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,425 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,431 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,435 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,471 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,485 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,502 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,516 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,530 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,563 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,629 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:24:31,734 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 ES F 50\n", - " 1 1 FR NaN 61\n", - " 2 2 UK F 19\n", - " 3 3 USA F 37\n", - " 4 4 UK M 14\n", - " 5 5 USA M 20\n", - " 6 6 FR F 41\n", - " 7 7 UK NaN 15\n", - " 8 8 USA F 17\n", - " 9 9 UK M 27,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 5 mobile android\n", - " 1 1 3 tablet ios\n", - " 2 2 5 mobile ios\n", - " 3 3 2 mobile ios\n", - " 4 4 1 mobile ios\n", - " 5 5 8 mobile ios\n", - " 6 6 3 mobile ios\n", - " 7 7 2 tablet ios\n", - " 8 8 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount \\\n", - " 0 0 1 2019-01-05 19:44:17.109082880 108.520480 \n", - " 1 1 0 2019-01-05 19:44:17.109082880 108.520446 \n", - " 2 2 4 2019-01-05 19:44:17.109082880 108.520456 \n", - " 3 3 4 2019-01-05 19:44:17.109082880 108.520463 \n", - " 4 4 0 2019-01-05 19:44:17.109082880 108.520451 \n", - " 5 5 4 2019-01-14 10:59:23.654829312 -293.744075 \n", - " 6 6 0 2019-01-08 21:06:22.124677120 -157.555578 \n", - " 7 7 4 2019-01-05 19:44:17.109082880 108.520405 \n", - " 8 8 4 2019-01-20 05:49:39.108628736 1436.388992 \n", - " 9 9 0 2019-01-13 01:21:03.496772608 -688.014399 \n", - " 10 10 5 2019-01-15 19:14:13.132234496 92.573682 \n", - " 11 11 0 2019-01-15 19:13:47.834800128 88.733610 \n", - " 12 12 6 2019-01-15 19:13:25.752608512 87.612329 \n", - " 13 13 6 2019-01-15 19:14:06.131340544 91.149239 \n", - " 14 14 4 2019-01-05 19:44:17.109082880 108.520315 \n", - " 15 15 4 2019-01-10 14:54:25.195068416 -405.271997 \n", - " 16 16 0 2019-01-06 11:03:17.424094208 321.267007 \n", - " 17 17 1 2019-01-15 19:13:30.394386944 89.473644 \n", - " 18 18 5 2019-01-15 19:13:23.178177536 87.084668 \n", - " 19 19 5 2019-01-15 19:13:53.814526976 87.328847 \n", - " 20 20 6 2019-01-15 19:13:47.778595840 87.698023 \n", - " 21 21 0 2019-01-05 19:44:17.109082880 108.520444 \n", - " 22 22 4 2019-01-03 13:01:55.231368192 -171.151299 \n", - " 23 23 4 2019-01-10 03:03:38.631104256 -711.512501 \n", - " 24 24 5 2019-01-15 19:13:39.923056128 92.779222 \n", - " 25 25 5 2019-01-15 19:13:54.239550976 87.046705 \n", - " 26 26 1 2019-01-15 19:13:51.475357952 89.768894 \n", - " 27 27 6 2019-01-15 19:13:38.378663168 85.864682 \n", - " 28 28 7 2019-01-20 22:53:01.679367168 84.550410 \n", - " 29 29 7 2019-01-20 22:53:01.672930560 84.739006 \n", - " 30 30 1 2019-01-20 22:53:01.444472832 84.790641 \n", - " 31 31 8 2019-01-20 22:53:01.404470784 84.595276 \n", - " 32 32 4 2019-01-05 19:44:17.109082880 108.520411 \n", - " 33 33 4 2019-01-15 17:22:02.305563904 -602.769803 \n", - " 34 34 1 2019-01-15 04:01:37.883063808 -493.475733 \n", - " 35 35 0 2019-01-15 19:13:20.536071936 82.887047 \n", - " 36 36 5 2019-01-15 19:13:50.685021696 90.577240 \n", - " 37 37 6 2019-01-15 19:14:05.128847616 91.543630 \n", - " 38 38 6 2019-01-15 19:13:10.650768896 90.478404 \n", - " 39 39 7 2019-01-20 22:53:01.680873472 84.876897 \n", - " 40 40 7 2019-01-20 22:53:01.735595520 85.080808 \n", - " 41 41 0 2019-01-20 22:53:01.332932352 84.584933 \n", - " 42 42 8 2019-01-20 22:53:01.451664640 84.779657 \n", - " \n", - " approved \n", - " 0 False \n", - " 1 False \n", - " 2 False \n", - " 3 False \n", - " 4 False \n", - " 5 True \n", - " 6 True \n", - " 7 False \n", - " 8 True \n", - " 9 True \n", - " 10 True \n", - " 11 True \n", - " 12 True \n", - " 13 True \n", - " 14 False \n", - " 15 True \n", - " 16 True \n", - " 17 True \n", - " 18 True \n", - " 19 True \n", - " 20 True \n", - " 21 False \n", - " 22 True \n", - " 23 True \n", - " 24 True \n", - " 25 True \n", - " 26 True \n", - " 27 True \n", - " 28 True \n", - " 29 True \n", - " 30 True \n", - " 31 True \n", - " 32 False \n", - " 33 True \n", - " 34 True \n", - " 35 True \n", - " 36 True \n", - " 37 True \n", - " 38 True \n", - " 39 True \n", - " 40 True \n", - " 41 True \n", - " 42 True }" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sdv.sample_all(10)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/1. Quickstart - Single Table - In Memory.ipynb b/examples/1. Quickstart - Single Table - In Memory.ipynb deleted file mode 100644 index b9f48bbd2..000000000 --- a/examples/1. Quickstart - Single Table - In Memory.ipynb +++ /dev/null @@ -1,430 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "data = pd.DataFrame({\n", - " 'index': [1, 2, 3, 4, 5, 6, 7, 8],\n", - " 'integer': [1, None, 1, 2, 1, 2, 3, 2],\n", - " 'float': [0.1, None, 0.1, 0.2, 0.1, 0.2, 0.3, 0.1],\n", - " 'categorical': ['a', 'b', 'a', 'b', 'a', None, 'c', None],\n", - " 'bool': [False, True, False, True, False, False, False, None],\n", - " 'nullable': [1, None, 3, None, 5, None, 7, None],\n", - " 'datetime': [\n", - " '2010-01-01', '2010-02-01', '2010-01-01', '2010-02-01',\n", - " '2010-01-01', '2010-02-01', '2010-03-01', None\n", - " ]\n", - "})\n", - "data['datetime'] = pd.to_datetime(data['datetime'])\n", - "\n", - "tables = {\n", - " 'data': data\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
011.00.1aFalse1.02010-01-01
12NaNNaNbTrueNaN2010-02-01
231.00.1aFalse3.02010-01-01
342.00.2bTrueNaN2010-02-01
451.00.1aFalse5.02010-01-01
562.00.2NoneFalseNaN2010-02-01
673.00.3cFalse7.02010-03-01
782.00.1NoneNoneNaNNaT
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable datetime\n", - "0 1 1.0 0.1 a False 1.0 2010-01-01\n", - "1 2 NaN NaN b True NaN 2010-02-01\n", - "2 3 1.0 0.1 a False 3.0 2010-01-01\n", - "3 4 2.0 0.2 b True NaN 2010-02-01\n", - "4 5 1.0 0.1 a False 5.0 2010-01-01\n", - "5 6 2.0 0.2 None False NaN 2010-02-01\n", - "6 7 3.0 0.3 c False 7.0 2010-03-01\n", - "7 8 2.0 0.1 None None NaN NaT" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"integer\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"float\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"categorical\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": False,\n", - " \"pii_category\": \"email\"\n", - " },\n", - " {\n", - " \"name\": \"bool\",\n", - " \"type\": \"boolean\",\n", - " },\n", - " {\n", - " \"name\": \"nullable\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"float\",\n", - " },\n", - " {\n", - " \"name\": \"datetime\",\n", - " \"type\": \"datetime\",\n", - " \"format\": \"%Y-%m-%d\"\n", - " },\n", - " ],\n", - " \"name\": \"data\",\n", - " \"primary_key\": \"index\"\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:11,511 - INFO - modeler - Modeling data\n", - "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field integer\n", - "2020-07-02 14:25:11,512 - INFO - metadata - Loading transformer NumericalTransformer for field float\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer CategoricalTransformer for field categorical\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer BooleanTransformer for field bool\n", - "2020-07-02 14:25:11,513 - INFO - metadata - Loading transformer NumericalTransformer for field nullable\n", - "2020-07-02 14:25:11,514 - INFO - metadata - Loading transformer DatetimeTransformer for field datetime\n", - "2020-07-02 14:25:11,551 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:11,564 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables={'data': data})" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexintegerfloatcategoricalboolnullabledatetime
00NaNNaNbFalseNaN2010-01-21 11:39:29.986688000
11NaNNaNNaNTrueNaN2010-02-17 13:51:34.408565760
223.00.304326cFalseNaN2010-02-25 12:07:19.982103552
332.00.174580aTrueNaN2010-01-21 04:04:59.336169472
443.00.208637aNaNNaNNaT
551.00.026796bNaNNaNNaT
662.00.166949NaNFalseNaN2010-01-28 23:23:34.873413888
771.00.086972bFalseNaN2010-01-08 09:44:47.101891840
\n", - "
" - ], - "text/plain": [ - " index integer float categorical bool nullable \\\n", - "0 0 NaN NaN b False NaN \n", - "1 1 NaN NaN NaN True NaN \n", - "2 2 3.0 0.304326 c False NaN \n", - "3 3 2.0 0.174580 a True NaN \n", - "4 4 3.0 0.208637 a NaN NaN \n", - "5 5 1.0 0.026796 b NaN NaN \n", - "6 6 2.0 0.166949 NaN False NaN \n", - "7 7 1.0 0.086972 b False NaN \n", - "\n", - " datetime \n", - "0 2010-01-21 11:39:29.986688000 \n", - "1 2010-02-17 13:51:34.408565760 \n", - "2 2010-02-25 12:07:19.982103552 \n", - "3 2010-01-21 04:04:59.336169472 \n", - "4 NaT \n", - "5 NaT \n", - "6 2010-01-28 23:23:34.873413888 \n", - "7 2010-01-08 09:44:47.101891840 " - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples = sdv.sample_all()\n", - "samples['data']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/2. Quickstart - Single Table - Census.ipynb b/examples/2. Quickstart - Single Table - Census.ipynb deleted file mode 100644 index 6857b7827..000000000 --- a/examples/2. Quickstart - Single Table - Census.ipynb +++ /dev/null @@ -1,562 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'\n", - "columns = [\n", - " 'age',\n", - " 'workclass',\n", - " 'fnlwgt',\n", - " 'education',\n", - " 'education-num',\n", - " 'marital-status',\n", - " 'occupation',\n", - " 'relationship',\n", - " 'race',\n", - " 'sex',\n", - " 'capital-gain',\n", - " 'capital-loss',\n", - " 'hours-per-week',\n", - " 'native-country',\n", - " 'income'\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num \\\n", - "0 39 State-gov 77516 Bachelors 13 \n", - "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", - "2 38 Private 215646 HS-grad 9 \n", - "3 53 Private 234721 11th 7 \n", - "4 28 Private 338409 Bachelors 13 \n", - "\n", - " marital-status occupation relationship race sex \\\n", - "0 Never-married Adm-clerical Not-in-family White Male \n", - "1 Married-civ-spouse Exec-managerial Husband White Male \n", - "2 Divorced Handlers-cleaners Not-in-family White Male \n", - "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", - "4 Married-civ-spouse Prof-specialty Wife Black Female \n", - "\n", - " capital-gain capital-loss hours-per-week native-country income \n", - "0 2174 0 40 United-States <=50K \n", - "1 0 0 13 United-States <=50K \n", - "2 0 0 40 United-States <=50K \n", - "3 0 0 40 United-States <=50K \n", - "4 0 0 40 Cuba <=50K " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pd.read_csv(url, names=columns)\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'census': df\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"age\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"workclass\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"fnlwgt\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"education\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"education-num\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"marital-status\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"occupation\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"relationship\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"race\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"sex\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"capital-gain\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"capital-loss\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"hours-per-week\",\n", - " \"type\": \"numerical\",\n", - " \"subtype\": \"integer\",\n", - " },\n", - " {\n", - " \"name\": \"native-country\",\n", - " \"type\": \"categorical\",\n", - " },\n", - " {\n", - " \"name\": \"income\",\n", - " \"type\": \"categorical\",\n", - " }\n", - " ],\n", - " \"name\": \"census\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:31,656 - INFO - modeler - Modeling census\n", - "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2020-07-02 14:25:31,657 - INFO - metadata - Loading transformer CategoricalTransformer for field workclass\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field fnlwgt\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer CategoricalTransformer for field education\n", - "2020-07-02 14:25:31,658 - INFO - metadata - Loading transformer NumericalTransformer for field education-num\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field marital-status\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field occupation\n", - "2020-07-02 14:25:31,659 - INFO - metadata - Loading transformer CategoricalTransformer for field relationship\n", - "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field race\n", - "2020-07-02 14:25:31,660 - INFO - metadata - Loading transformer CategoricalTransformer for field sex\n", - "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-gain\n", - "2020-07-02 14:25:31,661 - INFO - metadata - Loading transformer NumericalTransformer for field capital-loss\n", - "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer NumericalTransformer for field hours-per-week\n", - "2020-07-02 14:25:31,662 - INFO - metadata - Loading transformer CategoricalTransformer for field native-country\n", - "2020-07-02 14:25:31,663 - INFO - metadata - Loading transformer CategoricalTransformer for field income\n", - "2020-07-02 14:25:31,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:31,928 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
035Private468185HS-grad9Married-civ-spouseSalesHusbandWhiteMale203943648United-States<=50K
147Private104447HS-grad7Married-civ-spouseAdm-clericalHusbandWhiteMale-7038153United-States<=50K
220Private231391HS-grad11Never-marriedExec-managerialNot-in-familyWhiteMale654-857United-States<=50K
335Private223275Masters12Married-civ-spouseSalesNot-in-familyWhiteMale-392-17871United-States<=50K
426Private-11408HS-grad8Married-civ-spouseMachine-op-inspctHusbandWhiteMale57991347United-States<=50K
\n", - "
" - ], - "text/plain": [ - " age workclass fnlwgt education education-num marital-status \\\n", - "0 35 Private 468185 HS-grad 9 Married-civ-spouse \n", - "1 47 Private 104447 HS-grad 7 Married-civ-spouse \n", - "2 20 Private 231391 HS-grad 11 Never-married \n", - "3 35 Private 223275 Masters 12 Married-civ-spouse \n", - "4 26 Private -11408 HS-grad 8 Married-civ-spouse \n", - "\n", - " occupation relationship race sex capital-gain \\\n", - "0 Sales Husband White Male 2039 \n", - "1 Adm-clerical Husband White Male -703 \n", - "2 Exec-managerial Not-in-family White Male 654 \n", - "3 Sales Not-in-family White Male -392 \n", - "4 Machine-op-inspct Husband White Male 5799 \n", - "\n", - " capital-loss hours-per-week native-country income \n", - "0 436 48 United-States <=50K \n", - "1 81 53 United-States <=50K \n", - "2 -8 57 United-States <=50K \n", - "3 -178 71 United-States <=50K \n", - "4 13 47 United-States <=50K " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled = sdv.sample('census', num_rows=len(df))\n", - "sampled['census'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "-43.151506729008716" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from sdv.evaluation import evaluate\n", - "\n", - "samples = sdv.sample_all(len(tables['census']))\n", - "\n", - "evaluate(samples, real=tables, metadata=sdv.metadata)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/3. Quickstart - Multitable - Files.ipynb b/examples/3. Quickstart - Multitable - Files.ipynb deleted file mode 100644 index d00a5989e..000000000 --- a/examples/3. Quickstart - Multitable - Files.ipynb +++ /dev/null @@ -1,653 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:44,066 - INFO - modeler - Modeling customers\n", - "2020-07-02 14:25:44,067 - INFO - metadata - Loading table customers\n", - "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer CategoricalTransformer for field cust_postal_code\n", - "2020-07-02 14:25:44,074 - INFO - metadata - Loading transformer NumericalTransformer for field phone_number1\n", - "2020-07-02 14:25:44,075 - INFO - metadata - Loading transformer NumericalTransformer for field credit_limit\n", - "2020-07-02 14:25:44,076 - INFO - metadata - Loading transformer CategoricalTransformer for field country\n", - "2020-07-02 14:25:44,094 - INFO - modeler - Modeling orders\n", - "2020-07-02 14:25:44,095 - INFO - metadata - Loading table orders\n", - "2020-07-02 14:25:44,098 - INFO - metadata - Loading transformer NumericalTransformer for field order_total\n", - "2020-07-02 14:25:44,101 - INFO - modeler - Modeling order_items\n", - "2020-07-02 14:25:44,102 - INFO - metadata - Loading table order_items\n", - "2020-07-02 14:25:44,106 - INFO - metadata - Loading transformer CategoricalTransformer for field product_id\n", - "2020-07-02 14:25:44,107 - INFO - metadata - Loading transformer NumericalTransformer for field unit_price\n", - "2020-07-02 14:25:44,108 - INFO - metadata - Loading transformer NumericalTransformer for field quantity\n", - "2020-07-02 14:25:44,120 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,131 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,138 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,147 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,155 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,164 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,171 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,177 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,183 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,189 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,196 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,210 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,231 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,258 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,281 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,303 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,370 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:25:44,496 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('quickstart/metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:25:46,641 - INFO - metadata - Loading table customers\n", - "2020-07-02 14:25:46,646 - INFO - metadata - Loading table orders\n", - "2020-07-02 14:25:46,649 - INFO - metadata - Loading table order_items\n" - ] - } - ], - "source": [ - "real = sdv.metadata.load_tables()\n", - "\n", - "samples = sdv.sample_all(len(real['customers']), reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
00631457811758514510US
11201666720163478317SPAIN
221137174478617191371SPAIN
336314568350267041024US
44113715371158268739FRANCE
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 0 63145 7811758514 510 US\n", - "1 1 20166 6720163478 317 SPAIN\n", - "2 2 11371 7447861719 1371 SPAIN\n", - "3 3 63145 6835026704 1024 US\n", - "4 4 11371 5371158268 739 FRANCE" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
customer_idcust_postal_codephone_number1credit_limitcountry
0501137161755532951000UK
14631458605551835500US
297338810609670355521431000CANADA
36304076314570355521432000US
48263621137161755532951000UK
\n", - "
" - ], - "text/plain": [ - " customer_id cust_postal_code phone_number1 credit_limit country\n", - "0 50 11371 6175553295 1000 UK\n", - "1 4 63145 8605551835 500 US\n", - "2 97338810 6096 7035552143 1000 CANADA\n", - "3 630407 63145 7035552143 2000 US\n", - "4 826362 11371 6175553295 1000 UK" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['customers'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
0001784
1111864
2242677
3311485
4412435
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 0 0 1784\n", - "1 1 1 1864\n", - "2 2 4 2677\n", - "3 3 1 1485\n", - "4 4 1 2435" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_idcustomer_idorder_total
01502310
1241507
21097338810730
3655996144730
4355996144939
\n", - "
" - ], - "text/plain": [ - " order_id customer_id order_total\n", - "0 1 50 2310\n", - "1 2 4 1507\n", - "2 10 97338810 730\n", - "3 6 55996144 730\n", - "4 3 55996144 939" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['orders'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
00262087
117101162
22010576
3336-241
4436384
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 0 2 6 208 7\n", - "1 1 7 10 116 2\n", - "2 2 0 10 57 6\n", - "3 3 3 6 -24 1\n", - "4 4 3 6 38 4" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['order_items'].head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
order_item_idorder_idproduct_idunit_pricequantity
0100107528
1101861254
2102161254
3103491254
4104191134
\n", - "
" - ], - "text/plain": [ - " order_item_id order_id product_id unit_price quantity\n", - "0 100 10 7 52 8\n", - "1 101 8 6 125 4\n", - "2 102 1 6 125 4\n", - "3 103 4 9 125 4\n", - "4 104 1 9 113 4" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['order_items'].head()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/4. Anonymization.ipynb b/examples/4. Anonymization.ipynb deleted file mode 100644 index 5ff4caa51..000000000 --- a/examples/4. Anonymization.ipynb +++ /dev/null @@ -1,398 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "data = pd.DataFrame([\n", - " {\n", - " 'index': 1,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '1111222233334444'\n", - " },\n", - " {\n", - " 'index': 2,\n", - " 'name': 'Jeff',\n", - " 'credit_card_number': '0000000000000000'\n", - " },\n", - " {\n", - " 'index': 3,\n", - " 'name': 'Bill',\n", - " 'credit_card_number': '9999999999999999'\n", - " },\n", - " {\n", - " 'index': 4,\n", - " 'name': 'Joe',\n", - " 'credit_card_number': '8888888888888888'\n", - " },\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
credit_card_numberindexname
011112222333344441Bill
100000000000000002Jeff
299999999999999993Bill
388888888888888884Joe
\n", - "
" - ], - "text/plain": [ - " credit_card_number index name\n", - "0 1111222233334444 1 Bill\n", - "1 0000000000000000 2 Jeff\n", - "2 9999999999999999 3 Bill\n", - "3 8888888888888888 4 Joe" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = {\n", - " \"tables\": [\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": \"first_name\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\",\n", - " \"pii\": True,\n", - " \"pii_category\": [\n", - " \"credit_card_number\",\n", - " \"visa\"\n", - " ]\n", - " }\n", - " ],\n", - " \"name\": \"anonymized\",\n", - " \"primary_key\": \"index\",\n", - " },\n", - " {\n", - " \"fields\": [\n", - " {\n", - " \"name\": \"index\",\n", - " \"type\": \"id\"\n", - " },\n", - " {\n", - " \"name\": \"name\",\n", - " \"type\": \"categorical\"\n", - " },\n", - " {\n", - " \"name\": \"credit_card_number\",\n", - " \"type\": \"categorical\"\n", - " }\n", - " ],\n", - " \"name\": \"normal\",\n", - " \"primary_key\": \"index\",\n", - " }\n", - " ]\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "tables = {\n", - " 'anonymized': data,\n", - " 'normal': data.copy()\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2020-07-02 14:26:26,044 - INFO - modeler - Modeling anonymized\n", - "2020-07-02 14:26:26,044 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-07-02 14:26:26,045 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-07-02 14:26:26,087 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:26:26,092 - INFO - modeler - Modeling normal\n", - "2020-07-02 14:26:26,092 - INFO - metadata - Loading transformer CategoricalTransformer for field name\n", - "2020-07-02 14:26:26,093 - INFO - metadata - Loading transformer CategoricalTransformer for field credit_card_number\n", - "2020-07-02 14:26:26,109 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", - "2020-07-02 14:26:26,113 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit(metadata, tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "sampled = sdv.sample_all()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Pamela4801288395665668
11Kimberly4801288395665668
22Pamela4801288395665668
33Kimberly4592405566223480
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Pamela 4801288395665668\n", - "1 1 Kimberly 4801288395665668\n", - "2 2 Pamela 4801288395665668\n", - "3 3 Kimberly 4592405566223480" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['anonymized']" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
indexnamecredit_card_number
00Joe8888888888888888
11Joe0000000000000000
22Joe8888888888888888
33Bill8888888888888888
\n", - "
" - ], - "text/plain": [ - " index name credit_card_number\n", - "0 0 Joe 8888888888888888\n", - "1 1 Joe 0000000000000000\n", - "2 2 Joe 8888888888888888\n", - "3 3 Bill 8888888888888888" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sampled['normal']" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/6. Metadata Validation.ipynb b/examples/6. Metadata Validation.ipynb deleted file mode 100644 index 70ab58913..000000000 --- a/examples/6. Metadata Validation.ipynb +++ /dev/null @@ -1,239 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from sdv import load_demo" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "metadata, tables = load_demo(metadata=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'tables': {'users': {'primary_key': 'user_id',\n", - " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'country': {'type': 'categorical'},\n", - " 'gender': {'type': 'categorical'},\n", - " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", - " 'sessions': {'primary_key': 'session_id',\n", - " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'device': {'type': 'categorical'},\n", - " 'os': {'type': 'categorical'}}},\n", - " 'transactions': {'primary_key': 'transaction_id',\n", - " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", - " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", - " 'type': 'id',\n", - " 'subtype': 'integer'},\n", - " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", - " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", - " 'approved': {'type': 'boolean'}}}}}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata.to_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'users': user_id country gender age\n", - " 0 0 USA M 34\n", - " 1 1 UK F 23\n", - " 2 2 ES None 44\n", - " 3 3 UK M 22\n", - " 4 4 USA F 54\n", - " 5 5 DE M 57\n", - " 6 6 BG F 45\n", - " 7 7 ES None 41\n", - " 8 8 FR F 23\n", - " 9 9 UK None 30,\n", - " 'sessions': session_id user_id device os\n", - " 0 0 0 mobile android\n", - " 1 1 1 tablet ios\n", - " 2 2 1 tablet android\n", - " 3 3 2 mobile android\n", - " 4 4 4 mobile ios\n", - " 5 5 5 mobile android\n", - " 6 6 6 mobile ios\n", - " 7 7 6 tablet ios\n", - " 8 8 6 mobile ios\n", - " 9 9 8 tablet ios,\n", - " 'transactions': transaction_id session_id timestamp amount approved\n", - " 0 0 0 2019-01-01 12:34:32 100.0 True\n", - " 1 1 0 2019-01-01 12:42:21 55.3 True\n", - " 2 2 1 2019-01-07 17:23:11 79.5 True\n", - " 3 3 3 2019-01-10 11:08:57 112.1 False\n", - " 4 4 5 2019-01-10 21:54:08 110.0 False\n", - " 5 5 5 2019-01-11 11:21:20 76.3 True\n", - " 6 6 7 2019-01-22 14:44:10 89.5 True\n", - " 7 7 8 2019-01-23 10:14:09 132.1 False\n", - " 8 8 9 2019-01-27 16:09:17 68.0 True\n", - " 9 9 9 2019-01-29 12:10:48 99.9 True}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate(tables)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'country'" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: id field `user_id` is neither a primary or a foreign key\n" - ] - } - ], - "source": [ - "from sdv.metadata import MetadataError\n", - "\n", - "try:\n", - " metadata.validate()\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['primary_key'] = 'user_id'" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "metadata._metadata['tables']['users']['fields']['gender']['type'] = 'numerical'" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "metadata.validate()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: Invalid values found in column `gender` of table `users`: `could not convert string to float: 'M'`\n" - ] - } - ], - "source": [ - "try:\n", - " metadata.validate(tables)\n", - "except MetadataError as me:\n", - " print('Error: {}'.format(me))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb deleted file mode 100644 index d4b39df8b..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Dataset Subsampling.ipynb +++ /dev/null @@ -1,206 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Subsampling\n", - "\n", - "This notebooks shows how to generate the `users_sample.csv` and `sessions_sample.csv` files.\n", - "Before running this notebook, make sure to having downloaded the `train_users_2.csv` and `sessions.csv`\n", - "files from https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings/data" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "users = pd.read_csv('train_users_2.csv')\n", - "sessions = pd.read_csv('sessions.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(213451, 16)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "users.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(10567737, 6)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample = users.sample(1000).reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample = sessions[\n", - " sessions['user_id'].isin(users_sample['id'])\n", - "].reset_index(drop=True).copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(24988, 6)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sessions_sample.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "user_ids = users_sample['id'].reset_index().set_index('id')['index']" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id\n", - "c2cga16ny5 0\n", - "08yitv7nz0 1\n", - "3oobewtnls 2\n", - "owd2csvj6u 3\n", - "i66dlbdyrt 4\n", - "Name: index, dtype: int64" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "user_ids.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample['id'] = users_sample['id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample['user_id'] = sessions_sample['user_id'].map(user_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "users_sample.to_csv('users_sample.csv', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "sessions_sample.to_csv('sessions_sample.csv', index=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb b/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb deleted file mode 100644 index e4b43f916..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/Aibnb Example.ipynb +++ /dev/null @@ -1,730 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Airbnb Dataset Synthesis\n", - "\n", - "This notebook shows a usage example over a sample of the Airbnb New User Bookings dataset\n", - "available on Kaggle: https://www.kaggle.com/c/airbnb-recruiting-new-user-bookings\n", - "\n", - "Before running this notebook, make sure to having downloaded the data and run the\n", - "`Airbnb Dataset Sampling` notebook which can be found in the same folder as this one." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:17,764 - INFO - modeler - Modeling users\n", - "2019-11-03 16:07:17,764 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:17,781 - INFO - metadata - Loading transformer DatetimeTransformer for field date_account_created\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field timestamp_first_active\n", - "2019-11-03 16:07:17,782 - INFO - metadata - Loading transformer DatetimeTransformer for field date_first_booking\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field gender\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer NumericalTransformer for field age\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_method\n", - "2019-11-03 16:07:17,783 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_flow\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field language\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_channel\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field affiliate_provider\n", - "2019-11-03 16:07:17,784 - INFO - metadata - Loading transformer CategoricalTransformer for field first_affiliate_tracked\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field signup_app\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_device_type\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field first_browser\n", - "2019-11-03 16:07:17,785 - INFO - metadata - Loading transformer CategoricalTransformer for field country_destination\n", - "2019-11-03 16:07:17,852 - INFO - modeler - Modeling sessions\n", - "2019-11-03 16:07:17,853 - INFO - metadata - Loading table sessions\n", - "2019-11-03 16:07:17,872 - INFO - metadata - Loading transformer CategoricalTransformer for field action\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_type\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field action_detail\n", - "2019-11-03 16:07:17,873 - INFO - metadata - Loading transformer CategoricalTransformer for field device_type\n", - "2019-11-03 16:07:17,874 - INFO - metadata - Loading transformer NumericalTransformer for field secs_elapsed\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " baseCov = np.cov(mat.T)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: divide by zero encountered in true_divide\n", - " c *= np.true_divide(1, fact)\n", - "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2451: RuntimeWarning: invalid value encountered in multiply\n", - " c *= np.true_divide(1, fact)\n", - "2019-11-03 16:07:22,394 - INFO - modeler - Modeling Complete\n" - ] - } - ], - "source": [ - "from sdv import SDV\n", - "\n", - "sdv = SDV()\n", - "sdv.fit('metadata.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-11-03 16:07:22,443 - INFO - metadata - Loading table users\n", - "2019-11-03 16:07:22,454 - INFO - metadata - Loading table sessions\n" - ] - } - ], - "source": [ - "real = sdv.metadata.get_tables()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
7907902011-04-302011-04-30 17:58:152012-03-22FEMALE36.0basic2ensem-non-brandgoogleuntrackedWebMac DesktopFirefoxUS
5175172011-09-202011-09-20 08:14:062011-09-22-unknown-NaNbasic2ensem-brandgoogleomgWebMac DesktopChromeUS
9029022014-05-172014-05-17 01:03:00NaTMALE48.0basic24endirectdirectuntrackedMowebiPhoneMobile SafariNDF
70702013-10-212013-10-21 21:50:15NaT-unknown-NaNbasic0endirectdirectlinkedWebMac DesktopFirefoxNDF
6086082013-04-012013-04-01 20:05:132013-04-02-unknown-NaNbasic24endirectdirectlinkedMowebMac DesktopSafariother
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active date_first_booking \\\n", - "790 790 2011-04-30 2011-04-30 17:58:15 2012-03-22 \n", - "517 517 2011-09-20 2011-09-20 08:14:06 2011-09-22 \n", - "902 902 2014-05-17 2014-05-17 01:03:00 NaT \n", - "70 70 2013-10-21 2013-10-21 21:50:15 NaT \n", - "608 608 2013-04-01 2013-04-01 20:05:13 2013-04-02 \n", - "\n", - " gender age signup_method signup_flow language affiliate_channel \\\n", - "790 FEMALE 36.0 basic 2 en sem-non-brand \n", - "517 -unknown- NaN basic 2 en sem-brand \n", - "902 MALE 48.0 basic 24 en direct \n", - "70 -unknown- NaN basic 0 en direct \n", - "608 -unknown- NaN basic 24 en direct \n", - "\n", - " affiliate_provider first_affiliate_tracked signup_app first_device_type \\\n", - "790 google untracked Web Mac Desktop \n", - "517 google omg Web Mac Desktop \n", - "902 direct untracked Moweb iPhone \n", - "70 direct linked Web Mac Desktop \n", - "608 direct linked Moweb Mac Desktop \n", - "\n", - " first_browser country_destination \n", - "790 Firefox US \n", - "517 Chrome US \n", - "902 Mobile Safari NDF \n", - "70 Firefox NDF \n", - "608 Safari other " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
6345101search_resultsclickview_search_resultsMac Desktop1592.0
541772lookupNaNNaNMac Desktop651.0
22406823showviewp3Android Phone1198.0
6279101search_resultsclickview_search_resultsMac Desktop3389.0
6561883ajax_refresh_subtotalclickchange_trip_characteristicsMac Desktop268.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type \\\n", - "6345 101 search_results click \n", - "5417 72 lookup NaN \n", - "22406 823 show view \n", - "6279 101 search_results click \n", - "6561 883 ajax_refresh_subtotal click \n", - "\n", - " action_detail device_type secs_elapsed \n", - "6345 view_search_results Mac Desktop 1592.0 \n", - "5417 NaN Mac Desktop 651.0 \n", - "22406 p3 Android Phone 1198.0 \n", - "6279 view_search_results Mac Desktop 3389.0 \n", - "6561 change_trip_characteristics Mac Desktop 268.0 " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "real['sessions'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "length = len(real['users'])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1000" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "length" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "samples = sdv.sample_all(length, reset_primary_keys=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iddate_account_createdtimestamp_first_activedate_first_bookinggenderagesignup_methodsignup_flowlanguageaffiliate_channelaffiliate_providerfirst_affiliate_trackedsignup_appfirst_device_typefirst_browsercountry_destination
8578572014-03-16 21:04:00.3134963202014-03-15 05:39:43.249984000NaT-unknown-45.0facebook0endirectdirectuntrackedWebiPhoneIENDF
4894892014-10-03 17:46:33.6981690882014-10-14 06:11:35.937526784NaTFEMALE56.0basic0endirectdirectuntrackedWebWindows Desktop-unknown-NDF
7807802014-03-02 14:25:57.2513733122014-03-05 14:28:40.767416832NaT-unknown-NaNbasic0endirectdirectlinkedWebWindows DesktopSafariNDF
5965962014-01-12 05:40:57.1618867202014-01-18 03:57:01.785481984NaTMALE51.0basic0ensem-brandgooglelinkedWebMac DesktopSafariNDF
7767762013-05-28 14:15:41.3111449602013-05-28 06:13:45.2424908802013-08-07 07:45:31.016155904FEMALE46.0basic0endirectdirectuntrackedWebMac DesktopChromeNDF
\n", - "
" - ], - "text/plain": [ - " id date_account_created timestamp_first_active \\\n", - "857 857 2014-03-16 21:04:00.313496320 2014-03-15 05:39:43.249984000 \n", - "489 489 2014-10-03 17:46:33.698169088 2014-10-14 06:11:35.937526784 \n", - "780 780 2014-03-02 14:25:57.251373312 2014-03-05 14:28:40.767416832 \n", - "596 596 2014-01-12 05:40:57.161886720 2014-01-18 03:57:01.785481984 \n", - "776 776 2013-05-28 14:15:41.311144960 2013-05-28 06:13:45.242490880 \n", - "\n", - " date_first_booking gender age signup_method signup_flow \\\n", - "857 NaT -unknown- 45.0 facebook 0 \n", - "489 NaT FEMALE 56.0 basic 0 \n", - "780 NaT -unknown- NaN basic 0 \n", - "596 NaT MALE 51.0 basic 0 \n", - "776 2013-08-07 07:45:31.016155904 FEMALE 46.0 basic 0 \n", - "\n", - " language affiliate_channel affiliate_provider first_affiliate_tracked \\\n", - "857 en direct direct untracked \n", - "489 en direct direct untracked \n", - "780 en direct direct linked \n", - "596 en sem-brand google linked \n", - "776 en direct direct untracked \n", - "\n", - " signup_app first_device_type first_browser country_destination \n", - "857 Web iPhone IE NDF \n", - "489 Web Windows Desktop -unknown- NDF \n", - "780 Web Windows Desktop Safari NDF \n", - "596 Web Mac Desktop Safari NDF \n", - "776 Web Mac Desktop Chrome NDF " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['users'].sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_idactionaction_typeaction_detaildevice_typesecs_elapsed
7985222ajax_refresh_subtotalNaNlisting_reviewsWindows Desktop6730.0
23883629ajax_refresh_subtotaldata-unknown-Windows Desktop4842.0
15908418termsclickview_search_resultsMac Desktop111772.0
8900250personalizedataNaNMac Desktop20486.0
26026691search-unknown--unknown-Mac Desktop149526.0
\n", - "
" - ], - "text/plain": [ - " user_id action action_type action_detail \\\n", - "7985 222 ajax_refresh_subtotal NaN listing_reviews \n", - "23883 629 ajax_refresh_subtotal data -unknown- \n", - "15908 418 terms click view_search_results \n", - "8900 250 personalize data NaN \n", - "26026 691 search -unknown- -unknown- \n", - "\n", - " device_type secs_elapsed \n", - "7985 Windows Desktop 6730.0 \n", - "23883 Windows Desktop 4842.0 \n", - "15908 Mac Desktop 111772.0 \n", - "8900 Mac Desktop 20486.0 \n", - "26026 Mac Desktop 149526.0 " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples['sessions'].sample(5)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/airbnb-recruiting-new-user-bookings/metadata.json b/examples/airbnb-recruiting-new-user-bookings/metadata.json deleted file mode 100644 index 4c03506bf..000000000 --- a/examples/airbnb-recruiting-new-user-bookings/metadata.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "tables": [ - { - "headers": true, - "name": "users", - "path": "users_sample.csv", - "primary_key": "id", - "fields": [ - { - "name": "id", - "type": "id" - }, - { - "name": "date_account_created", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "timestamp_first_active", - "type": "datetime", - "format": "%Y%m%d%H%M%S" - }, - { - "name": "date_first_booking", - "type": "datetime", - "format": "%Y-%m-%d" - }, - { - "name": "gender", - "type": "categorical" - }, - { - "name": "age", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "signup_method", - "type": "categorical" - }, - { - "name": "signup_flow", - "type": "categorical" - }, - { - "name": "language", - "type": "categorical" - }, - { - "name": "affiliate_channel", - "type": "categorical" - }, - { - "name": "affiliate_provider", - "type": "categorical" - }, - { - "name": "first_affiliate_tracked", - "type": "categorical" - }, - { - "name": "signup_app", - "type": "categorical" - }, - { - "name": "first_device_type", - "type": "categorical" - }, - { - "name": "first_browser", - "type": "categorical" - }, - { - "name": "country_destination", - "type": "categorical" - } - ] - }, - { - "headers": true, - "name": "sessions", - "path": "sessions_sample.csv", - "fields": [ - { - "name": "user_id", - "ref": { - "field": "id", - "table": "users" - }, - "type": "id" - }, - { - "name": "action", - "type": "categorical" - }, - { - "name": "action_type", - "type": "categorical" - }, - { - "name": "action_detail", - "type": "categorical" - }, - { - "name": "device_type", - "type": "categorical" - }, - { - "name": "secs_elapsed", - "type": "numerical", - "subtype": "integer" - } - ] - } - ] -} diff --git a/examples/quickstart/customers.csv b/examples/quickstart/customers.csv deleted file mode 100644 index 14922885f..000000000 --- a/examples/quickstart/customers.csv +++ /dev/null @@ -1,8 +0,0 @@ -customer_id,cust_postal_code,phone_number1,credit_limit,country -50,11371,6175553295,1000,UK -4,63145,8605551835,500,US -97338810,6096,7035552143,1000,CANADA -630407,63145,7035552143,2000,US -826362,11371,6175553295,1000,UK -55996144,20166,4045553285,1000,FRANCE -598112,63145,6175553295,1000,SPAIN diff --git a/examples/quickstart/metadata.json b/examples/quickstart/metadata.json deleted file mode 100644 index 119d30c76..000000000 --- a/examples/quickstart/metadata.json +++ /dev/null @@ -1,96 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "customer_id", - "type": "id" - }, - { - "name": "cust_postal_code", - "type": "categorical" - }, - { - "name": "phone_number1", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "credit_limit", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "country", - "type": "categorical" - } - ], - "headers": true, - "name": "customers", - "path": "customers.csv", - "primary_key": "customer_id", - "use": true - }, - { - "fields": [ - { - "name": "order_id", - "type": "id" - }, - { - "name": "customer_id", - "ref": { - "field": "customer_id", - "table": "customers" - }, - "type": "id" - }, - { - "name": "order_total", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "orders", - "path": "orders.csv", - "primary_key": "order_id", - "use": true - }, - { - "fields": [ - { - "name": "order_item_id", - "type": "id" - }, - { - "name": "order_id", - "ref": { - "field": "order_id", - "table": "orders" - }, - "type": "id" - }, - { - "name": "product_id", - "type": "categorical" - }, - { - "name": "unit_price", - "type": "numerical", - "subtype": "integer" - }, - { - "name": "quantity", - "type": "numerical", - "subtype": "integer" - } - ], - "headers": true, - "name": "order_items", - "path": "order_items.csv", - "primary_key": "order_item_id", - "use": true - } - ] -} diff --git a/examples/quickstart/order_items.csv b/examples/quickstart/order_items.csv deleted file mode 100644 index 7f47c3d69..000000000 --- a/examples/quickstart/order_items.csv +++ /dev/null @@ -1,50 +0,0 @@ -order_item_id,order_id,product_id,unit_price,quantity -100,10,7,52,8 -101,8,6,125,4 -102,1,6,125,4 -103,4,9,125,4 -104,1,9,113,4 -105,9,10,87,2 -106,10,6,39,4 -107,1,6,50,4 -108,2,3,31,2 -109,4,6,37,3 -110,10,10,50,4 -111,1,10,50,2 -112,6,14,50,1 -113,10,1,161,5 -114,2,2,73,4 -115,1,6,50,2 -116,5,6,63,2 -117,1,6,115,4 -118,4,10,50,4 -119,5,9,50,2 -120,10,15,37,0 -121,7,6,113,7 -122,1,6,125,2 -123,10,6,113,7 -124,8,10,125,4 -125,5,6,125,2 -126,1,6,125,2 -127,8,10,69,2 -128,8,10,50,2 -129,6,8,152,3 -130,5,2,103,4 -131,3,11,30,-1 -132,9,6,97,2 -133,7,1,165,5 -134,8,6,125,2 -135,8,6,50,2 -136,1,9,125,4 -137,6,8,152,3 -138,1,10,50,2 -139,3,6,37,3 -140,6,15,111,1 -141,9,10,125,2 -142,10,2,147,4 -143,3,6,58,4 -144,1,6,102,4 -145,9,10,50,4 -146,6,15,50,1 -147,7,3,30,-1 -148,1,6,125,4 diff --git a/examples/quickstart/orders.csv b/examples/quickstart/orders.csv deleted file mode 100644 index 08371dbb8..000000000 --- a/examples/quickstart/orders.csv +++ /dev/null @@ -1,11 +0,0 @@ -order_id,customer_id,order_total -1,50,2310 -2,4,1507 -10,97338810,730 -6,55996144,730 -3,55996144,939 -4,50,2380 -5,97338810,1570 -7,50,730 -8,97338810,2336 -9,4,743 diff --git a/sdv/metadata.py b/sdv/metadata/__init__.py similarity index 90% rename from sdv/metadata.py rename to sdv/metadata/__init__.py index f19ed985c..d9b3ad368 100644 --- a/sdv/metadata.py +++ b/sdv/metadata/__init__.py @@ -4,11 +4,21 @@ import os from collections import defaultdict -import graphviz import numpy as np import pandas as pd from rdt import HyperTransformer, transformers +from sdv.metadata import visualization +from sdv.metadata.errors import MetadataError +from sdv.metadata.table import Table + +__all__ = [ + 'Metadata', + 'MetadataError', + 'Table', + 'visualization' +] + LOGGER = logging.getLogger(__name__) @@ -51,10 +61,6 @@ def _load_csv(root_path, table_meta): return data -class MetadataError(Exception): - pass - - class Metadata: """Dataset Metadata. @@ -965,119 +971,7 @@ def to_json(self, path): with open(path, 'w') as out_file: json.dump(self._metadata, out_file, indent=4) - @staticmethod - def _get_graphviz_extension(path): - if path: - path_splitted = path.split('.') - if len(path_splitted) == 1: - raise ValueError('Path without graphviz extansion.') - - graphviz_extension = path_splitted[-1] - - if graphviz_extension not in graphviz.backend.FORMATS: - raise ValueError( - '"{}" not a valid graphviz extension format.'.format(graphviz_extension) - ) - - return '.'.join(path_splitted[:-1]), graphviz_extension - - return None, None - - def _visualize_add_nodes(self, plot): - """Add nodes into a `graphviz.Digraph`. - - Each node represent a metadata table. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - # Append table fields - fields = [] - - for name, value in self.get_fields(table).items(): - if value.get('subtype') is not None: - fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) - - else: - fields.append('{} : {}'.format(name, value['type'])) - - fields = r'\l'.join(fields) - - # Append table extra information - extras = [] - - primary_key = self.get_primary_key(table) - if primary_key is not None: - extras.append('Primary key: {}'.format(primary_key)) - - parents = self.get_parents(table) - for parent in parents: - foreign_key = self.get_foreign_key(parent, table) - extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) - - path = self.get_table_meta(table).get('path') - if path is not None: - extras.append('Data path: {}'.format(path)) - - extras = r'\l'.join(extras) - - # Add table node - title = r'{%s|%s\l|%s\l}' % (table, fields, extras) - plot.node(table, label=title) - - def _visualize_add_edges(self, plot): - """Add edges into a `graphviz.Digraph`. - - Each edge represents a relationship between two metadata tables. - - Args: - plot (graphviz.Digraph) - """ - for table in self.get_tables(): - for parent in list(self.get_parents(table)): - plot.edge( - parent, - table, - label=' {}.{} -> {}.{}'.format( - table, self.get_foreign_key(parent, table), - parent, self.get_primary_key(parent) - ), - arrowhead='crow' - ) - - def visualize(self, path=None): - """Plot metadata usign graphviz. - - Try to generate a plot using graphviz. - If a ``path`` is provided save the output into a file. - - Args: - path (str): - Output file path to save the plot, it requires a graphviz - supported extension. If ``None`` do not save the plot. - Defaults to ``None``. - """ - filename, graphviz_extension = self._get_graphviz_extension(path) - plot = graphviz.Digraph( - 'Metadata', - format=graphviz_extension, - node_attr={ - "shape": "Mrecord", - "fillcolor": "lightgoldenrod1", - "style": "filled" - }, - ) - - self._visualize_add_nodes(plot) - self._visualize_add_edges(plot) - - if filename: - plot.render(filename=filename, cleanup=True, format=graphviz_extension) - else: - return plot - - def __str__(self): + def __repr__(self): tables = self.get_tables() relationships = [ ' {}.{} -> {}.{}'.format( @@ -1099,3 +993,18 @@ def __str__(self): tables, '\n'.join(relationships) ) + + def visualize(self, path=None): + """Plot metadata usign graphviz. + + Generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + path (str): + Output file path to save the plot. It requires a graphviz + supported extension. If ``None`` do not save the plot and + just return the ``graphviz.Digraph`` object. + Defaults to ``None``. + """ + return visualization.visualize(self, path) diff --git a/sdv/metadata/errors.py b/sdv/metadata/errors.py new file mode 100644 index 000000000..8e1bada44 --- /dev/null +++ b/sdv/metadata/errors.py @@ -0,0 +1,2 @@ +class MetadataError(Exception): + pass diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py new file mode 100644 index 000000000..598a4a725 --- /dev/null +++ b/sdv/metadata/table.py @@ -0,0 +1,399 @@ +import copy +import json +import logging + +import numpy as np +import pandas as pd +import rdt +from faker import Faker + +from sdv.metadata.errors import MetadataError + +LOGGER = logging.getLogger(__name__) + + +class Table: + """Table Metadata. + + The Metadata class provides a unified layer of abstraction over the metadata + of a single Table, which includes both the necessary details to load the data + from the filesystem and to know how to parse and transform it to numerical data. + """ + + _metadata = None + _hyper_transformer = None + _anonymization_mappings = None + + _FIELD_TEMPLATES = { + 'i': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'f': { + 'type': 'numerical', + 'subtype': 'float', + }, + 'O': { + 'type': 'categorical', + }, + 'b': { + 'type': 'boolean', + }, + 'M': { + 'type': 'datetime', + } + } + _DTYPES = { + ('categorical', None): 'object', + ('boolean', None): 'bool', + ('numerical', None): 'float', + ('numerical', 'float'): 'float', + ('numerical', 'integer'): 'int', + ('datetime', None): 'datetime64', + ('id', None): 'int', + ('id', 'integer'): 'int', + ('id', 'string'): 'str' + } + + def _get_faker(self, category): + """Return the faker object to anonymize data. + + Args: + category (str or tuple): + Fake category to use. If a tuple is passed, the first element is + the category and the rest are additional arguments for the Faker. + + Returns: + function: + Faker function to generate new fake data instances. + + Raises: + ValueError: + A ``ValueError`` is raised if the faker category we want don't exist. + """ + if isinstance(category, (tuple, list)): + category, *args = category + else: + args = tuple() + + try: + faker_method = getattr(Faker(), category) + + if not args: + return faker_method + + def faker(): + return faker_method(*args) + + return faker + + except AttributeError: + raise ValueError('Category "{}" couldn\'t be found on faker'.format(category)) + + def __init__(self, metadata=None, field_names=None, primary_key=None, + field_types=None, anonymize_fields=None, constraints=None, + transformer_templates=None): + # TODO: Validate that the given metadata is a valid dict + self._metadata = metadata or dict() + self._hyper_transformer = None + + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types or {} + self._anonymize_fields = anonymize_fields + self._constraints = constraints + + self._transformer_templates = transformer_templates or {} + + self._fakers = { + name: self._get_faker(category) + for name, category in (anonymize_fields or {}).items() + } + + def _get_field_dtype(self, field_name, field_metadata): + field_type = field_metadata['type'] + field_subtype = field_metadata.get('subtype') + dtype = self._DTYPES.get((field_type, field_subtype)) + if not dtype: + raise MetadataError( + 'Invalid type and subtype combination for field {}: ({}, {})'.format( + field_name, field_type, field_subtype) + ) + + return dtype + + def get_fields(self): + """Get fields metadata. + + Returns: + dict: + Dictionary of fields metadata for this table. + """ + return copy.deepcopy(self._metadata['fields']) + + def get_dtypes(self, ids=False): + """Get a ``dict`` with the ``dtypes`` for each field of the table. + + Args: + ids (bool): + Whether or not include the id fields. Defaults to ``False``. + + Returns: + dict: + Dictionary that contains the field names and data types. + + Raises: + ValueError: + If a field has an invalid type or subtype. + """ + dtypes = dict() + for name, field_meta in self._metadata['fields'].items(): + field_type = field_meta['type'] + + if ids or (field_type != 'id'): + dtypes[name] = self._get_field_dtype(name, field_meta) + + return dtypes + + def _build_fields_metadata(self, data): + """Build all the fields metadata. + + Args: + data (pandas.DataFrame): + Data to be analyzed. + + Returns: + dict: + Dict of valid fields. + + Raises: + ValueError: + If a column from the data analyzed is an unsupported data type + """ + field_names = self._field_names or data.columns + + fields_metadata = dict() + for field_name in field_names: + if field_name not in data: + raise ValueError('Field {} not found in given data'.format(field_name)) + + field_meta = self._field_types.get(field_name) + if field_meta: + # Validate the given meta + self._get_field_dtype(field_name, field_meta) + else: + dtype = data[field_name].dtype + field_template = self._FIELD_TEMPLATES.get(dtype.kind) + if field_template is None: + raise ValueError('Unsupported dtype {} in column {}'.format(dtype, field_name)) + + field_meta = copy.deepcopy(field_template) + + fields_metadata[field_name] = field_meta + + return fields_metadata + + def _get_transformers(self, dtypes): + """Create the transformer instances needed to process the given dtypes. + + Args: + dtypes (dict): + mapping of field names and dtypes. + + Returns: + dict: + mapping of field names and transformer instances. + """ + transformer_templates = { + 'i': rdt.transformers.NumericalTransformer(dtype=int), + 'f': rdt.transformers.NumericalTransformer(dtype=float), + 'O': rdt.transformers.CategoricalTransformer, + 'b': rdt.transformers.BooleanTransformer, + 'M': rdt.transformers.DatetimeTransformer, + } + transformer_templates.update(self._transformer_templates) + + transformers = dict() + for name, dtype in dtypes.items(): + transformer_template = transformer_templates[np.dtype(dtype).kind] + if isinstance(transformer_template, type): + transformer = transformer_template() + else: + transformer = copy.deepcopy(transformer_template) + + LOGGER.info('Loading transformer %s for field %s', + transformer.__class__.__name__, name) + transformers[name] = transformer + + return transformers + + def _fit_hyper_transformer(self, data): + """Create and return a new ``rdt.HyperTransformer`` instance. + + First get the ``dtypes`` and then use them to build a transformer dictionary + to be used by the ``HyperTransformer``. + + Returns: + rdt.HyperTransformer + """ + dtypes = self.get_dtypes(ids=False) + transformers_dict = self._get_transformers(dtypes) + self._hyper_transformer = rdt.HyperTransformer(transformers=transformers_dict) + self._hyper_transformer.fit(data[list(dtypes.keys())]) + + @staticmethod + def _get_key_subtype(field_meta): + """Get the appropriate key subtype.""" + field_type = field_meta['type'] + + if field_type == 'categorical': + field_subtype = 'string' + + elif field_type in ('numerical', 'id'): + field_subtype = field_meta['subtype'] + if field_subtype not in ('integer', 'string'): + raise ValueError( + 'Invalid field "subtype" for key field: "{}"'.format(field_subtype) + ) + + else: + raise ValueError( + 'Invalid field "type" for key field: "{}"'.format(field_type) + ) + + return field_subtype + + def set_primary_key(self, field_name): + """Set the primary key of this table. + + The field must exist and either be an integer or categorical field. + + Args: + field_name (str): + Name of the field to be used as the new primary key. + + Raises: + ValueError: + If the table or the field do not exist or if the field has an + invalid type or subtype. + """ + if field_name is not None: + if field_name not in self._metadata['fields']: + raise ValueError('Field "{}" does not exist in this table'.format(field_name)) + + field_metadata = self._metadata['fields'][field_name] + field_subtype = self._get_key_subtype(field_metadata) + + field_metadata.update({ + 'type': 'id', + 'subtype': field_subtype + }) + + self._primary_key = field_name + + def _make_anonymization_mappings(self, data): + mappings = {} + for name, faker in self._fakers.items(): + uniques = data[name].unique() + fake_values = [faker() for x in range(len(uniques))] + mappings[name] = dict(zip(uniques, fake_values)) + + self._anonymization_mappings = mappings + + def _anonymize(self, data): + if self._anonymization_mappings: + data = data.copy() + for name, mapping in self._anonymization_mappings.items(): + data[name] = data[name].map(mapping) + + return data + + def fit(self, data): + """Fit this metadata to the given data. + + Args: + data (pandas.DataFrame): + Table to be analyzed. + """ + self._field_names = self._field_names or list(data.columns) + self._metadata['fields'] = self._build_fields_metadata(data) + self.set_primary_key(self._primary_key) + + if self._fakers: + self._make_anonymization_mappings(data) + data = self._anonymize(data) + + # TODO: Treat/Learn constraints + self._fit_hyper_transformer(data) + + def transform(self, data): + """Transform the given data. + + Args: + data (pandas.DataFrame): + Table data. + + Returns: + pandas.DataFrame: + Transformed data. + """ + data = self._anonymize(data[self._field_names]) + return self._hyper_transformer.transform(data) + + def reverse_transform(self, data): + """Reverse the transformed data to the original format. + + Args: + data (pandas.DataFrame): + Data to be reverse transformed. + + Returns: + pandas.DataFrame + """ + reversed_data = self._hyper_transformer.reverse_transform(data) + + fields = self._metadata['fields'] + for name, dtype in self.get_dtypes(ids=True).items(): + field_type = fields[name]['type'] + if field_type == 'id': + field_data = pd.Series(np.arange(len(reversed_data))) + else: + field_data = reversed_data[name] + + reversed_data[name] = field_data.dropna().astype(dtype) + + return reversed_data[self._field_names] + + # ###################### # + # Metadata Serialization # + # ###################### # + + def to_dict(self): + """Get a dict representation of this metadata. + + Returns: + dict: + dict representation of this metadata. + """ + return copy.deepcopy(self._metadata) + + def to_json(self, path): + """Dump this metadata into a JSON file. + + Args: + path (str): + Path of the JSON file where this metadata will be stored. + """ + with open(path, 'w') as out_file: + json.dump(self._metadata, out_file, indent=4) + + @classmethod + def from_json(cls, path): + """Load a Table from a JSON + + Args: + path (str): + Path of the JSON file to load + """ + with open(path, 'r') as in_file: + return cls(json.load(in_file)) diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py new file mode 100644 index 000000000..b14fb160a --- /dev/null +++ b/sdv/metadata/visualization.py @@ -0,0 +1,121 @@ +import graphviz + + +def _get_graphviz_extension(path): + if path: + path_splitted = path.split('.') + if len(path_splitted) == 1: + raise ValueError('Path without graphviz extansion.') + + graphviz_extension = path_splitted[-1] + + if graphviz_extension not in graphviz.backend.FORMATS: + raise ValueError( + '"{}" not a valid graphviz extension format.'.format(graphviz_extension) + ) + + return '.'.join(path_splitted[:-1]), graphviz_extension + + return None, None + + +def _add_nodes(metadata, digraph): + """Add nodes into a `graphviz.Digraph`. + + Each node represent a metadata table. + + Args: + metadata (Metadata): + Metadata object to plot. + digraph (graphviz.Digraph): + graphviz.Digraph being built + """ + for table in metadata.get_tables(): + # Append table fields + fields = [] + + for name, value in metadata.get_fields(table).items(): + if value.get('subtype') is not None: + fields.append('{} : {} - {}'.format(name, value['type'], value['subtype'])) + + else: + fields.append('{} : {}'.format(name, value['type'])) + + fields = r'\l'.join(fields) + + # Append table extra information + extras = [] + + primary_key = metadata.get_primary_key(table) + if primary_key is not None: + extras.append('Primary key: {}'.format(primary_key)) + + parents = metadata.get_parents(table) + for parent in parents: + foreign_key = metadata.get_foreign_key(parent, table) + extras.append('Foreign key ({}): {}'.format(parent, foreign_key)) + + path = metadata.get_table_meta(table).get('path') + if path is not None: + extras.append('Data path: {}'.format(path)) + + extras = r'\l'.join(extras) + + # Add table node + title = r'{%s|%s\l|%s\l}' % (table, fields, extras) + digraph.node(table, label=title) + + +def _add_edges(metadata, digraph): + """Add edges into a `graphviz.Digraph`. + + Each edge represents a relationship between two metadata tables. + + Args: + digraph (graphviz.Digraph) + """ + for table in metadata.get_tables(): + for parent in list(metadata.get_parents(table)): + digraph.edge( + parent, + table, + label=' {}.{} -> {}.{}'.format( + table, metadata.get_foreign_key(parent, table), + parent, metadata.get_primary_key(parent) + ), + arrowhead='oinv' + ) + + +def visualize(metadata, path=None): + """Plot metadata usign graphviz. + + Try to generate a plot using graphviz. + If a ``path`` is provided save the output into a file. + + Args: + metadata (Metadata): + Metadata object to plot. + path (str): + Output file path to save the plot, it requires a graphviz + supported extension. If ``None`` do not save the plot. + Defaults to ``None``. + """ + filename, graphviz_extension = _get_graphviz_extension(path) + digraph = graphviz.Digraph( + 'Metadata', + format=graphviz_extension, + node_attr={ + "shape": "Mrecord", + "fillcolor": "lightgoldenrod1", + "style": "filled" + }, + ) + + _add_nodes(metadata, digraph) + _add_edges(metadata, digraph) + + if filename: + digraph.render(filename=filename, cleanup=True, format=graphviz_extension) + else: + return digraph diff --git a/sdv/tabular/__init__.py b/sdv/tabular/__init__.py new file mode 100644 index 000000000..dae118d1d --- /dev/null +++ b/sdv/tabular/__init__.py @@ -0,0 +1,7 @@ +from sdv.tabular.copulas import GaussianCopula +from sdv.tabular.ctgan import CTGAN + +__all__ = [ + 'GaussianCopula', + 'CTGAN' +] diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py new file mode 100644 index 000000000..f21be7123 --- /dev/null +++ b/sdv/tabular/base.py @@ -0,0 +1,213 @@ +import pickle + +from sdv.metadata import Table + + +class BaseTabularModel(): + """Base class for all the tabular models. + + The ``BaseTabularModel`` class defines the common API that all the + TabularModels need to implement, as well as common functionality. + """ + + TRANSFORMER_TEMPLATES = None + + _metadata = None + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, table_metadata=None, *args, **kwargs): + """Initialize a Tabular Model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. + anonymize_fields (dict[str, str]): + Dict specifying which fields to anonymize and what faker + category they belong to. + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + *args, **kwargs: + Subclasses will add any arguments or keyword arguments needed. + """ + if table_metadata is not None: + if isinstance(table_metadata, dict): + table_metadata = Table(table_metadata,) + + for arg in (field_names, primary_key, field_types, anonymize_fields): + if arg: + raise ValueError( + 'If table_metadata is given {} must be None'.format(arg.__name__)) + + self._metadata = table_metadata + + else: + self._field_names = field_names + self._primary_key = primary_key + self._field_types = field_types + self._anonymize_fields = anonymize_fields + + def _fit_metadata(self, data): + """Generate a new Table metadata and fit it to the data. + + The information provided will be used to create the Table instance + and then the rest of information will be learned from the given + data. + + Args: + data (pandas.DataFrame): + Data to learn from. + """ + metadata = Table( + field_names=self._field_names, + primary_key=self._primary_key, + field_types=self._field_types, + anonymize_fields=self._anonymize_fields, + transformer_templates=self.TRANSFORMER_TEMPLATES, + ) + metadata.fit(data) + + self._metadata = metadata + + def fit(self, data): + """Fit this model to the data. + + If the table metadata has not been given, learn it from the data. + + Args: + data (pandas.DataFrame or str): + Data to fit the model to. It can be passed as a + ``pandas.DataFrame`` or as an ``str``. + If an ``str`` is passed, it is assumed to be + the path to a CSV file which can be loaded using + ``pandas.read_csv``. + """ + if self._metadata is None: + self._fit_metadata(data) + + self._size = len(data) + + transformed = self._metadata.transform(data) + self._fit(transformed) + + def get_metadata(self): + """Get metadata about the table. + + This will return an ``sdv.metadata.Table`` object containing + the information about the data that this model has learned. + + This Table metadata will contain some common information, + such as field names and data types, as well as additional + information that each Sub-class might add, such as the + observed data field distributions and their parameters. + + Returns: + sdv.metadata.Table: + Table metadata. + """ + return self._metadata + + def sample(self, size=None, values=None): + """Sample rows from this table. + + Args: + size (int): + Number of rows to sample. If not given the model + will generate as many rows as there were in the + data passed to the ``fit`` method. + values (dict): <- FUTURE + Fixed values to use for knowledge-based sampling. + In case the model does not support knowledge-based + sampling, a discard+resample strategy will be used. + + Returns: + pandas.DataFrame: + Sampled data. + """ + size = size or self._size + sampled = self._sample(size) + return self._metadata.reverse_transform(sampled) + + def get_parameters(self): + """Get the parameters learned from the data. + + The result is a flat dict (single level) which contains + all the necessary parameters to be able to reproduce + this model. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + parameters (dict): + flat dict (single level) which contains all the + necessary parameters to be able to reproduce + this model. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() + + @classmethod + def from_parameters(cls): + """Regenerate a previously learned model from its parameters. + + Subclasses which are not parametric, such as DeepLearning + based models, raise a NonParametricError indicating that + this method is not supported for their implementation. + + Returns: + BaseTabularModel: + New instance with its parameters set. + + Raises: + NonParametricError: + If the model is not parametric or cannot be described + using a simple dictionary. + """ + raise NotImplementedError() + + def save(self, path): + """Save this model instance to the given path using pickle. + + Args: + path (str): + Path where the SDV instance will be serialized. + """ + with open(path, 'wb') as output: + pickle.dump(self, output) + + @classmethod + def load(cls, path): + """Load a TabularModel instance from a given path. + + Args: + path (str): + Path from which to load the instance. + + Returns: + TabularModel: + The loaded tabular model. + """ + with open(path, 'rb') as f: + return pickle.load(f) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py new file mode 100644 index 000000000..7ae00120f --- /dev/null +++ b/sdv/tabular/copulas.py @@ -0,0 +1,246 @@ +import copulas +import numpy as np +import rdt + +from sdv.tabular.base import BaseTabularModel +from sdv.tabular.utils import ( + check_matrix_symmetric_positive_definite, flatten_dict, make_positive_definite, square_matrix, + unflatten_dict) + + +class GaussianCopula(BaseTabularModel): + """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. + + Args: + distribution (copulas.univariate.Univariate or str): + Copulas univariate distribution to use. + categorical_transformer (str): + Type of transformer to use for the categorical variables, to choose + from ``one_hot_encoding``, ``label_encoding``, ``categorical`` and + ``categorical_fuzzy``. + """ + + DEFAULT_DISTRIBUTION = copulas.univariate.Univariate + _distribution = None + _categorical_transformer = None + _model = None + + HYPERPARAMETERS = { + 'distribution': { + 'type': 'str or copulas.univariate.Univariate', + 'default': 'copulas.univariate.Univariate', + 'description': 'Univariate distribution to use to model each column', + 'choices': [ + 'copulas.univariate.Univariate', + 'copulas.univariate.GaussianUnivariate', + 'copulas.univariate.GammaUnivariate', + 'copulas.univariate.BetaUnivariate', + 'copulas.univariate.StudentTUnivariate', + 'copulas.univariate.GaussianKDE', + 'copulas.univariate.TruncatedGaussian', + ] + }, + 'categorical_transformer': { + 'type': 'str', + 'default': 'categoircal_fuzzy', + 'description': 'Type of transformer to use for the categorical variables', + 'choices': [ + 'categorical', + 'categorical_fuzzy', + 'one_hot_encoding', + 'label_encoding' + ] + } + } + DEFAULT_TRANSFORMER = 'one_hot_encoding' + CATEGORICAL_TRANSFORMERS = { + 'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False), + 'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True), + 'one_hot_encoding': rdt.transformers.OneHotEncodingTransformer, + 'label_encoding': rdt.transformers.LabelEncodingTransformer, + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.OneHotEncodingTransformer + } + + def __init__(self, distribution=None, categorical_transformer=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self._metadata is not None and 'model_kwargs' in self._metadata._metadata: + model_kwargs = self._metadata._metadata['model_kwargs'] + if distribution is None: + distribution = model_kwargs['distribution'] + + if categorical_transformer is None: + categorical_transformer = model_kwargs['categorical_transformer'] + + self._size = model_kwargs['_size'] + + self._distribution = distribution or self.DEFAULT_DISTRIBUTION + + categorical_transformer = categorical_transformer or self.DEFAULT_TRANSFORMER + self._categorical_transformer = categorical_transformer + + self.TRANSFORMER_TEMPLATES['O'] = self.CATEGORICAL_TRANSFORMERS[categorical_transformer] + + def _update_metadata(self): + parameters = self._model.to_dict() + univariates = parameters['univariates'] + columns = parameters['columns'] + + distributions = {} + for column, univariate in zip(columns, univariates): + distributions[column] = univariate['type'] + + self._metadata._metadata['model_kwargs'] = { + 'distribution': distributions, + 'categorical_transformer': self._categorical_transformer, + '_size': self._size + } + + def _fit(self, data): + """Fit the model to the table. + + Args: + table_data (pandas.DataFrame): + Data to be fitted. + """ + self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution) + self._model.fit(data) + self._update_metadata() + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) + + def get_parameters(self, flatten=False): + """Get copula model parameters. + + Compute model ``covariance`` and ``distribution.std`` + before it returns the flatten dict. + + Args: + flatten (bool): + Whether to flatten the parameters or not before + returning them. + + Returns: + dict: + Copula parameters. + """ + if not flatten: + return self._model.to_dict() + + values = list() + triangle = np.tril(self._model.covariance) + + for index, row in enumerate(triangle.tolist()): + values.append(row[:index + 1]) + + self._model.covariance = np.array(values) + params = self._model.to_dict() + univariates = dict() + for name, univariate in zip(params.pop('columns'), params['univariates']): + univariates[name] = univariate + if 'scale' in univariate: + scale = univariate['scale'] + if scale == 0: + scale = copulas.EPSILON + + univariate['scale'] = np.log(scale) + + params['univariates'] = univariates + + return flatten_dict(params) + + def _prepare_sampled_covariance(self, covariance): + """Prepare a covariance matrix. + + Args: + covariance (list): + covariance after unflattening model parameters. + + Result: + list[list]: + symmetric Positive semi-definite matrix. + """ + covariance = np.array(square_matrix(covariance)) + covariance = (covariance + covariance.T - (np.identity(covariance.shape[0]) * covariance)) + + if not check_matrix_symmetric_positive_definite(covariance): + covariance = make_positive_definite(covariance) + + return covariance.tolist() + + def _unflatten_gaussian_copula(self, model_parameters): + """Prepare unflattened model params to recreate Gaussian Multivariate instance. + + The preparations consist basically in: + + - Transform sampled negative standard deviations from distributions into positive + numbers + + - Ensure the covariance matrix is a valid symmetric positive-semidefinite matrix. + + - Add string parameters kept inside the class (as they can't be modelled), + like ``distribution_type``. + + Args: + model_parameters (dict): + Sampled and reestructured model parameters. + + Returns: + dict: + Model parameters ready to recreate the model. + """ + + univariate_kwargs = { + 'type': model_parameters['distribution'] + } + + columns = list() + univariates = list() + for column, univariate in model_parameters['univariates'].items(): + columns.append(column) + univariate.update(univariate_kwargs) + if 'scale' in univariate: + univariate['scale'] = np.exp(univariate['scale']) + + univariates.append(univariate) + + model_parameters['univariates'] = univariates + model_parameters['columns'] = columns + + covariance = model_parameters.get('covariance') + model_parameters['covariance'] = self._prepare_sampled_covariance(covariance) + + return model_parameters + + def set_parameters(self, parameters, unflatten=False): + """Set copula model parameters. + + Add additional keys after unflatte the parameters + in order to set expected parameters for the copula. + + Args: + dict: + Copula flatten parameters. + unflatten (bool): + Whether the parameters need to be unflattened or not. + """ + if unflatten: + parameters = unflatten_dict(parameters) + parameters.setdefault('distribution', self._distribution) + + parameters = self._unflatten_gaussian_copula(parameters) + + self._model = copulas.multivariate.GaussianMultivariate.from_dict(parameters) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py new file mode 100644 index 000000000..eb1f87d33 --- /dev/null +++ b/sdv/tabular/ctgan.py @@ -0,0 +1,142 @@ +import rdt + +from sdv.tabular.base import BaseTabularModel + + +class CTGAN(BaseTabularModel): + """Model wrapping ``CTGANSynthesizer`` copula. + + Args: + TBD + """ + + _CTGAN_CLASS = None + _model = None + + HYPERPARAMETERS = { + 'TBD' + } + TRANSFORMER_TEMPLATES = { + 'O': rdt.transformers.LabelEncodingTransformer + } + + def __init__(self, field_names=None, primary_key=None, field_types=None, + anonymize_fields=None, constraints=None, table_metadata=None, + epochs=300, log_frequency=True, embedding_dim=128, gen_dim=(256, 256), + dis_dim=(256, 256), l2scale=1e-6, batch_size=500): + """Initialize this CTGAN model. + + Args: + field_names (list[str]): + List of names of the fields that need to be modeled + and included in the generated output data. Any additional + fields found in the data will be ignored and will not be + included in the generated output, except if they have + been added as primary keys or fields to anonymize. + If ``None``, all the fields found in the data are used. + primary_key (str, list[str] or dict[str, dict]): + Specification about which field or fields are the + primary key of the table and information about how + to generate them. + field_types (dict[str, dict]): + Dictinary specifying the data types and subtypes + of the fields that will be modeled. Field types and subtypes + combinations must be compatible with the SDV Metadata Schema. + anonymize_fields (dict[str, str or tuple]): + Dict specifying which fields to anonymize and what faker + category they belong to. If arguments for the faker need to be + passed to fine tune the value generation a tuple can be passed, + where the first element is the category and the rest are additional + positional arguments for the Faker. + constraints (list[dict]): + List of dicts specifying field and inter-field constraints. + TODO: Format TBD + table_metadata (dict or metadata.Table): + Table metadata instance or dict representation. + If given alongside any other metadata-related arguments, an + exception will be raised. + If not given at all, it will be built using the other + arguments or learned from the data. + epochs (int): + Number of training epochs. Defaults to 300. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. + embedding_dim (int): + Size of the random sample passed to the Generator. Defaults to 128. + gen_dim (tuple or list of ints): + Size of the output samples for each one of the Residuals. A Resiudal Layer + will be created for each one of the values provided. Defaults to (256, 256). + dis_dim (tuple or list of ints): + Size of the output samples for each one of the Discriminator Layers. A Linear + Layer will be created for each one of the values provided. Defaults to (256, 256). + l2scale (float): + Wheight Decay for the Adam Optimizer. Defaults to 1e-6. + batch_size (int): + Number of data samples to process in each step. + """ + super().__init__( + field_names=field_names, + primary_key=primary_key, + field_types=field_types, + anonymize_fields=anonymize_fields, + constraints=constraints, + table_metadata=table_metadata + ) + try: + from ctgan import CTGANSynthesizer # Lazy import to make dependency optional + + self._CTGAN_CLASS = CTGANSynthesizer + except ImportError as ie: + ie.msg += ( + '\n\nIt seems like `ctgan` is not installed.\n' + 'Please install it using:\n\n pip install ctgan' + ) + raise + + self._embedding_dim = embedding_dim + self._gen_dim = gen_dim + self._dis_dim = dis_dim + self._l2scale = l2scale + self._batch_size = batch_size + self._epochs = epochs + self._log_frequency = log_frequency + + def _fit(self, data): + """Fit the model to the table. + + Args: + data (pandas.DataFrame): + Data to be learned. + """ + self._model = self._CTGAN_CLASS( + embedding_dim=self._embedding_dim, + gen_dim=self._gen_dim, + dis_dim=self._dis_dim, + l2scale=self._l2scale, + batch_size=self._batch_size, + ) + categoricals = [ + field + for field, meta in self._metadata.get_fields().items() + if meta['type'] == 'categorical' + ] + self._model.fit( + data, + epochs=self._epochs, + discrete_columns=categoricals, + log_frequency=self._log_frequency, + ) + + def _sample(self, size): + """Sample ``size`` rows from the model. + + Args: + size (int): + Amount of rows to sample. + + Returns: + pandas.DataFrame: + Sampled data. + """ + return self._model.sample(size) diff --git a/sdv/tabular/utils.py b/sdv/tabular/utils.py new file mode 100644 index 000000000..241a20048 --- /dev/null +++ b/sdv/tabular/utils.py @@ -0,0 +1,219 @@ +import numpy as np + +IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] + + +def flatten_array(nested, prefix=''): + """Flatten an array as a dict. + + Args: + nested (list, numpy.array): + Iterable to flatten. + prefix (str): + Name to append to the array indices. Defaults to ``''``. + + Returns: + dict: + Flattened array. + """ + result = dict() + for index in range(len(nested)): + prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) + + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def flatten_dict(nested, prefix=''): + """Flatten a dictionary. + + This method returns a flatten version of a dictionary, concatenating key names with + double underscores. + + Args: + nested (dict): + Original dictionary to flatten. + prefix (str): + Prefix to append to key name. Defaults to ``''``. + + Returns: + dict: + Flattened dictionary. + """ + result = dict() + + for key, value in nested.items(): + prefix_key = '__'.join([prefix, str(key)]) if len(prefix) else key + + if key in IGNORED_DICT_KEYS and not isinstance(value, (dict, list)): + continue + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix_key)) + + elif isinstance(value, (np.ndarray, list)): + result.update(flatten_array(value, prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def _key_order(key_value): + parts = list() + for part in key_value[0].split('__'): + if part.isdigit(): + part = int(part) + + parts.append(part) + + return parts + + +def unflatten_dict(flat): + """Transform a flattened dict into its original form. + + Args: + flat (dict): + Flattened dict. + + Returns: + dict: + Nested dict (if corresponds) + """ + unflattened = dict() + + for key, value in sorted(flat.items(), key=_key_order): + if '__' in key: + key, subkey = key.split('__', 1) + subkey, name = subkey.rsplit('__', 1) + + if name.isdigit(): + column_index = int(name) + row_index = int(subkey) + + array = unflattened.setdefault(key, list()) + + if len(array) == row_index: + row = list() + array.append(row) + elif len(array) == row_index + 1: + row = array[row_index] + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + if len(row) == column_index: + row.append(value) + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + else: + subdict = unflattened.setdefault(key, dict()) + if subkey.isdigit(): + subkey = int(subkey) + + inner = subdict.setdefault(subkey, dict()) + inner[name] = value + + else: + unflattened[key] = value + + return unflattened + + +def impute(data): + for column in data: + column_data = data[column] + if column_data.dtype in (np.int, np.float): + fill_value = column_data.mean() + else: + fill_value = column_data.mode()[0] + + data[column] = data[column].fillna(fill_value) + + return data + + +def square_matrix(triangular_matrix): + """Fill with zeros a triangular matrix to reshape it to a square one. + + Args: + triangular_matrix (list [list [float]]): + Array of arrays of + + Returns: + list: + Square matrix. + """ + length = len(triangular_matrix) + zero = [0.0] + + for item in triangular_matrix: + item.extend(zero * (length - len(item))) + + return triangular_matrix + + +def check_matrix_symmetric_positive_definite(matrix): + """Check if a matrix is symmetric positive-definite. + + Args: + matrix (list or numpy.ndarray): + Matrix to evaluate. + + Returns: + bool + """ + try: + if len(matrix.shape) != 2 or matrix.shape[0] != matrix.shape[1]: + # Not 2-dimensional or square, so not simmetric. + return False + + np.linalg.cholesky(matrix) + return True + + except np.linalg.LinAlgError: + return False + + +def make_positive_definite(matrix): + """Find the nearest positive-definite matrix to input. + + Args: + matrix (numpy.ndarray): + Matrix to transform + + Returns: + numpy.ndarray: + Closest symetric positive-definite matrix. + """ + symetric_matrix = (matrix + matrix.T) / 2 + _, s, V = np.linalg.svd(symetric_matrix) + symmetric_polar = np.dot(V.T, np.dot(np.diag(s), V)) + A2 = (symetric_matrix + symmetric_polar) / 2 + A3 = (A2 + A2.T) / 2 + + if check_matrix_symmetric_positive_definite(A3): + return A3 + + spacing = np.spacing(np.linalg.norm(matrix)) + identity = np.eye(matrix.shape[0]) + iterations = 1 + while not check_matrix_symmetric_positive_definite(A3): + min_eigenvals = np.min(np.real(np.linalg.eigvals(A3))) + A3 += identity * (-min_eigenvals * iterations**2 + spacing) + iterations += 1 + + return A3 diff --git a/setup.py b/setup.py index c8a4a56b5..ad3d04e62 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,17 @@ install_requires = [ 'exrex>=0.9.4,<0.11', - 'numpy>=1.15.4,<1.17', + 'numpy>=1.15.4,<2', 'pandas>=0.23.4,<0.25', - 'copulas>=0.3,<0.4', - 'rdt>=0.2.1,<0.3', + 'copulas>=0.3.1,<0.4', + 'rdt>=0.2.3,<0.3', 'graphviz>=0.13.2', - 'sdmetrics>=0.0.1,<0.0.2' + 'sdmetrics>=0.0.2.dev0,<0.0.3', + 'scikit-learn<0.23,>=0.21', +] + +ctgan_requires = [ + 'ctgan>=0.2.2.dev0,<0.3', ] setup_requires = [ @@ -40,6 +45,7 @@ # docs 'm2r>=0.2.0,<0.3', + 'nbsphinx>=0.5.0,<0.7', 'Sphinx>=1.7.1,<3', 'sphinx_rtd_theme>=0.2.4,<0.5', 'autodocsumm>=0.1.10', @@ -76,8 +82,8 @@ ], description='Automated Generative Modeling and Sampling', extras_require={ - 'test': tests_require, - 'dev': development_requires + tests_require, + 'test': tests_require + ctgan_requires, + 'dev': development_requires + tests_require + ctgan_requires, }, install_package_data=True, install_requires=install_requires, diff --git a/tests/integration/tabular/test_copulas.py b/tests/integration/tabular/test_copulas.py new file mode 100644 index 000000000..0a34a0e15 --- /dev/null +++ b/tests/integration/tabular/test_copulas.py @@ -0,0 +1,56 @@ +from sdv.demo import load_demo +from sdv.tabular.copulas import GaussianCopula + + +def test_gaussian_copula(): + users = load_demo(metadata=False)['users'] + + field_types = { + 'age': { + 'type': 'numerical', + 'subtype': 'integer', + }, + 'country': { + 'type': 'categorical' + } + } + anonymize_fields = { + 'country': 'country_code' + } + + gc = GaussianCopula( + field_names=['user_id', 'country', 'gender', 'age'], + field_types=field_types, + primary_key='user_id', + anonymize_fields=anonymize_fields, + categorical_transformer='one_hot_encoding', + ) + gc.fit(users) + + parameters = gc.get_parameters() + new_gc = GaussianCopula( + table_metadata=gc.get_metadata(), + categorical_transformer='one_hot_encoding', + ) + new_gc.set_parameters(parameters) + + sampled = new_gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + # country codes have been replaced with new ones + assert set(sampled.country.unique()) != set(users.country.unique()) + + metadata = gc.get_metadata().to_dict() + assert metadata['fields'] == { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + + assert 'model_kwargs' in metadata diff --git a/tests/integration/tabular/test_ctgan.py b/tests/integration/tabular/test_ctgan.py new file mode 100644 index 000000000..8cc8d4ad6 --- /dev/null +++ b/tests/integration/tabular/test_ctgan.py @@ -0,0 +1,29 @@ +from sdv.demo import load_demo +from sdv.tabular.ctgan import CTGAN + + +def test_ctgan(): + users = load_demo(metadata=False)['users'] + + gc = CTGAN( + primary_key='user_id', + epochs=1 + ) + gc.fit(users) + + sampled = gc.sample() + + # test shape is right + assert sampled.shape == users.shape + + # test user_id has been generated as an ID field + assert list(sampled['user_id']) == list(range(0, len(users))) + + assert gc.get_metadata().to_dict() == { + 'fields': { + 'user_id': {'type': 'id', 'subtype': 'integer'}, + 'country': {'type': 'categorical'}, + 'gender': {'type': 'categorical'}, + 'age': {'type': 'numerical', 'subtype': 'integer'} + } + } diff --git a/tests/test_metadata.py b/tests/metadata/test___init__.py similarity index 87% rename from tests/test_metadata.py rename to tests/metadata/test___init__.py index 812ac7bcd..9968597ae 100644 --- a/tests/test_metadata.py +++ b/tests/metadata/test___init__.py @@ -1,7 +1,6 @@ from unittest import TestCase -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import Mock, call, patch -import graphviz import pandas as pd import pytest @@ -963,121 +962,3 @@ def test_add_field(self): assert metadata._metadata == expected_metadata metadata._check_field.assert_called_once_with('a_table', 'a_field', exists=False) - - def test__get_graphviz_extension_path_without_extension(self): - """Raises a ValueError when the path doesn't contains an extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path') - - def test__get_graphviz_extension_invalid_extension(self): - """Raises a ValueError when the path contains an invalid extension.""" - with pytest.raises(ValueError): - Metadata._get_graphviz_extension('/some/path.foo') - - def test__get_graphviz_extension_none(self): - """Get graphviz with path equals to None.""" - # Run - result = Metadata._get_graphviz_extension(None) - - # Asserts - assert result == (None, None) - - def test__get_graphviz_extension_valid(self): - """Get a valid graphviz extension.""" - # Run - result = Metadata._get_graphviz_extension('/some/path.png') - - # Asserts - assert result == ('/some/path', 'png') - - def test__visualize_add_nodes(self): - """Add nodes into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - minimock = Mock() - - # pass tests in python3.5 - minimock.items.return_value = ( - ('a_field', {'type': 'numerical', 'subtype': 'integer'}), - ('b_field', {'type': 'id'}), - ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) - ) - - metadata.get_tables.return_value = ['demo'] - metadata.get_fields.return_value = minimock - - metadata.get_primary_key.return_value = 'b_field' - metadata.get_parents.return_value = set(['other']) - metadata.get_foreign_key.return_value = 'c_field' - - metadata.get_table_meta.return_value = {'path': None} - - plot = Mock() - - # Run - Metadata._visualize_add_nodes(metadata, plot) - - # Asserts - expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ - r"c_field : id\l|Primary key: b_field\l" \ - r"Foreign key (other): c_field\l}" - - metadata.get_fields.assert_called_once_with('demo') - metadata.get_primary_key.assert_called_once_with('demo') - metadata.get_parents.assert_called_once_with('demo') - metadata.get_table_meta.assert_called_once_with('demo') - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_table_meta.assert_called_once_with('demo') - - plot.node.assert_called_once_with('demo', label=expected_node_label) - - def test__visualize_add_edges(self): - """Add edges into a graphviz digraph.""" - # Setup - metadata = MagicMock(spec_set=Metadata) - - metadata.get_tables.return_value = ['demo', 'other'] - metadata.get_parents.side_effect = [set(['other']), set()] - - metadata.get_foreign_key.return_value = 'fk' - metadata.get_primary_key.return_value = 'pk' - - plot = Mock() - - # Run - Metadata._visualize_add_edges(metadata, plot) - - # Asserts - expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') - - metadata.get_tables.assert_called_once_with() - metadata.get_foreign_key.assert_called_once_with('other', 'demo') - metadata.get_primary_key.assert_called_once_with('other') - assert metadata.get_parents.call_args_list == [call('demo'), call('other')] - - plot.edge.assert_called_once_with( - 'other', - 'demo', - label=expected_edge_label, - arrowhead='crow' - ) - - @patch('sdv.metadata.graphviz') - def test_visualize(self, graphviz_mock): - """Metadata visualize digraph""" - # Setup - plot = Mock(spec_set=graphviz.Digraph) - graphviz_mock.Digraph.return_value = plot - - metadata = MagicMock(spec_set=Metadata) - metadata._get_graphviz_extension.return_value = ('output', 'png') - - # Run - Metadata.visualize(metadata, path='output.png') - - # Asserts - metadata._get_graphviz_extension.assert_called_once_with('output.png') - metadata._visualize_add_nodes.assert_called_once_with(plot) - metadata._visualize_add_edges.assert_called_once_with(plot) - - plot.render.assert_called_once_with(filename='output', cleanup=True, format='png') diff --git a/tests/metadata/test_visualization.py b/tests/metadata/test_visualization.py new file mode 100644 index 000000000..b9f588593 --- /dev/null +++ b/tests/metadata/test_visualization.py @@ -0,0 +1,132 @@ +from unittest.mock import MagicMock, Mock, call, patch + +import graphviz +import pytest + +from sdv.metadata import Metadata, visualization + + +def test__get_graphviz_extension_path_without_extension(): + """Raises a ValueError when the path doesn't contains an extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path') + + +def test__get_graphviz_extension_invalid_extension(): + """Raises a ValueError when the path contains an invalid extension.""" + with pytest.raises(ValueError): + visualization._get_graphviz_extension('/some/path.foo') + + +def test__get_graphviz_extension_none(): + """Get graphviz with path equals to None.""" + # Run + result = visualization._get_graphviz_extension(None) + + # Asserts + assert result == (None, None) + + +def test__get_graphviz_extension_valid(): + """Get a valid graphviz extension.""" + # Run + result = visualization._get_graphviz_extension('/some/path.png') + + # Asserts + assert result == ('/some/path', 'png') + + +def test__add_nodes(): + """Add nodes into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + minimock = Mock() + + # pass tests in python3.5 + minimock.items.return_value = ( + ('a_field', {'type': 'numerical', 'subtype': 'integer'}), + ('b_field', {'type': 'id'}), + ('c_field', {'type': 'id', 'ref': {'table': 'other', 'field': 'pk_field'}}) + ) + + metadata.get_tables.return_value = ['demo'] + metadata.get_fields.return_value = minimock + + metadata.get_primary_key.return_value = 'b_field' + metadata.get_parents.return_value = set(['other']) + metadata.get_foreign_key.return_value = 'c_field' + + metadata.get_table_meta.return_value = {'path': None} + + plot = Mock() + + # Run + visualization._add_nodes(metadata, plot) + + # Asserts + expected_node_label = r"{demo|a_field : numerical - integer\lb_field : id\l" \ + r"c_field : id\l|Primary key: b_field\l" \ + r"Foreign key (other): c_field\l}" + + metadata.get_fields.assert_called_once_with('demo') + metadata.get_primary_key.assert_called_once_with('demo') + metadata.get_parents.assert_called_once_with('demo') + metadata.get_table_meta.assert_called_once_with('demo') + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_table_meta.assert_called_once_with('demo') + + plot.node.assert_called_once_with('demo', label=expected_node_label) + + +def test__add_edges(): + """Add edges into a graphviz digraph.""" + # Setup + metadata = MagicMock(spec_set=Metadata) + + metadata.get_tables.return_value = ['demo', 'other'] + metadata.get_parents.side_effect = [set(['other']), set()] + + metadata.get_foreign_key.return_value = 'fk' + metadata.get_primary_key.return_value = 'pk' + + plot = Mock() + + # Run + visualization._add_edges(metadata, plot) + + # Asserts + expected_edge_label = ' {}.{} -> {}.{}'.format('demo', 'fk', 'other', 'pk') + + metadata.get_tables.assert_called_once_with() + metadata.get_foreign_key.assert_called_once_with('other', 'demo') + metadata.get_primary_key.assert_called_once_with('other') + assert metadata.get_parents.call_args_list == [call('demo'), call('other')] + + plot.edge.assert_called_once_with( + 'other', + 'demo', + label=expected_edge_label, + arrowhead='oinv' + ) + + +@patch('sdv.metadata.visualization.graphviz.Digraph', spec_set=graphviz.Digraph) +@patch('sdv.metadata.visualization._add_nodes') +@patch('sdv.metadata.visualization._add_edges') +def test_visualize(add_nodes_mock, add_edges_mock, digraph_mock): + """Metadata visualize digraph""" + # Setup + # plot = Mock(spec_set=graphviz.Digraph) + # graphviz_mock.Digraph.return_value = plot + + metadata = MagicMock(spec_set=Metadata) + + # Run + visualization.visualize(metadata, path='output.png') + + # Asserts + digraph = digraph_mock.return_value + add_nodes_mock.assert_called_once_with(metadata, digraph) + add_edges_mock.assert_called_once_with(metadata, digraph) + + digraph.render.assert_called_once_with(filename='output', cleanup=True, format='png') diff --git a/tutorials/01_Quickstart.ipynb b/tutorials/01_Quickstart.ipynb new file mode 100644 index 000000000..5eb6cd8e8 --- /dev/null +++ b/tutorials/01_Quickstart.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this short tutorial we will guide you through a series of steps that will help you\n", + "getting started using **SDV**.\n", + "\n", + "## 1. Model the dataset using SDV\n", + "\n", + "To model a multi table, relational dataset, we follow two steps. In the first step, we will load\n", + "the data and configures the meta data. In the second step, we will use the sdv API to fit and\n", + "save a hierarchical model. We will cover these two steps in this section using an example dataset.\n", + "\n", + "### Step 1: Load example data\n", + "\n", + "**SDV** comes with a toy dataset to play with, which can be loaded using the `sdv.load_demo`\n", + "function:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return two objects:\n", + "\n", + "1. A `Metadata` object with all the information that **SDV** needs to know about the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "user_id : id - integer\n", + "country : categorical\n", + "gender : categorical\n", + "age : numerical - integer\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "session_id : id - integer\n", + "user_id : id - integer\n", + "device : categorical\n", + "os : categorical\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "transaction_id : id - integer\n", + "session_id : id - integer\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "approved : boolean\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details about how to build the `Metadata` for your own dataset, please refer to the\n", + "[Metadata](https://sdv-dev.github.io/SDV/metadata.html) section of the documentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. A dictionary containing three `pandas.DataFrames` with the tables described in the\n", + "metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Fit a model using the SDV API.\n", + "\n", + "First, we build a hierarchical statistical model of the data using **SDV**. For this we will\n", + "create an instance of the `sdv.SDV` class and use its `fit` method.\n", + "\n", + "During this process, **SDV** will traverse across all the tables in your dataset following the\n", + "primary key-foreign key relationships and learn the probability distributions of the values in\n", + "the columns." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 20:57:19,919 - INFO - modeler - Modeling users\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 20:57:19,920 - INFO - __init__ - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 20:57:19,921 - INFO - __init__ - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 20:57:19,933 - INFO - modeler - Modeling sessions\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field device\n", + "2020-07-09 20:57:19,934 - INFO - __init__ - Loading transformer CategoricalTransformer for field os\n", + "2020-07-09 20:57:19,944 - INFO - modeler - Modeling transactions\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer DatetimeTransformer for field timestamp\n", + "2020-07-09 20:57:19,944 - INFO - __init__ - Loading transformer NumericalTransformer for field amount\n", + "2020-07-09 20:57:19,945 - INFO - __init__ - Loading transformer BooleanTransformer for field approved\n", + "2020-07-09 20:57:19,954 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,962 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 20:57:19,968 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/pandas/core/frame.py:7143: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " baseCov = np.cov(mat.T)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: divide by zero encountered in true_divide\n", + " c *= np.true_divide(1, fact)\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/lib/function_base.py:2480: RuntimeWarning: invalid value encountered in multiply\n", + " c *= np.true_divide(1, fact)\n", + "2020-07-09 20:57:19,974 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,979 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,985 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,989 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:19,994 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,007 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,062 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,074 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,092 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,104 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,117 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,129 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,146 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,208 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 20:57:20,294 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data using the `sdv` instance that you have.\n", + "\n", + "For this, all you have to do is call the `sample_all` method from your instance passing the number of rows that you want to generate:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a dictionary identical to the `tables` one that we passed to the SDV instance for learning, filled in with new synthetic data.\n", + "\n", + "**Note** that only the parent tables of your dataset will have the specified number of rows,\n", + "as the number of child rows that each row in the parent table has is also sampled following\n", + "the original distribution of your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKNaN39
11UKF27
22UKNaN50
33USAF29
44UKNaN21
55ESNaN27
66UKF20
77USAF15
88DEF41
99ESF35
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK NaN 39\n", + "1 1 UK F 27\n", + "2 2 UK NaN 50\n", + "3 3 USA F 29\n", + "4 4 UK NaN 21\n", + "5 5 ES NaN 27\n", + "6 6 UK F 20\n", + "7 7 USA F 15\n", + "8 8 DE F 41\n", + "9 9 ES F 35" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['users']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
session_iduser_iddeviceos
001mobileandroid
112tabletios
222tabletios
332mobileandroid
443mobileandroid
555mobileandroid
665mobileandroid
775tabletios
886tabletios
996mobileandroid
\n", + "
" + ], + "text/plain": [ + " session_id user_id device os\n", + "0 0 1 mobile android\n", + "1 1 2 tablet ios\n", + "2 2 2 tablet ios\n", + "3 3 2 mobile android\n", + "4 4 3 mobile android\n", + "5 5 5 mobile android\n", + "6 6 5 mobile android\n", + "7 7 5 tablet ios\n", + "8 8 6 tablet ios\n", + "9 9 6 mobile android" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['sessions'].head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
transaction_idsession_idtimestampamountapproved
0002019-01-13 01:11:06.08427596880.350553True
1102019-01-13 01:11:06.36870425682.116470True
2212019-01-25 15:05:07.73526297697.616590True
3312019-01-25 15:05:09.09895884896.505258True
4422019-01-25 15:05:09.12423142497.235452True
5522019-01-25 15:05:07.34731545696.615923True
6632019-01-25 15:05:09.19044198497.173195True
7732019-01-25 15:05:07.43791078496.995037True
8802019-01-13 01:11:06.56094003281.027625True
9912019-01-25 15:05:12.42074649697.340054True
\n", + "
" + ], + "text/plain": [ + " transaction_id session_id timestamp amount \\\n", + "0 0 0 2019-01-13 01:11:06.084275968 80.350553 \n", + "1 1 0 2019-01-13 01:11:06.368704256 82.116470 \n", + "2 2 1 2019-01-25 15:05:07.735262976 97.616590 \n", + "3 3 1 2019-01-25 15:05:09.098958848 96.505258 \n", + "4 4 2 2019-01-25 15:05:09.124231424 97.235452 \n", + "5 5 2 2019-01-25 15:05:07.347315456 96.615923 \n", + "6 6 3 2019-01-25 15:05:09.190441984 97.173195 \n", + "7 7 3 2019-01-25 15:05:07.437910784 96.995037 \n", + "8 8 0 2019-01-13 01:11:06.560940032 81.027625 \n", + "9 9 1 2019-01-25 15:05:12.420746496 97.340054 \n", + "\n", + " approved \n", + "0 True \n", + "1 True \n", + "2 True \n", + "3 True \n", + "4 True \n", + "5 True \n", + "6 True \n", + "7 True \n", + "8 True \n", + "9 True " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled['transactions'].head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving and Loading your model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In some cases, you might want to save the fitted SDV instance to be able to generate synthetic data from\n", + "it later or on a different system.\n", + "\n", + "In order to do so, you can save your fitted `SDV` instance for later usage using the `save` method of your\n", + "instance." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sdv.save('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated `pkl` file will not include any of the original data in it, so it can be\n", + "safely sent to where the synthetic data will be generated without any privacy concerns." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Later on, in order to sample data from the fitted model, we will first need to load it from its\n", + "`pkl` file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "sdv = SDV.load('sdv.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After loading the instance, we can sample synthetic data using its `sample_all` method like before." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = sdv.sample_all(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/02_Single_Table_Modeling.ipynb b/tutorials/02_Single_Table_Modeling.ipynb new file mode 100644 index 000000000..fa536bd3d --- /dev/null +++ b/tutorials/02_Single_Table_Modeling.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single Table Modeling\n", + "\n", + "**SDV** has special support for modeling single table datasets using a variety of models.\n", + "\n", + "Currently, SDV implements:\n", + "\n", + "* GaussianCopula: A tool to model multivariate distributions using [copula functions](https://en.wikipedia.org/wiki/Copula_%28probability_theory%29). Based on our [Copulas Library](https://github.com/sdv-dev/Copulas).\n", + "* CTGAN: A GAN-based Deep Learning data synthesizer that can generate synthetic tabular data with high fidelity. Based on our [CTGAN Library](https://github.com/sdv-dev/CTGAN)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GaussianCopula\n", + "\n", + "In this first part of the tutorial we will be using the GaussianCopula class to model the `users` table\n", + "from the toy dataset included in the **SDV** library.\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "users = load_demo()['users']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with 4 fields:\n", + "\n", + "* `user_id`: A unique identifier of the user.\n", + "* `country`: A 2 letter code of the country of residence of the user.\n", + "* `gender`: A single letter code, `M` or `F`, indicating the user gender. Note that this demo simulates the case where some users did not indicate the gender, which resulted in empty data values in some rows.\n", + "* `age`: The age of the user, in years." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM34
11UKF23
22ESNone44
33UKM22
44USAF54
55DEM57
66BGF45
77ESNone41
88FRF23
99UKNone30
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 34\n", + "1 1 UK F 23\n", + "2 2 ES None 44\n", + "3 3 UK M 22\n", + "4 4 USA F 54\n", + "5 5 DE M 57\n", + "6 6 BG F 45\n", + "7 7 ES None 41\n", + "8 8 FR F 23\n", + "9 9 UK None 30" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "users" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to properly model our data we will need to provide some additional information to our model,\n", + "so let's prepare this information in some variables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's indicate that the `user_id` field in our table is the primary key, so we do not want our\n", + "model to attempt to learn it." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "primary_key = 'user_id'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will also want to anonymize the countries of residence of our users, to avoid disclosing such information.\n", + "Let's make a variable indicating that the `country` field needs to be anonymized using fake `country_codes`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "anonymize_fileds = {\n", + " 'country': 'contry_code'\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The full list of categories supported corresponds to the `Faker` library\n", + "[provider names](https://faker.readthedocs.io/en/master/providers.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have prepared the arguments for our model we are ready to import it, create an instance\n", + "and fit it to our data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:32,974 - INFO - table - Loading transformer CategoricalTransformer for field country\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer CategoricalTransformer for field gender\n", + "2020-07-09 21:18:32,975 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:32,991 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + } + ], + "source": [ + "from sdv.tabular import GaussianCopula\n", + "\n", + "model = GaussianCopula(\n", + " primary_key=primary_key,\n", + " anonymize_fileds=anonymize_fileds\n", + ")\n", + "model.fit(users)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Notice** how the model took care of transforming the different fields using the appropriate\n", + "Reversible Data Transforms to ensure that the data has a format that the GaussianMultivariate model\n", + "from the [copulas](https://github.com/sdv-dev/Copulas) library can handle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00USAM38
11UKNaN23
22USAF34
33ESNaN47
44ESF29
55UKF39
66FRNaN40
77ESM38
88ESF32
99ESF36
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 USA M 38\n", + "1 1 UK NaN 23\n", + "2 2 USA F 34\n", + "3 3 ES NaN 47\n", + "4 4 ES F 29\n", + "5 5 UK F 39\n", + "6 6 FR NaN 40\n", + "7 7 ES M 38\n", + "8 8 ES F 32\n", + "9 9 ES F 36" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": false + }, + "source": [ + "Notice, as well that the number of rows generated by default corresponds to the number of rows that\n", + "the original table had, but that this number can be changed by simply passing it:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idcountrygenderage
00UKF48
11USANaN38
22USAM29
33BGM22
44USAM43
\n", + "
" + ], + "text/plain": [ + " user_id country gender age\n", + "0 0 UK F 48\n", + "1 1 USA NaN 38\n", + "2 2 USA M 29\n", + "3 3 BG M 22\n", + "4 4 USA M 43" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTGAN\n", + "\n", + "In this second part of the tutorial we will be using the CTGAN model to learn the data from the\n", + "demo dataset called `census`, which is based on the [UCI Adult Census Dataset]('https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data').\n", + "\n", + "### 1. Load the Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,085 - INFO - __init__ - Loading table census\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "census = load_demo('census')['census']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table with several rows of multiple data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "\n", + " marital-status occupation relationship race sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "\n", + " capital-gain capital-loss hours-per-week native-country income \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "census.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case there is no primary key to setup and we will not be anonymizing anything, so the only\n", + "thing that we will pass to the CTGAN model is the number of epochs that we want it to perform when\n", + "it leanrs the data, which we will keep low to make this execution quick." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.utils.testing module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.utils. Anything that cannot be imported from sklearn.utils is now part of the private API.\n", + " warnings.warn(message, FutureWarning)\n" + ] + } + ], + "source": [ + "from sdv.tabular import CTGAN\n", + "\n", + "model = CTGAN(epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the instance is created, we can fit it to our data. Bear in mind that this process might take some\n", + "time to finish, especially on non-GPU enabled systems, so in this case we will be passing only a\n", + "subsample of the data to accelerate the process." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:18:33,488 - INFO - table - Loading transformer NumericalTransformer for field age\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer LabelEncodingTransformer for field workclass\n", + "2020-07-09 21:18:33,489 - INFO - table - Loading transformer NumericalTransformer for field fnlwgt\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field education\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer NumericalTransformer for field education-num\n", + "2020-07-09 21:18:33,490 - INFO - table - Loading transformer LabelEncodingTransformer for field marital-status\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field occupation\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field relationship\n", + "2020-07-09 21:18:33,491 - INFO - table - Loading transformer LabelEncodingTransformer for field race\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer LabelEncodingTransformer for field sex\n", + "2020-07-09 21:18:33,492 - INFO - table - Loading transformer NumericalTransformer for field capital-gain\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field capital-loss\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer NumericalTransformer for field hours-per-week\n", + "2020-07-09 21:18:33,493 - INFO - table - Loading transformer LabelEncodingTransformer for field native-country\n", + "2020-07-09 21:18:33,494 - INFO - table - Loading transformer LabelEncodingTransformer for field income\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss G: 1.9512, Loss D: -0.0182\n", + "Epoch 2, Loss G: 1.9884, Loss D: -0.0663\n", + "Epoch 3, Loss G: 1.9710, Loss D: -0.1339\n", + "Epoch 4, Loss G: 1.8960, Loss D: -0.2061\n", + "Epoch 5, Loss G: 1.9155, Loss D: -0.3062\n", + "Epoch 6, Loss G: 1.9699, Loss D: -0.3906\n", + "Epoch 7, Loss G: 1.8614, Loss D: -0.5142\n", + "Epoch 8, Loss G: 1.8446, Loss D: -0.6448\n", + "Epoch 9, Loss G: 1.7619, Loss D: -0.7488\n", + "Epoch 10, Loss G: 1.6732, Loss D: -0.7961\n" + ] + } + ], + "source": [ + "model.fit(census.sample(1000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Sample data from the fitted model\n", + "\n", + "Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method\n", + "from our model just like we did with the GaussianCopula model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sampled = model.sample()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will return a table identical to the one which the model was fitted on, but filled with new data\n", + "which resembles the original one." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryincome
050Local-gov1697191st-4th9Widowed?HusbandWhiteMale114838Columbia<=50K
132?1524791st-4th9Never-marriedAdm-clericalWifeBlackMale-422021Jamaica>50K
222Private69617Bachelors0Separated?HusbandWhiteMale61138Guatemala<=50K
325?65285810th16Married-civ-spouseHandlers-cleanersNot-in-familyWhiteFemale152-2739Cuba<=50K
443Private301956Some-college8Married-civ-spouse?WifeWhiteMale-133-1239India<=50K
566Private401171Prof-school13SeparatedProtective-servUnmarriedBlackFemale-124-140Cuba<=50K
652Private278399Bachelors12Never-marriedProf-specialtyUnmarriedOtherMale122567-647Columbia<=50K
736Federal-gov229817HS-grad8Married-AF-spouseFarming-fishingNot-in-familyWhiteMale81938Portugal>50K
827Federal-gov306972Some-college8Never-marriedExec-managerialHusbandAsian-Pac-IslanderFemale42144339Japan>50K
928Local-gov4161611st-4th8DivorcedAdm-clericalUnmarriedWhiteFemale-349109061Guatemala>50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education education-num \\\n", + "0 50 Local-gov 169719 1st-4th 9 \n", + "1 32 ? 152479 1st-4th 9 \n", + "2 22 Private 69617 Bachelors 0 \n", + "3 25 ? 652858 10th 16 \n", + "4 43 Private 301956 Some-college 8 \n", + "5 66 Private 401171 Prof-school 13 \n", + "6 52 Private 278399 Bachelors 12 \n", + "7 36 Federal-gov 229817 HS-grad 8 \n", + "8 27 Federal-gov 306972 Some-college 8 \n", + "9 28 Local-gov 416161 1st-4th 8 \n", + "\n", + " marital-status occupation relationship \\\n", + "0 Widowed ? Husband \n", + "1 Never-married Adm-clerical Wife \n", + "2 Separated ? Husband \n", + "3 Married-civ-spouse Handlers-cleaners Not-in-family \n", + "4 Married-civ-spouse ? Wife \n", + "5 Separated Protective-serv Unmarried \n", + "6 Never-married Prof-specialty Unmarried \n", + "7 Married-AF-spouse Farming-fishing Not-in-family \n", + "8 Never-married Exec-managerial Husband \n", + "9 Divorced Adm-clerical Unmarried \n", + "\n", + " race sex capital-gain capital-loss hours-per-week \\\n", + "0 White Male 114 8 38 \n", + "1 Black Male -42 20 21 \n", + "2 White Male 6 11 38 \n", + "3 White Female 152 -27 39 \n", + "4 White Male -133 -12 39 \n", + "5 Black Female -124 -1 40 \n", + "6 Other Male 122567 -6 47 \n", + "7 White Male 8 19 38 \n", + "8 Asian-Pac-Islander Female 42144 3 39 \n", + "9 White Female -349 1090 61 \n", + "\n", + " native-country income \n", + "0 Columbia <=50K \n", + "1 Jamaica >50K \n", + "2 Guatemala <=50K \n", + "3 Cuba <=50K \n", + "4 India <=50K \n", + "5 Cuba <=50K \n", + "6 Columbia <=50K \n", + "7 Portugal >50K \n", + "8 Japan >50K \n", + "9 Guatemala >50K " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampled.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate how good the data is\n", + "\n", + "Finally, we will use the evaluation framework included in SDV to obtain a metric of how\n", + "similar the sampled data is to the original one.\n", + "\n", + "For this, we will simply import the `sdv.evaluation.evaluate` function and pass both\n", + "the synthetic and the real data to it." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-144.971907591418" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from sdv.evaluation import evaluate\n", + "\n", + "evaluate(sampled, census)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/03_Relational_Data_Modeling.ipynb b/tutorials/03_Relational_Data_Modeling.ipynb new file mode 100644 index 000000000..230061b00 --- /dev/null +++ b/tutorials/03_Relational_Data_Modeling.ipynb @@ -0,0 +1,1039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Relational Data Modeling\n", + "\n", + "In this tutorial we will be showing how to model a real world multi-table dataset using SDV.\n", + "\n", + "## About the datset\n", + "\n", + "We have a store series, each of those have a size and a category and additional information in a given date: average temperature in the region, cost of fuel in the region, promotional data, the customer price index, the unemployment rate and whether the date is a special holiday.\n", + "\n", + "From those stores we obtained a training of historical data \n", + "between 2010-02-05 and 2012-11-01. This historical data includes the sales of each department on a specific date.\n", + "In this notebook, we will show you step-by-step how to download the \"Walmart\" dataset, explain the structure and sample the data.\n", + "\n", + "In this demonstration we will show how SDV can be used to generate synthetic data. And lately, this data can be used to train machine learning models.\n", + "\n", + "*The dataset used in this example can be found in [Kaggle](https://www.kaggle.com/c/walmart-recruiting-store-sales-forecasting/data), but we will show how to download it from SDV.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data model summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

stores

\n", + "\n", + "| Field | Type | Subtype | Additional Properties |\n", + "|-------|-------------|---------|-----------------------|\n", + "| Store | id | integer | Primary key |\n", + "| Size | numerical | integer | |\n", + "| Type | categorical | | |\n", + "\n", + "Contains information about the 45 stores, indicating the type and size of store." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

features

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|-----------------------------|\n", + "| Store | id | integer | foreign key (stores.Store) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| IsHoliday | boolean | | |\n", + "| Fuel_Price | numerical | float | |\n", + "| Unemployment | numerical | float | |\n", + "| Temperature | numerical | float | |\n", + "| CPI | numerical | float | |\n", + "| MarkDown1 | numerical | float | |\n", + "| MarkDown2 | numerical | float | |\n", + "| MarkDown3 | numerical | float | |\n", + "| MarkDown4 | numerical | float | |\n", + "| MarkDown5 | numerical | float | |\n", + "\n", + "Contains historical training data, which covers to 2010-02-05 to 2012-11-01." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

depts

\n", + "\n", + "| Fields | Type | Subtype | Additional Properties |\n", + "|--------------|-----------|---------|------------------------------|\n", + "| Store | id | integer | foreign key (stores.Stores) |\n", + "| Date | datetime | | format: \"%Y-%m-%d\" |\n", + "| Weekly_Sales | numerical | float | |\n", + "| Dept | numerical | integer | |\n", + "| IsHoliday | boolean | | |\n", + "\n", + "Contains additional data related to the store, department, and regional activity for the given dates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load data\n", + "\n", + "Let's start downloading the data set. In this case, we will download the data set *walmart*. We will use the SDV function `load_demo`, we can specify the name of the dataset we want to use and if we want its Metadata object or not. To know more about the demo data [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.demo.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:17,378 - INFO - __init__ - Loading table stores\n", + "2020-07-09 21:00:17,384 - INFO - __init__ - Loading table features\n", + "2020-07-09 21:00:17,402 - INFO - __init__ - Loading table depts\n" + ] + } + ], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(dataset_name='walmart', metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our dataset is downloaded from an [Amazon S3 bucket](http://sdv-datasets.s3.amazonaws.com/index.html) that contains all available data sets of the `load_demo` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now visualize the metadata structure" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "stores\n", + "\n", + "stores\n", + "\n", + "Type : categorical\n", + "Size : numerical - integer\n", + "Store : id - integer\n", + "\n", + "Primary key: Store\n", + "Data path: stores.csv\n", + "\n", + "\n", + "\n", + "features\n", + "\n", + "features\n", + "\n", + "Date : datetime\n", + "MarkDown1 : numerical - float\n", + "Store : id - integer\n", + "IsHoliday : boolean\n", + "MarkDown4 : numerical - float\n", + "MarkDown3 : numerical - float\n", + "Fuel_Price : numerical - float\n", + "Unemployment : numerical - float\n", + "Temperature : numerical - float\n", + "MarkDown5 : numerical - float\n", + "MarkDown2 : numerical - float\n", + "CPI : numerical - float\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: features.csv\n", + "\n", + "\n", + "\n", + "stores->features\n", + "\n", + "\n", + "   features.Store -> stores.Store\n", + "\n", + "\n", + "\n", + "depts\n", + "\n", + "depts\n", + "\n", + "Date : datetime\n", + "Weekly_Sales : numerical - float\n", + "Store : id - integer\n", + "Dept : numerical - integer\n", + "IsHoliday : boolean\n", + "\n", + "Foreign key (stores): Store\n", + "Data path: train.csv\n", + "\n", + "\n", + "\n", + "stores->depts\n", + "\n", + "\n", + "   depts.Store -> stores.Store\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And also validate that the metadata is correctly defined for our data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate(tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Create an instance of SDV and train the instance\n", + "\n", + "Once we download it, we have to create an SDV instance. With that instance, we have to analyze the loaded tables to generate a statistical model from the data. In this case, the process of adjusting the model is quickly because the dataset is small. However, with larger datasets it can be a slow process." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:31,480 - INFO - modeler - Modeling stores\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer CategoricalTransformer for field Type\n", + "2020-07-09 21:00:31,481 - INFO - __init__ - Loading transformer NumericalTransformer for field Size\n", + "2020-07-09 21:00:31,491 - INFO - modeler - Modeling features\n", + "2020-07-09 21:00:31,492 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown1\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:31,493 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown4\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown3\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Fuel_Price\n", + "2020-07-09 21:00:31,494 - INFO - __init__ - Loading transformer NumericalTransformer for field Unemployment\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field Temperature\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown5\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field MarkDown2\n", + "2020-07-09 21:00:31,495 - INFO - __init__ - Loading transformer NumericalTransformer for field CPI\n", + "2020-07-09 21:00:31,544 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,595 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/Projects/MIT/SDV/sdv/models/copulas.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", + " self.model.covariance = np.array(values)\n", + "2020-07-09 21:00:31,651 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,679 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,707 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,734 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,790 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,816 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,845 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,872 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,901 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,931 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,959 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:31,986 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,014 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,040 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,070 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,096 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,123 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,152 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,181 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,209 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,235 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,264 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,293 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,322 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,349 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,376 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,405 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,433 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,463 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,492 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,521 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,552 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,583 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,612 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,644 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,674 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,704 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,732 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,762 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,791 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,821 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,852 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,882 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:32,917 - INFO - modeler - Modeling depts\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer DatetimeTransformer for field Date\n", + "2020-07-09 21:00:32,918 - INFO - __init__ - Loading transformer NumericalTransformer for field Weekly_Sales\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer NumericalTransformer for field Dept\n", + "2020-07-09 21:00:32,919 - INFO - __init__ - Loading transformer BooleanTransformer for field IsHoliday\n", + "2020-07-09 21:00:33,016 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,318 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,334 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,350 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,364 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,381 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,396 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,412 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,428 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-09 21:00:33,447 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,464 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,479 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,495 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,511 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,526 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,541 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,554 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,567 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,582 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,596 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,609 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,624 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,639 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,652 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,667 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,682 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,696 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,709 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,724 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,740 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,753 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,766 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,779 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,792 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,803 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,818 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,831 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,842 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,856 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,870 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,885 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,898 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,912 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,926 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,937 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:33,949 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "2020-07-09 21:00:34,047 - INFO - gaussian - Fitting GaussianMultivariate(distribution=\"GaussianUnivariate\")\n", + "/home/xals/.virtualenvs/SDV/lib/python3.6/site-packages/numpy/core/fromnumeric.py:58: FutureWarning: Series.nonzero() is deprecated and will be removed in a future version.Use Series.to_numpy().nonzero() instead\n", + " return bound(*args, **kwds)\n", + "2020-07-09 21:00:34,259 - INFO - modeler - Modeling Complete\n" + ] + } + ], + "source": [ + "from sdv import SDV\n", + "\n", + "sdv = SDV()\n", + "sdv.fit(metadata, tables=tables)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: We may not want to train the model every time we want to generate new synthetic data. We can [save](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) the SDV instance to [load](https://sdv-dev.github.io/SDV/api/sdv.sdv.html#sdv.sdv.SDV.save) it later." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Generate synthetic data\n", + "\n", + "Once the instance is trained, we are ready to generate the synthetic data.\n", + "\n", + "The easiest way to generate synthetic data for the entire dataset is to call the `sample_all` method. By default, this method generates only 5 rows, but we can specify the row number that will be generated with the `num_rows` argument. To learn more about the available arguments, see [sample_all](https://sdv-dev.github.io/SDV/api/sdv.sampler.html#sdv.sampler.Sampler.sample_all)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'stores': 45, 'features': 8190, 'depts': 421570}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.modeler.table_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "samples = sdv.sample_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This returns a dictionary with a `pandas.DataFrame` for each table." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TypeSizeStore
0B1061060
1B680711
2C2536032
3A2232683
4B1668484
\n", + "
" + ], + "text/plain": [ + " Type Size Store\n", + "0 B 106106 0\n", + "1 B 68071 1\n", + "2 C 253603 2\n", + "3 A 223268 3\n", + "4 B 166848 4" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['stores'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateMarkDown1StoreIsHolidayMarkDown4MarkDown3Fuel_PriceUnemploymentTemperatureMarkDown5MarkDown2CPI
02009-12-13 04:59:49.832576000222.2688220False758.900545-1080.5694354.5749666.05279551.354606-4016.012141NaN224.973722
12011-07-20 14:40:40.4002447366857.7120280FalseNaN-551.0452573.0182966.72066980.723862NaNNaN204.595072
22011-10-12 08:23:54.9876582405530.9294250FalseNaNNaN3.4663587.25657073.954938NaNNaN202.591941
32012-11-26 09:22:59.377291776NaN0FalseNaN3703.1371913.6237256.59861572.041454NaN3925.456611169.457078
42013-05-13 15:22:12.213717760NaN0False3773.942572NaN3.3817955.11500283.05832215018.944605NaN182.801974
\n", + "
" + ], + "text/plain": [ + " Date MarkDown1 Store IsHoliday MarkDown4 \\\n", + "0 2009-12-13 04:59:49.832576000 222.268822 0 False 758.900545 \n", + "1 2011-07-20 14:40:40.400244736 6857.712028 0 False NaN \n", + "2 2011-10-12 08:23:54.987658240 5530.929425 0 False NaN \n", + "3 2012-11-26 09:22:59.377291776 NaN 0 False NaN \n", + "4 2013-05-13 15:22:12.213717760 NaN 0 False 3773.942572 \n", + "\n", + " MarkDown3 Fuel_Price Unemployment Temperature MarkDown5 \\\n", + "0 -1080.569435 4.574966 6.052795 51.354606 -4016.012141 \n", + "1 -551.045257 3.018296 6.720669 80.723862 NaN \n", + "2 NaN 3.466358 7.256570 73.954938 NaN \n", + "3 3703.137191 3.623725 6.598615 72.041454 NaN \n", + "4 NaN 3.381795 5.115002 83.058322 15018.944605 \n", + "\n", + " MarkDown2 CPI \n", + "0 NaN 224.973722 \n", + "1 NaN 204.595072 \n", + "2 NaN 202.591941 \n", + "3 3925.456611 169.457078 \n", + "4 NaN 182.801974 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['features'].head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
DateWeekly_SalesStoreDeptIsHoliday
02012-04-19 04:40:04.3703828487523.202470036False
12012-10-24 12:49:45.63017651213675.425498017False
22012-06-05 05:02:35.5517614082402.017324031False
32012-05-31 22:15:14.903856896-12761.073176081False
42011-09-07 15:34:43.2398356483642.817612058False
\n", + "
" + ], + "text/plain": [ + " Date Weekly_Sales Store Dept IsHoliday\n", + "0 2012-04-19 04:40:04.370382848 7523.202470 0 36 False\n", + "1 2012-10-24 12:49:45.630176512 13675.425498 0 17 False\n", + "2 2012-06-05 05:02:35.551761408 2402.017324 0 31 False\n", + "3 2012-05-31 22:15:14.903856896 -12761.073176 0 81 False\n", + "4 2011-09-07 15:34:43.239835648 3642.817612 0 58 False" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples['depts'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may not want to generate data for all tables in the dataset, rather for just one table. This is possible with SDV using the `sample` method. To use it we only need to specify the name of the table we want to synthesize and the row numbers to generate. In this case, the \"walmart\" data set has 3 tables: stores, features and depts.\n", + "\n", + "In the following example, we will generate 1000 rows of the \"features\" table." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'features': Date MarkDown1 Store IsHoliday \\\n", + " 0 2010-07-07 18:33:30.490365440 NaN 46 True \n", + " 1 2010-05-31 05:12:43.555245568 NaN 49 False \n", + " 2 2009-11-13 00:11:26.645775872 NaN 46 False \n", + " 3 2011-06-26 13:37:23.832223232 8480.251096 48 False \n", + " 4 2013-09-30 06:14:18.804104192 4586.374804 48 False \n", + " 5 2011-03-18 21:50:55.521650432 NaN 45 False \n", + " 6 2012-03-22 14:59:09.189622016 9747.212668 46 False \n", + " 7 2011-04-05 16:45:11.282816000 NaN 48 False \n", + " 8 2013-05-05 14:23:57.431958784 5430.557501 46 False \n", + " 9 2011-04-07 02:38:10.873993984 6735.454863 49 False \n", + " 10 2012-07-09 15:59:11.863023616 11818.737600 48 False \n", + " 11 2012-03-10 03:57:22.968428544 4603.635257 48 False \n", + " 12 2011-10-01 03:51:05.280273920 14596.976258 49 False \n", + " 13 2012-06-24 23:45:06.501742080 5816.163902 49 False \n", + " 14 2011-08-06 14:43:48.366935040 NaN 47 False \n", + " 15 2012-06-14 01:06:45.771771648 NaN 45 False \n", + " 16 2011-11-22 08:57:27.766619392 NaN 45 False \n", + " 17 2012-10-06 19:10:00.039407360 5887.384577 48 False \n", + " 18 2013-06-08 07:46:59.612567808 -5770.051172 48 False \n", + " 19 2012-05-27 11:08:47.184699648 6908.122060 46 False \n", + " 20 2010-10-19 21:42:09.919441408 NaN 46 False \n", + " 21 2013-01-01 17:57:15.273558784 6221.151492 48 False \n", + " 22 2012-07-13 14:40:51.181466112 1378.006687 47 False \n", + " 23 2012-02-13 18:37:30.161656320 NaN 48 False \n", + " 24 2010-11-27 09:56:34.540603392 NaN 46 False \n", + " 25 2011-05-21 03:49:37.053693440 NaN 49 False \n", + " 26 2010-06-12 20:31:32.339124224 NaN 49 True \n", + " 27 2011-08-12 21:10:37.406385152 NaN 47 False \n", + " 28 2012-02-07 03:31:37.558127360 NaN 47 False \n", + " 29 2011-05-28 22:39:00.441712128 NaN 48 False \n", + " .. ... ... ... ... \n", + " 970 2012-10-01 09:27:33.224137984 7367.091411 49 False \n", + " 971 2011-08-05 19:20:26.151251200 NaN 49 False \n", + " 972 2010-05-02 10:25:48.294526720 NaN 49 False \n", + " 973 2014-07-30 06:33:25.027884800 2869.286369 46 False \n", + " 974 2012-06-10 23:50:28.918937344 NaN 45 False \n", + " 975 2011-11-20 17:24:25.076207104 NaN 48 False \n", + " 976 2011-07-16 19:31:09.475053312 8465.620591 48 False \n", + " 977 2010-12-17 11:17:29.377529600 NaN 45 False \n", + " 978 2011-05-01 22:00:24.413698816 NaN 48 True \n", + " 979 2009-01-04 14:06:28.700405504 NaN 48 False \n", + " 980 2013-01-26 23:44:47.124111360 7711.715123 47 False \n", + " 981 2012-02-05 15:58:21.306519040 4771.829938 49 False \n", + " 982 2012-12-25 20:25:33.794860544 9052.115357 45 False \n", + " 983 2011-03-11 05:15:33.422215424 NaN 48 False \n", + " 984 2012-04-10 15:14:42.299491072 11469.698722 49 False \n", + " 985 2012-09-04 11:56:04.719035648 11985.730540 47 False \n", + " 986 2010-08-12 06:13:43.472920320 NaN 47 False \n", + " 987 2011-11-28 00:46:09.219015936 NaN 47 False \n", + " 988 2010-11-01 19:46:13.442969344 NaN 47 False \n", + " 989 2012-09-03 02:46:56.811711232 13448.706455 46 False \n", + " 990 2012-06-21 20:15:00.600609536 8515.644049 45 False \n", + " 991 2009-11-09 11:32:20.317559808 NaN 47 False \n", + " 992 2012-12-24 16:57:40.573141760 8785.735128 48 False \n", + " 993 2010-10-26 22:04:15.916646656 NaN 45 False \n", + " 994 2013-07-27 21:48:37.233461248 2218.444188 46 False \n", + " 995 2012-11-22 15:48:58.070053120 NaN 47 False \n", + " 996 2012-11-19 17:12:55.679613440 11318.264204 45 False \n", + " 997 2013-10-11 21:09:49.021492480 478.609430 47 False \n", + " 998 2013-01-23 23:17:20.493887232 14227.944357 49 False \n", + " 999 2012-06-23 00:01:59.594749696 9121.193148 46 False \n", + " \n", + " MarkDown4 MarkDown3 Fuel_Price Unemployment Temperature \\\n", + " 0 NaN NaN 3.246268 6.712596 72.778272 \n", + " 1 NaN NaN 2.961909 8.775273 31.911285 \n", + " 2 NaN NaN 2.474694 8.693548 36.996590 \n", + " 3 2990.430314 450.886030 3.180972 7.482533 82.689207 \n", + " 4 1300.901307 1038.142954 3.979948 3.842990 61.626723 \n", + " 5 NaN NaN 3.190374 9.185961 43.775213 \n", + " 6 5068.165070 4033.503158 3.436991 10.002234 44.980268 \n", + " 7 2639.207053 NaN 3.126477 8.921070 60.901225 \n", + " 8 -1515.448599 1859.931318 4.101750 10.478649 81.003354 \n", + " 9 NaN -4013.843645 3.200662 8.549170 46.332113 \n", + " 10 6983.561124 9246.461293 3.637923 6.940256 66.208533 \n", + " 11 NaN 1872.845519 3.535785 5.360209 76.842613 \n", + " 12 5967.615987 -2084.078449 3.314418 NaN 73.216988 \n", + " 13 4181.309429 -7382.929438 3.170917 7.209879 27.887139 \n", + " 14 NaN NaN 4.084598 10.976241 75.512964 \n", + " 15 6522.458652 -2969.500556 3.592465 6.778646 45.304799 \n", + " 16 NaN NaN 3.691225 9.169614 61.590909 \n", + " 17 2195.064778 10031.652043 3.790657 5.655427 38.507717 \n", + " 18 -2522.448316 4530.988378 3.349566 4.743290 36.629555 \n", + " 19 NaN 3416.997827 3.641454 8.177166 89.916452 \n", + " 20 NaN NaN 3.075063 5.229354 61.706563 \n", + " 21 4647.376520 3247.164388 4.162684 5.033966 37.962542 \n", + " 22 NaN NaN 3.384827 4.921576 54.039707 \n", + " 23 NaN 7039.242263 3.187745 9.768689 74.453345 \n", + " 24 NaN NaN 3.045024 6.954771 63.069450 \n", + " 25 NaN NaN 4.164028 8.876290 71.070532 \n", + " 26 NaN NaN 2.719009 7.230035 57.364249 \n", + " 27 NaN NaN 3.257133 9.345917 68.652152 \n", + " 28 NaN -4190.690106 3.548902 9.494165 89.145578 \n", + " 29 NaN NaN 3.316924 6.697705 77.375684 \n", + " .. ... ... ... ... ... \n", + " 970 2712.006496 -14.780688 3.618468 8.148930 30.483065 \n", + " 971 NaN NaN 3.012630 6.326113 47.975596 \n", + " 972 NaN NaN 2.841184 10.269877 81.946843 \n", + " 973 1937.104665 -1511.898278 4.759412 7.778287 30.662510 \n", + " 974 5253.689129 329.011494 3.343722 7.398721 21.129117 \n", + " 975 NaN 4930.780630 3.506178 9.701808 110.523151 \n", + " 976 5293.404197 -3121.976568 2.800094 5.843644 32.442600 \n", + " 977 9805.271519 NaN 3.136993 7.623560 72.572211 \n", + " 978 NaN NaN 2.714601 7.234621 52.291714 \n", + " 979 NaN NaN 2.043747 8.744450 46.628272 \n", + " 980 5415.328424 -4809.140910 3.287617 8.354515 61.384200 \n", + " 981 NaN NaN 3.337297 6.142083 36.829644 \n", + " 982 837.730915 4072.518501 3.894026 12.494251 94.673851 \n", + " 983 NaN NaN 3.420209 7.090230 47.694700 \n", + " 984 5134.594286 357.279124 3.325391 4.539604 37.460031 \n", + " 985 NaN 4056.157461 3.685698 7.230368 43.096815 \n", + " 986 3953.583585 805.676045 3.044291 7.884751 34.438372 \n", + " 987 4511.412982 NaN 3.394697 6.834802 36.636557 \n", + " 988 NaN NaN 3.154329 8.363893 49.751330 \n", + " 989 NaN 2084.191209 3.615679 9.596049 75.533229 \n", + " 990 4457.772828 218.985398 3.440906 8.526339 61.864859 \n", + " 991 NaN NaN 3.247303 7.423321 59.019784 \n", + " 992 NaN NaN 3.458168 6.459322 48.584436 \n", + " 993 NaN NaN 2.608670 10.259658 82.758016 \n", + " 994 719.333117 6631.998410 4.368128 7.588005 54.580465 \n", + " 995 3003.501922 NaN 3.800082 NaN 100.143258 \n", + " 996 NaN -681.708699 4.140327 10.717719 46.213930 \n", + " 997 735.529994 3856.959812 3.921162 7.439993 59.473702 \n", + " 998 2557.407176 2865.135441 3.971143 9.540309 67.952407 \n", + " 999 671.086509 2328.798014 3.487042 5.988222 63.846379 \n", + " \n", + " MarkDown5 MarkDown2 CPI \n", + " 0 NaN NaN 172.666543 \n", + " 1 NaN NaN 119.214662 \n", + " 2 NaN NaN 143.990866 \n", + " 3 4354.651570 5019.468642 180.631779 \n", + " 4 -4772.570563 2947.674138 173.813790 \n", + " 5 NaN NaN 177.052581 \n", + " 6 7083.723804 3513.601560 185.921725 \n", + " 7 NaN -180.888893 153.966086 \n", + " 8 6265.833733 2401.076738 170.559516 \n", + " 9 3759.801559 4729.764492 200.039874 \n", + " 10 6567.512481 -1022.527044 150.002644 \n", + " 11 7765.142655 NaN 221.169299 \n", + " 12 NaN 3308.681495 NaN \n", + " 13 2826.442468 8385.400703 150.777231 \n", + " 14 NaN NaN 211.070762 \n", + " 15 6642.564583 NaN 154.828636 \n", + " 16 NaN NaN 138.236111 \n", + " 17 4118.446321 NaN 138.853628 \n", + " 18 1115.215425 -1526.587205 219.621502 \n", + " 19 4994.790024 NaN 197.086249 \n", + " 20 NaN NaN 143.384357 \n", + " 21 9934.092076 6538.282816 208.758478 \n", + " 22 396.251330 1170.864752 217.807128 \n", + " 23 1492.226506 NaN 144.261943 \n", + " 24 NaN NaN 198.046465 \n", + " 25 NaN NaN 152.795139 \n", + " 26 NaN NaN 192.354502 \n", + " 27 1804.646340 NaN 213.883682 \n", + " 28 NaN NaN 237.955017 \n", + " 29 NaN NaN 140.091500 \n", + " .. ... ... ... \n", + " 970 -208.184350 4224.751927 149.278618 \n", + " 971 NaN 3814.404768 167.707496 \n", + " 972 NaN NaN 115.688179 \n", + " 973 1199.031857 3033.592607 120.703360 \n", + " 974 NaN -351.113110 106.143265 \n", + " 975 NaN NaN 211.109816 \n", + " 976 3975.213749 NaN 152.402312 \n", + " 977 NaN NaN 121.333484 \n", + " 978 NaN 3607.788296 176.680147 \n", + " 979 NaN NaN 164.924902 \n", + " 980 5431.606087 7005.735955 208.516910 \n", + " 981 -285.659902 NaN 146.865904 \n", + " 982 5584.625275 2016.204322 165.913603 \n", + " 983 NaN NaN 136.591286 \n", + " 984 2929.718397 11730.460834 153.875137 \n", + " 985 4618.823214 1218.336997 191.245195 \n", + " 986 NaN 128.922827 209.229013 \n", + " 987 7757.467609 NaN 201.908861 \n", + " 988 NaN NaN 210.724201 \n", + " 989 5115.758017 1620.240724 149.542233 \n", + " 990 1789.565202 1350.171000 193.810848 \n", + " 991 NaN NaN 147.307937 \n", + " 992 -105.978005 NaN 229.688010 \n", + " 993 NaN NaN 217.839779 \n", + " 994 -907.055933 3253.948843 128.459172 \n", + " 995 NaN NaN NaN \n", + " 996 3043.948542 NaN 114.508859 \n", + " 997 1459.810600 -4003.468566 193.149902 \n", + " 998 12151.265277 -3099.180177 150.590590 \n", + " 999 2792.671569 NaN 201.355736 \n", + " \n", + " [1000 rows x 12 columns]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sdv.sample('features', 1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SDV has tools to evaluate the synthetic data generated and compare them with the original data. To see more about the evaluation tools, [see the documentation](https://sdv-dev.github.io/SDV/api/sdv.evaluation.html#sdv.evaluation.evaluate)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/04_Working_with_Metadata.ipynb b/tutorials/04_Working_with_Metadata.ipynb new file mode 100644 index 000000000..5c79f8ebf --- /dev/null +++ b/tutorials/04_Working_with_Metadata.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Working with Metadata\n", + "\n", + "In order to work with complex dataset structures you will need to pass additional information\n", + "about you data to SDV.\n", + "\n", + "This is done by using a Metadata.\n", + "\n", + "Let's have a quick look at how to do it.\n", + "\n", + "## Load demo data\n", + "\n", + "We will load the demo dataset included in SDV for the rest of the session." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo\n", + "\n", + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A part from the tables dict, this is returning a Metadata object that contains all the information\n", + "about our dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Metadata\n", + " root_path: /home/xals/Projects/MIT/SDV/tutorials\n", + " tables: ['users', 'sessions', 'transactions']\n", + " relationships:\n", + " sessions.user_id -> users.user_id\n", + " transactions.session_id -> sessions.session_id" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Metadata can also be represented by using a dict object:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'primary_key': 'user_id',\n", + " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", + " 'sessions': {'primary_key': 'session_id',\n", + " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'device': {'type': 'categorical'},\n", + " 'os': {'type': 'categorical'}}},\n", + " 'transactions': {'primary_key': 'transaction_id',\n", + " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'approved': {'type': 'boolean'}}}}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a Metadata object from scratch\n", + "\n", + "In this section we will have a look at how to create a Metadata object from scratch.\n", + "\n", + "The simplest way to do it is by populating it passing the tables of your dataset together\n", + "with some additional information.\n", + "\n", + "Let's start by creating an empty metadata object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import Metadata\n", + "\n", + "new_meta = Metadata()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can start by adding the parent table from our dataset, `users`,\n", + "indicating that the primary key is the field called `user_id`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "users_data = tables['users']\n", + "new_meta.add_table('users', data=users_data, primary_key='user_id')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's add the sessions table, indicating that:\n", + "- The primary key is the field `session_id`\n", + "- The `users` table is parent to this table\n", + "- The relationship between the `users` and `sessions` table is created by the field called `user_id`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "sessions_data = tables['sessions']\n", + "new_meta.add_table(\n", + " 'sessions',\n", + " data=sessions_data,\n", + " primary_key='session_id',\n", + " parent='users',\n", + " foreign_key='user_id'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's add the transactions table.\n", + "\n", + "In this case, we will pass some additional information to indicate that\n", + "the `timestamp` field should be actually parsed and interpreted as a\n", + "datetime field." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "transactions_fields = {\n", + " 'timestamp': {\n", + " 'type': 'datetime',\n", + " 'format': '%Y-%m-%d'\n", + " }\n", + "}\n", + "transactions_data = tables['transactions']\n", + "new_meta.add_table(\n", + " 'transactions',\n", + " transactions_data,\n", + " fields_metadata=transactions_fields,\n", + " primary_key='transaction_id',\n", + " parent='sessions'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's see what our Metadata looks like right now:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Metadata\n", + "\n", + "\n", + "\n", + "users\n", + "\n", + "users\n", + "\n", + "gender : categorical\n", + "age : numerical - integer\n", + "user_id : id - integer\n", + "country : categorical\n", + "\n", + "Primary key: user_id\n", + "\n", + "\n", + "\n", + "sessions\n", + "\n", + "sessions\n", + "\n", + "os : categorical\n", + "device : categorical\n", + "user_id : id - integer\n", + "session_id : id - integer\n", + "\n", + "Primary key: session_id\n", + "Foreign key (users): user_id\n", + "\n", + "\n", + "\n", + "users->sessions\n", + "\n", + "\n", + "   sessions.user_id -> users.user_id\n", + "\n", + "\n", + "\n", + "transactions\n", + "\n", + "transactions\n", + "\n", + "timestamp : datetime\n", + "amount : numerical - float\n", + "session_id : id - integer\n", + "approved : boolean\n", + "transaction_id : id - integer\n", + "\n", + "Primary key: transaction_id\n", + "Foreign key (sessions): session_id\n", + "\n", + "\n", + "\n", + "sessions->transactions\n", + "\n", + "\n", + "   transactions.session_id -> sessions.session_id\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'fields': {'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'},\n", + " 'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'}},\n", + " 'primary_key': 'user_id'},\n", + " 'sessions': {'fields': {'os': {'type': 'categorical'},\n", + " 'device': {'type': 'categorical'},\n", + " 'user_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'users', 'field': 'user_id'}},\n", + " 'session_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'session_id'},\n", + " 'transactions': {'fields': {'timestamp': {'type': 'datetime',\n", + " 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'session_id': {'type': 'id',\n", + " 'subtype': 'integer',\n", + " 'ref': {'table': 'sessions', 'field': 'session_id'}},\n", + " 'approved': {'type': 'boolean'},\n", + " 'transaction_id': {'type': 'id', 'subtype': 'integer'}},\n", + " 'primary_key': 'transaction_id'}}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Pretty similar to the original metadata, right?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_meta.to_dict() == metadata.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving the Metadata as a JSON file\n", + "\n", + "The Metadata object can also be saved as a JSON file, which later on we can load:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "new_meta.to_json('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "loaded = Metadata('demo_metadata.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded.to_dict() == new_meta.to_dict()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}