Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Risklist module for production #631

Merged
merged 55 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a563694
listmaking WIP
thcrock Feb 20, 2019
9750c3e
forgot migraton
thcrock Feb 20, 2019
360f8f9
WIP
tweddielin Feb 21, 2019
999a46f
alembic add label_value to list_predictions table
tweddielin Feb 26, 2019
372d9c8
add docstrings
tweddielin Feb 28, 2019
16645bc
move risklist a layer above
tweddielin Mar 13, 2019
914ad76
create risklist module
tweddielin Mar 13, 2019
c92bd8b
__init__lpy
tweddielin Mar 13, 2019
d3c3ba9
fix alembic reversion and replace metta.generate_uuid with filename_f…
tweddielin Mar 13, 2019
0e92fb0
Fix down revision of production schema migration
thcrock Apr 12, 2019
dbd4578
Fix alembic revisions
thcrock Jan 6, 2021
f7d49e5
Enable github checks on this branch too
thcrock Jan 6, 2021
dee930f
Closer to getting tests to run
thcrock Jan 6, 2021
1769b00
Add CLI for risklist
thcrock Jan 8, 2021
52c9ff0
Risklist docs stub
thcrock Jan 8, 2021
173167a
Break up data gathering into experiment and matrix, use pytest fixtur…
thcrock Jan 9, 2021
f6b2d02
Modify schema for list prediction metadata
thcrock Jan 9, 2021
acffa67
fix conflicts and add helper functions for getting imputed features
tweddielin Jan 9, 2021
43c1919
Handle other imputation flag cases, fix tracking indentation error
thcrock Jan 10, 2021
7dfb7e1
Add more tests, fill out doc page
thcrock Jan 11, 2021
cc9fe4a
Fix exception name typo
thcrock Jan 11, 2021
5951565
use timechop and planner to create matrix_metadata for production
tweddielin Jan 15, 2021
537f6c8
retrain and predict forward
tweddielin Apr 15, 2021
b429540
rename to retrain_definition
tweddielin Apr 15, 2021
0045aa5
reusing random seeds from existing models
shaycrk May 8, 2021
9dc3697
fix tests (write experiment to test db)
shaycrk May 10, 2021
da870d5
unit test for reusing model random seeds
shaycrk May 10, 2021
6768ee5
add docstring
shaycrk May 10, 2021
7d6a420
only store random seed in experiment runs
shaycrk May 20, 2021
b8fe6d8
DB migration to remove random seed from experiments table
shaycrk May 20, 2021
8207fcd
debugging
shaycrk May 20, 2021
45c9d68
debug model trainer tests
shaycrk May 21, 2021
a665e7e
debug catwalk utils tests
shaycrk May 21, 2021
ead882b
debug catwalk integration test
shaycrk May 21, 2021
de85f10
use public method
tweddielin May 30, 2021
ad860cd
Merge remote-tracking branch 'origin/kit_rand_seed' into list_making
tweddielin May 31, 2021
40466d5
alembic merge
tweddielin May 31, 2021
83c7385
reuse random seed
tweddielin May 31, 2021
f97089b
use timechop for getting retrain information
tweddielin Jun 30, 2021
6f0af1c
create retrain model hash in retrain level instead of model_trainer l…
tweddielin Jun 30, 2021
42bccaa
move util functions to utils
tweddielin Jun 30, 2021
3ec377f
fix cli and docs
tweddielin Jul 1, 2021
1c4da24
update docs
tweddielin Jul 1, 2021
35bd978
use reconstructed feature dict
tweddielin Jul 1, 2021
9f5a099
add RetrainModel and Retrain
tweddielin Jul 29, 2021
ba84822
remove break point
tweddielin Jul 29, 2021
83e0f66
change experiment_runs to triage_runs
tweddielin Aug 21, 2021
d6f14f5
get retrain_config
tweddielin Aug 22, 2021
d76359b
explicitly include run_type in joins to triage_runs
shaycrk Aug 26, 2021
9698500
DB migration updates
shaycrk Aug 26, 2021
a8a29f1
update argument name in docs
shaycrk Aug 26, 2021
694edcc
ensure correct temporal config is used for predicting forward
shaycrk Aug 27, 2021
583e9bd
debug
shaycrk Aug 27, 2021
815a258
debug
shaycrk Aug 27, 2021
5e183fe
Merge branch 'master' into list_making
shaycrk Aug 27, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Python package

