-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Risklist module for production (#631)
* listmaking WIP * forgot migraton * WIP * alembic add label_value to list_predictions table * add docstrings * move risklist a layer above * create risklist module * __init__lpy * fix alembic reversion and replace metta.generate_uuid with filename_friendly_hash * Fix down revision of production schema migration * Enable github checks on this branch too * Closer to getting tests to run * Add CLI for risklist * Risklist docs stub * Break up data gathering into experiment and matrix, use pytest fixtures to speed up subsequent tests * Modify schema for list prediction metadata * fix conflicts and add helper functions for getting imputed features * Handle other imputation flag cases, fix tracking indentation error * Add more tests, fill out doc page * Fix exception name typo * use timechop and planner to create matrix_metadata for production * retrain and predict forward * rename to retrain_definition * reusing random seeds from existing models * fix tests (write experiment to test db) * unit test for reusing model random seeds * add docstring * only store random seed in experiment runs * DB migration to remove random seed from experiments table * debugging * debug model trainer tests * debug catwalk utils tests * debug catwalk integration test * use public method * alembic merge * reuse random seed * use timechop for getting retrain information * create retrain model hash in retrain level instead of model_trainer level * move util functions to utils * fix cli and docs * update docs * use reconstructed feature dict * add RetrainModel and Retrain * remove break point * change experiment_runs to triage_runs * get retrain_config * explicitly include run_type in joins to triage_runs * DB migration updates * update argument name in docs * ensure correct temporal config is used for predicting forward * debug * debug Co-authored-by: Tristan Crockett <tristan.h.crockett@gmail.com> Co-authored-by: Kit Rodolfa <shaycrk@gmail.com>
- Loading branch information
1 parent
a994f3e
commit 537813a
Showing
31 changed files
with
1,682 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Retrain and Predict | ||
Use an existing model group to retrain a new model on all the data up to the current date and then predict forward into the future. | ||
|
||
## Examples | ||
Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information: | ||
1. A `model_group_id` from a Triage model group that you want to use to retrain a model and generate prediction | ||
2. A `prediction_date` to generate your predictions on. | ||
|
||
### CLI | ||
`triage retrainpredict <model_group_id> <prediction_date>` | ||
|
||
Example: | ||
`triage retrainpredict 30 2021-04-04` | ||
|
||
The `retrainpredict` will assume the current path to be the 'project path' to train models and write matrices, but this can be overridden by sending the `--project-path` option | ||
|
||
### Python | ||
The `Retrainer` class from `triage.predictlist` module can be used to retrain a model and predict forward. | ||
|
||
```python | ||
from triage.predictlist import Retrainer | ||
from triage import create_engine | ||
|
||
retrainer = Retrainer( | ||
db_engine=create_engine(<your-db-info>), | ||
project_path='/home/you/triage/project2' | ||
model_group_id=36, | ||
) | ||
retrainer.retrain(prediction_date='2021-04-04') | ||
retrainer.predict(prediction_date='2021-04-04') | ||
|
||
``` | ||
|
||
## Output | ||
The retrained model is sotred similariy to the matrices created during an Experiment: | ||
- Raw Matrix saved to the matrices directory in project storage | ||
- Raw Model saved to the trained_model directory in project storage | ||
- Retrained Model info saved in a table (triage_metadata.models) where model_comment = 'retrain_2021-04-04 21:19:09.975112' | ||
- Predictions saved in a table (triage_production.predictions) | ||
- Prediction metadata (tiebreaking, random seed) saved in a table (triage_produciton.prediction_metadata) | ||
|
||
|
||
# Predictlist | ||
If you would like to generate a list of predictions on already-trained Triage model with new data, you can use the 'Predictlist' module. | ||
|
||
# Predict Foward with Existed Model | ||
Use an existing model object to generate predictions on new data. | ||
|
||
## Examples | ||
Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information: | ||
1. A `model_id` from a Triage model that you want to use to generate predictions | ||
2. An `as_of_date` to generate your predictions on. | ||
|
||
### CLI | ||
`triage predictlist <model_id> <as_of_date>` | ||
|
||
Example: | ||
`triage predictlist 46 2019-05-06` | ||
|
||
The predictlist will assume the current path to be the 'project path' to find models and write matrices, but this can be overridden by sending the `--project-path` option. | ||
|
||
### Python | ||
|
||
The `predict_forward_with_existed_model` function from the `triage.predictlist` module can be used similarly to the CLI, with the addition of the database engine and project storage as inputs. | ||
``` | ||
from triage.predictlist import generate predict_forward_with_existed_model | ||
from triage import create_engine | ||
predict_forward_with_existed_model( | ||
db_engine=create_engine(<your-db-info>), | ||
project_path='/home/you/triage/project2' | ||
model_id=46, | ||
as_of_date='2019-05-06' | ||
) | ||
``` | ||
|
||
## Output | ||
The Predictlist is stored similarly to the matrices created during an Experiment: | ||
- Raw Matrix saved to the matrices directory in project storage | ||
- Predictions saved in a table (triage_production.predictions) | ||
- Prediction metadata (tiebreaking, random seed) saved in a table (triage_production.prediction_metadata) | ||
|
||
## Notes | ||
- The cohort and features for the Predictlist are all inferred from the Experiment that trained the given model_id (as defined by the experiment_models table). | ||
- The feature list ensures that imputation flag columns are present for any columns that either needed to be imputed in the training process, or that needed to be imputed in the predictlist dataset. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from triage.predictlist import Retrainer, predict_forward_with_existed_model, train_matrix_info_from_model_id, experiment_config_from_model_id | ||
from triage.validation_primitives import table_should_have_data | ||
|
||
|
||
def test_predict_forward_with_existed_model_should_write_predictions(finished_experiment): | ||
# given a model id and as-of-date <= today | ||
# and the model id is trained and is linked to an experiment with feature and cohort config | ||
# generate records in triage_production.predictions | ||
# the # of records should equal the size of the cohort for that date | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
predict_forward_with_existed_model( | ||
db_engine=finished_experiment.db_engine, | ||
project_path=finished_experiment.project_storage.project_path, | ||
model_id=model_id, | ||
as_of_date=as_of_date | ||
) | ||
table_should_have_data( | ||
db_engine=finished_experiment.db_engine, | ||
table_name="triage_production.predictions", | ||
) | ||
|
||
|
||
def test_predict_forward_with_existed_model_should_be_same_shape_as_cohort(finished_experiment): | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
predict_forward_with_existed_model( | ||
db_engine=finished_experiment.db_engine, | ||
project_path=finished_experiment.project_storage.project_path, | ||
model_id=model_id, | ||
as_of_date=as_of_date) | ||
|
||
num_records_matching_cohort = finished_experiment.db_engine.execute( | ||
f'''select count(*) | ||
from triage_production.predictions | ||
join triage_production.cohort_{finished_experiment.config['cohort_config']['name']} using (entity_id, as_of_date) | ||
''' | ||
).first()[0] | ||
|
||
num_records = finished_experiment.db_engine.execute( | ||
'select count(*) from triage_production.predictions' | ||
).first()[0] | ||
assert num_records_matching_cohort == num_records | ||
|
||
|
||
def test_predict_forward_with_existed_model_matrix_record_is_populated(finished_experiment): | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
predict_forward_with_existed_model( | ||
db_engine=finished_experiment.db_engine, | ||
project_path=finished_experiment.project_storage.project_path, | ||
model_id=model_id, | ||
as_of_date=as_of_date) | ||
|
||
matrix_records = list(finished_experiment.db_engine.execute( | ||
"select * from triage_metadata.matrices where matrix_type = 'production'" | ||
)) | ||
assert len(matrix_records) == 1 | ||
|
||
|
||
def test_experiment_config_from_model_id(finished_experiment): | ||
model_id = 1 | ||
experiment_config = experiment_config_from_model_id(finished_experiment.db_engine, model_id) | ||
assert experiment_config == finished_experiment.config | ||
|
||
|
||
def test_train_matrix_info_from_model_id(finished_experiment): | ||
model_id = 1 | ||
(train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(finished_experiment.db_engine, model_id) | ||
assert train_matrix_uuid | ||
assert matrix_metadata | ||
|
||
|
||
def test_retrain_should_write_model(finished_experiment): | ||
# given a model id and prediction_date | ||
# and the model id is trained and is linked to an experiment with feature and cohort config | ||
# create matrix for retraining a model | ||
# generate records in production models | ||
# retrain_model_hash should be the same with model_hash in triage_metadata.models | ||
model_group_id = 1 | ||
prediction_date = '2014-03-01' | ||
|
||
retrainer = Retrainer( | ||
db_engine=finished_experiment.db_engine, | ||
project_path=finished_experiment.project_storage.project_path, | ||
model_group_id=model_group_id, | ||
) | ||
retrain_info = retrainer.retrain(prediction_date) | ||
model_comment = retrain_info['retrain_model_comment'] | ||
|
||
records = [ | ||
row | ||
for row in finished_experiment.db_engine.execute( | ||
f"select model_hash from triage_metadata.models where model_comment = '{model_comment}'" | ||
) | ||
] | ||
assert len(records) == 1 | ||
assert retrainer.retrain_model_hash == records[0][0] | ||
|
||
retrainer.predict(prediction_date) | ||
|
||
table_should_have_data( | ||
db_engine=finished_experiment.db_engine, | ||
table_name="triage_production.predictions", | ||
) | ||
|
||
matrix_records = list(finished_experiment.db_engine.execute( | ||
f"select * from triage_metadata.matrices where matrix_uuid = '{retrainer.predict_matrix_uuid}'" | ||
)) | ||
assert len(matrix_records) == 1 |
Oops, something went wrong.