-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Risklist module for production #631
Changes from 22 commits
a563694
9750c3e
360f8f9
999a46f
372d9c8
16645bc
914ad76
c92bd8b
d3c3ba9
0e92fb0
dbd4578
f7d49e5
dee930f
1769b00
52c9ff0
173167a
f6b2d02
acffa67
43c1919
7dfb7e1
cc9fe4a
5951565
537f6c8
b429540
0045aa5
9dc3697
da870d5
6768ee5
7d6a420
b8fe6d8
8207fcd
45c9d68
a665e7e
ead882b
de85f10
ad860cd
40466d5
83c7385
f97089b
6f0af1c
42bccaa
3ec377f
1c4da24
35bd978
9f5a099
ba84822
83e0f66
d6f14f5
d76359b
9698500
a8a29f1
694edcc
583e9bd
815a258
5e183fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
name: Python package | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
services: | ||
# Label used to access the service container | ||
postgres: | ||
# Docker Hub image | ||
image: postgres | ||
# Provide the password for postgres | ||
env: | ||
POSTGRES_PASSWORD: postgres | ||
# Set health checks to wait until postgres has started | ||
options: >- | ||
--health-cmd pg_isready | ||
--health-interval 10s | ||
--health-timeout 5s | ||
--health-retries 5 | ||
strategy: | ||
matrix: | ||
python-version: [3.6, 3.7] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install system dependencies | ||
run: | | ||
sudo apt-get update | ||
sudo apt-get install libblas-dev liblapack-dev libatlas-base-dev gfortran | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install flake8 pytest | ||
pip install -r requirement/include/build.txt | ||
pip install -r requirement/include/test-management.txt | ||
- name: Test with tox | ||
run: | | ||
tox |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Risklist | ||
|
||
If you would like to generate a list of predictions on already-trained Triage model with new data, you can use the 'Risklist' module. | ||
|
||
## Examples | ||
Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information: | ||
1. A `model_id` from a Triage model that you want to use to generate predictions | ||
2. An `as_of_date` to generate your predictions on. | ||
|
||
### CLI | ||
`triage risklist <model_id> <as_of_date>` | ||
|
||
Example: | ||
`triage risklist 46 2019-05-06` | ||
|
||
The risklist will assume the current path to be the 'project path' to find models and write matrices, but this can be overridden by sending the `--project-path` option. | ||
|
||
### Python | ||
|
||
The `generate_risk_list` function from the `triage.risklist` module can be used similarly to the CLI, with the addition of the database engine and project storage as inputs. | ||
``` | ||
from triage.risklist generate generate_risk_list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep |
||
from triage.catwalk.component.storage import ProjectStorage | ||
from triage import create_engine | ||
|
||
generate_risk_list( | ||
db_engine=create_engine(<your-db-info>), | ||
project_storage=ProjectStorage('/home/you/triage/project2') | ||
model_id=46, | ||
as_of_date='2019-05-06' | ||
) | ||
``` | ||
|
||
## Output | ||
The Risklist is stored similarly to the matrices created during an Experiment: | ||
- Raw Matrix saved to the matrices directory in project storage | ||
- Predictions saved in a table (production.list_predictions) | ||
- Prediction metadata (tiebreaking, random seed) saved in a table (production.prediction_metadata) | ||
|
||
## Notes | ||
- The cohort and features for the Risklist are all inferred from the Experiment that trained the given model_id (as defined by the experiment_models table). | ||
- The feature list ensures that imputation flag columns are present for any columns that either needed to be imputed in the training process, or that needed to be imputed in the risklist dataset. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from triage.risklist import generate_risk_list, train_matrix_info_from_model_id, experiment_config_from_model_id | ||
from triage.validation_primitives import table_should_have_data | ||
|
||
|
||
def test_risklist_should_write_predictions(finished_experiment): | ||
# given a model id and as-of-date <= today | ||
# and the model id is trained and is linked to an experiment with feature and cohort config | ||
# generate records in listpredictions | ||
# the # of records should equal the size of the cohort for that date | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
generate_risk_list( | ||
db_engine=finished_experiment.db_engine, | ||
project_storage=finished_experiment.project_storage, | ||
model_id=model_id, | ||
as_of_date=as_of_date) | ||
table_should_have_data( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This simple assertion was a good start but we should go further. Does it make sense to make assertions about the size of the table? How about the contents? We'd expect all of these rows to have the same date/model id and stuff like that, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if I recall correctly the model_metadata.matrices table will also get a row since we used the MatrixBuilder, we should make sure that row looks reasonable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
db_engine=finished_experiment.db_engine, | ||
table_name="production.list_predictions", | ||
) | ||
|
||
|
||
def test_risklist_should_be_same_shape_as_cohort(finished_experiment): | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
generate_risk_list( | ||
db_engine=finished_experiment.db_engine, | ||
project_storage=finished_experiment.project_storage, | ||
model_id=model_id, | ||
as_of_date=as_of_date) | ||
|
||
num_records_matching_cohort = finished_experiment.db_engine.execute( | ||
f'''select count(*) | ||
from production.list_predictions | ||
join production.cohort_{finished_experiment.config['cohort_config']['name']} using (entity_id, as_of_date) | ||
''' | ||
).first()[0] | ||
|
||
num_records = finished_experiment.db_engine.execute( | ||
'select count(*) from production.list_predictions' | ||
).first()[0] | ||
assert num_records_matching_cohort == num_records | ||
|
||
|
||
def test_risklist_matrix_record_is_populated(finished_experiment): | ||
model_id = 1 | ||
as_of_date = '2014-01-01' | ||
generate_risk_list( | ||
db_engine=finished_experiment.db_engine, | ||
project_storage=finished_experiment.project_storage, | ||
model_id=model_id, | ||
as_of_date=as_of_date) | ||
|
||
matrix_records = list(finished_experiment.db_engine.execute( | ||
"select * from triage_metadata.matrices where matrix_type = 'production'" | ||
)) | ||
assert len(matrix_records) == 1 | ||
|
||
|
||
def test_experiment_config_from_model_id(finished_experiment): | ||
model_id = 1 | ||
experiment_config = experiment_config_from_model_id(finished_experiment.db_engine, model_id) | ||
assert experiment_config == finished_experiment.config | ||
|
||
|
||
def test_train_matrix_info_from_model_id(finished_experiment): | ||
model_id = 1 | ||
(train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(finished_experiment.db_engine, model_id) | ||
assert train_matrix_uuid | ||
assert matrix_metadata |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
MultiCoreExperiment, | ||
SingleThreadedExperiment, | ||
) | ||
from triage.risklist import generate_risk_list | ||
from triage.component.postmodeling.crosstabs import CrosstabsConfigLoader, run_crosstabs | ||
from triage.util.db import create_engine | ||
|
||
|
@@ -399,6 +400,37 @@ def __call__(self, args): | |
run_crosstabs(db_engine, config) | ||
|
||
|
||
@Triage.register | ||
class Risklist(Command): | ||
"""Generate a list of risk scores from an already-trained model and new data""" | ||
|
||
def __init__(self, parser): | ||
parser.add_argument( | ||
"model_id", | ||
type=natural_number, | ||
help="The model_id of an existing trained model in the models table", | ||
) | ||
parser.add_argument( | ||
"as_of_date", | ||
type=valid_date, | ||
help="The date as of which to run features. Format YYYY-MM-DD", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be able to handle timestamps rather than just dates (I think the only place we currently assume day-level resolution at the moment is timechop)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. It looks like some of the same util code that is pinning timechop to day-level resolution (triage.util.conf.dt_from_str) is also used in risklist. So we could make dt_from_str attempt to read ISO datetime strings as well, which would help in both places. And replacing the valid_date helper defined here in the CLI module with the improved dt_from_str would help too, since the two functions really do the same thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense to me -- is it worth doing that here or as a separate PR? |
||
parser.add_argument( | ||
"--project-path", | ||
default=os.getcwd(), | ||
help="path to store matrices and trained models", | ||
) | ||
|
||
def __call__(self, args): | ||
db_engine = create_engine(self.root.db_url) | ||
|
||
generate_risk_list( | ||
db_engine, | ||
ProjectStorage(args.project_path), | ||
args.model_id, | ||
args.as_of_date | ||
) | ||
|
||
@Triage.register | ||
class Db(Command): | ||
"""Manage experiment database""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd update the naming from Risklist to Predictlist here too. The name and probably also the filename since the filename will show up in the documentation URL.