on: [push]

jobs:
build:

runs-on: ubuntu-latest
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
# Provide the password for postgres
env:
POSTGRES_PASSWORD: postgres
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
matrix:
python-version: [3.6, 3.7]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev libatlas-base-dev gfortran
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install -r requirement/include/build.txt
pip install -r requirement/include/test-management.txt
- name: Test with tox
run: |
tox
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ nav:
- Using Postmodeling: postmodeling/index.md
- Postmodeling & Crosstabs Configuration: postmodeling/postmodeling-config.md
- Model governance: dirtyduck/ml_governance.md
- Risklist: risklist/index.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd update the naming from Risklist to Predictlist here too. The name and probably also the filename since the filename will show up in the documentation URL.

- Scaling up: dirtyduck/aws_batch.md
- API Reference:
- Audition:
Expand Down
42 changes: 42 additions & 0 deletions docs/sources/risklist/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Risklist

If you would like to generate a list of predictions on already-trained Triage model with new data, you can use the 'Risklist' module.

## Examples
Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information:
1. A `model_id` from a Triage model that you want to use to generate predictions
2. An `as_of_date` to generate your predictions on.

### CLI
`triage risklist <model_id> <as_of_date>`

Example:
`triage risklist 46 2019-05-06`

The risklist will assume the current path to be the 'project path' to find models and write matrices, but this can be overridden by sending the `--project-path` option.

### Python

The `generate_risk_list` function from the `triage.risklist` module can be used similarly to the CLI, with the addition of the database engine and project storage as inputs.
```
from triage.risklist generate generate_risk_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this should be from triage.risklist import generate_risk_list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

from triage.catwalk.component.storage import ProjectStorage
from triage import create_engine

generate_risk_list(
db_engine=create_engine(<your-db-info>),
project_storage=ProjectStorage('/home/you/triage/project2')
model_id=46,
as_of_date='2019-05-06'
)
```

## Output
The Risklist is stored similarly to the matrices created during an Experiment:
- Raw Matrix saved to the matrices directory in project storage
- Predictions saved in a table (production.list_predictions)
- Prediction metadata (tiebreaking, random seed) saved in a table (production.prediction_metadata)

## Notes
- The cohort and features for the Risklist are all inferred from the Experiment that trained the given model_id (as defined by the experiment_models table).
- The feature list ensures that imputation flag columns are present for any columns that either needed to be imputed in the training process, or that needed to be imputed in the risklist dataset.
52 changes: 52 additions & 0 deletions src/tests/collate_tests/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Unit tests for `collate` module.

"""
import pytest
from triage.component.collate import Aggregate, Aggregation, Categorical

def test_aggregate():
Expand Down Expand Up @@ -191,3 +192,54 @@ def test_distinct():
),
)
) == ["count(distinct (x,y)) FILTER (WHERE date < '2012-01-01')"]


def test_Aggregation_colname_aggregate_lookup():
n = Aggregate("x", "sum", {})
d = Aggregate("1", "count", {})
m = Aggregate("y", "avg", {})
aggregation = Aggregation(
[n, d, m],
groups=['entity_id'],
from_obj="source",
prefix="mysource",
state_table="tbl"
)
assert aggregation.colname_aggregate_lookup == {
'mysource_entity_id_x_sum': 'sum',
'mysource_entity_id_1_count': 'count',
'mysource_entity_id_y_avg': 'avg'
}

def test_Aggregation_colname_agg_function():
n = Aggregate("x", "sum", {})
d = Aggregate("1", "count", {})
m = Aggregate("y", "stddev_samp", {})
aggregation = Aggregation(
[n, d, m],
groups=['entity_id'],
from_obj="source",
prefix="mysource",
state_table="tbl"
)

assert aggregation.colname_agg_function('mysource_entity_id_x_sum') == 'sum'
assert aggregation.colname_agg_function('mysource_entity_id_y_stddev_samp') == 'stddev_samp'


def test_Aggregation_imputation_flag_base():
n = Aggregate("x", ["sum", "count"], {})
m = Aggregate("y", "stddev_samp", {})
aggregation = Aggregation(
[n, m],
groups=['entity_id'],
from_obj="source",
prefix="mysource",
state_table="tbl"
)

assert aggregation.imputation_flag_base('mysource_entity_id_x_sum') == 'mysource_entity_id_x'
assert aggregation.imputation_flag_base('mysource_entity_id_x_count') == 'mysource_entity_id_x'
assert aggregation.imputation_flag_base('mysource_entity_id_y_stddev_samp') == 'mysource_entity_id_y_stddev_samp'
with pytest.raises(KeyError):
aggregation.imputation_flag_base('mysource_entity_id_x_stddev_samp')
11 changes: 11 additions & 0 deletions src/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import triage.cli as cli
from unittest.mock import Mock, patch
import os
import datetime


# we do not need a real database URL but one SQLalchemy thinks looks like a real one
Expand Down Expand Up @@ -56,3 +57,13 @@ def test_featuretest():
try_command('featuretest', 'example/config/experiment.yaml', '2017-06-06')
featuremock.assert_called_once()
cohortmock.assert_called_once()


def test_cli_risklist():
with patch('triage.cli.generate_risk_list', autospec=True) as mock:
try_command('risklist', '40', '2019-06-04')
mock.assert_called_once()
assert mock.call_args[0][0].url
assert mock.call_args[0][1].project_path
assert mock.call_args[0][2] == 40
assert mock.call_args[0][3] == datetime.datetime(2019, 6, 4)
70 changes: 70 additions & 0 deletions src/tests/test_risklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from triage.risklist import generate_risk_list, train_matrix_info_from_model_id, experiment_config_from_model_id
from triage.validation_primitives import table_should_have_data


def test_risklist_should_write_predictions(finished_experiment):
# given a model id and as-of-date <= today
# and the model id is trained and is linked to an experiment with feature and cohort config
# generate records in listpredictions
# the # of records should equal the size of the cohort for that date
model_id = 1
as_of_date = '2014-01-01'
generate_risk_list(
db_engine=finished_experiment.db_engine,
project_storage=finished_experiment.project_storage,
model_id=model_id,
as_of_date=as_of_date)
table_should_have_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simple assertion was a good start but we should go further. Does it make sense to make assertions about the size of the table? How about the contents? We'd expect all of these rows to have the same date/model id and stuff like that, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if I recall correctly the model_metadata.matrices table will also get a row since we used the MatrixBuilder, we should make sure that row looks reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

db_engine=finished_experiment.db_engine,
table_name="production.list_predictions",
)


def test_risklist_should_be_same_shape_as_cohort(finished_experiment):
model_id = 1
as_of_date = '2014-01-01'
generate_risk_list(
db_engine=finished_experiment.db_engine,
project_storage=finished_experiment.project_storage,
model_id=model_id,
as_of_date=as_of_date)

num_records_matching_cohort = finished_experiment.db_engine.execute(
f'''select count(*)
from production.list_predictions
join production.cohort_{finished_experiment.config['cohort_config']['name']} using (entity_id, as_of_date)
'''
).first()[0]

num_records = finished_experiment.db_engine.execute(
'select count(*) from production.list_predictions'
).first()[0]
assert num_records_matching_cohort == num_records


def test_risklist_matrix_record_is_populated(finished_experiment):
model_id = 1
as_of_date = '2014-01-01'
generate_risk_list(
db_engine=finished_experiment.db_engine,
project_storage=finished_experiment.project_storage,
model_id=model_id,
as_of_date=as_of_date)

matrix_records = list(finished_experiment.db_engine.execute(
"select * from triage_metadata.matrices where matrix_type = 'production'"
))
assert len(matrix_records) == 1


def test_experiment_config_from_model_id(finished_experiment):
model_id = 1
experiment_config = experiment_config_from_model_id(finished_experiment.db_engine, model_id)
assert experiment_config == finished_experiment.config


def test_train_matrix_info_from_model_id(finished_experiment):
model_id = 1
(train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(finished_experiment.db_engine, model_id)
assert train_matrix_uuid
assert matrix_metadata
32 changes: 32 additions & 0 deletions src/triage/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MultiCoreExperiment,
SingleThreadedExperiment,
)
from triage.risklist import generate_risk_list
from triage.component.postmodeling.crosstabs import CrosstabsConfigLoader, run_crosstabs
from triage.util.db import create_engine

Expand Down Expand Up @@ -399,6 +400,37 @@ def __call__(self, args):
run_crosstabs(db_engine, config)


@Triage.register
class Risklist(Command):
"""Generate a list of risk scores from an already-trained model and new data"""

def __init__(self, parser):
parser.add_argument(
"model_id",
type=natural_number,
help="The model_id of an existing trained model in the models table",
)
parser.add_argument(
"as_of_date",
type=valid_date,
help="The date as of which to run features. Format YYYY-MM-DD",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be able to handle timestamps rather than just dates (I think the only place we currently assume day-level resolution at the moment is timechop)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. It looks like some of the same util code that is pinning timechop to day-level resolution (triage.util.conf.dt_from_str) is also used in risklist. So we could make dt_from_str attempt to read ISO datetime strings as well, which would help in both places.

And replacing the valid_date helper defined here in the CLI module with the improved dt_from_str would help too, since the two functions really do the same thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me -- is it worth doing that here or as a separate PR?

parser.add_argument(
"--project-path",
default=os.getcwd(),
help="path to store matrices and trained models",
)

def __call__(self, args):
db_engine = create_engine(self.root.db_url)

generate_risk_list(
db_engine,
ProjectStorage(args.project_path),
args.model_id,
args.as_of_date
)

@Triage.register
class Db(Command):
"""Manage experiment database"""
Expand Down
54 changes: 31 additions & 23 deletions src/triage/component/architect/builders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import json

import verboselogs, logging
logger = verboselogs.VerboseLogger(__name__)
Expand Down Expand Up @@ -32,6 +31,7 @@ def __init__(
self.replace = replace
self.include_missing_labels_in_train_as = include_missing_labels_in_train_as
self.run_id = run_id
self.includes_labels = 'labels_table_name' in self.db_config

@property
def sessionmaker(self):
Expand Down Expand Up @@ -131,7 +131,7 @@ def make_entity_date_table(
"""

as_of_time_strings = [str(as_of_time) for as_of_time in as_of_times]
if matrix_type == "test" or self.include_missing_labels_in_train_as is not None:
if matrix_type == "test" or matrix_type == "production" or self.include_missing_labels_in_train_as is not None:
indices_query = self._all_valid_entity_dates_query(
as_of_time_strings=as_of_time_strings, state=state
)
Expand Down Expand Up @@ -232,14 +232,15 @@ def build_matrix(
if self.run_id:
errored_matrix(self.run_id, self.db_engine)
return
if not table_has_data(
f"{self.db_config['labels_schema_name']}.{self.db_config['labels_table_name']}",
self.db_engine,
):
logger.warning("labels table is not populated, cannot build matrix")
if self.run_id:
errored_matrix(self.run_id, self.db_engine)
return

if self.includes_labels:
if not table_has_data(
f"{self.db_config['labels_schema_name']}.{self.db_config['labels_table_name']}",
self.db_engine,
):
logger.warning("labels table is not populated, cannot build matrix")
if self.run_id:
errored_matrix(self.run_id, self.db_engine)

matrix_store = self.matrix_storage_engine.get_store(matrix_uuid)
if not self.replace and matrix_store.exists:
Expand All @@ -261,7 +262,7 @@ def build_matrix(
matrix_metadata["state"],
matrix_type,
matrix_uuid,
matrix_metadata["label_timespan"],
matrix_metadata.get("label_timespan", None),
)
except ValueError as e:
logger.exception(
Expand All @@ -277,19 +278,26 @@ def build_matrix(
as_of_times, feature_dictionary, entity_date_table_name, matrix_uuid
)
logger.debug(f"Feature data extracted for matrix {matrix_uuid}")
logger.spam(
"Extracting label data from database into file for matrix {matrix_uuid}",
)
labels_df = self.load_labels_data(
label_name,
label_type,
entity_date_table_name,
matrix_uuid,
matrix_metadata["label_timespan"],
)
dataframes.insert(0, labels_df)

logger.debug(f"Label data extracted for matrix {matrix_uuid}")
# dataframes add label_name

if self.includes_labels:
logger.spam(
"Extracting label data from database into file for matrix {matrix_uuid}",
)
labels_df = self.load_labels_data(
label_name,
label_type,
entity_date_table_name,
matrix_uuid,
matrix_metadata["label_timespan"],
)
dataframes.insert(0, labels_df)
logging.debug(f"Label data extracted for matrix {matrix_uuid}")
else:
labels_df = pd.DataFrame(index=dataframes[0].index, columns=[label_name])
dataframes.insert(0, labels_df)

# stitch together the csvs
logger.spam(f"Merging feature files for matrix {matrix_uuid}")
output = self.merge_feature_csvs(dataframes, matrix_uuid)
Expand Down
Loading