Skip to content

Commit

Permalink
Enable filtering features and metrics in the ModelCardGenerator (#256)
Browse files Browse the repository at this point in the history
* Add pytest and pyenv files to gitignore

* Surface feature and metric include/exclude options
  • Loading branch information
codesue authored Jul 19, 2023
1 parent 8a276a7 commit f872831
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 48 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ __pycache__/
# C extensions
**/*.so

# Unit test
.pytest_cache/

# Distribution / packaging
.Python
# build/ # build/ contains required files for building tfx packages.
Expand Down Expand Up @@ -37,6 +40,9 @@ env/*
**/env
**/venv

#Editor
# pyenv
.python-version

# Editor
.idea/*
.vscode/*
60 changes: 52 additions & 8 deletions tfx_addons/model_card_generator/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ class ModelCardGeneratorSpec(component_spec.ComponentSpec):
# See below link for details.
# https://github.com/tensorflow/tfx/blob/4ff5e97b09540ff8a858076a163ecdf209716324/tfx/types/component_spec.py#L308
'template_io':
component_spec.ExecutionParameter(type=List[Any], optional=True)
component_spec.ExecutionParameter(type=List[Any], optional=True),
'features_include':
component_spec.ExecutionParameter(type=List[str], optional=True),
'features_exclude':
component_spec.ExecutionParameter(type=List[str], optional=True),
'metrics_include':
component_spec.ExecutionParameter(type=List[str], optional=True),
'metrics_exclude':
component_spec.ExecutionParameter(type=List[str], optional=True),
}
INPUTS = {
standard_component_specs.STATISTICS_KEY:
Expand Down Expand Up @@ -94,12 +102,18 @@ class ModelCardGenerator(BaseComponent):
SPEC_CLASS = ModelCardGeneratorSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

def __init__(self,
evaluation: Optional[types.Channel] = None,
statistics: Optional[types.Channel] = None,
pushed_model: Optional[types.Channel] = None,
json: Optional[str] = None,
template_io: Optional[List[Tuple[str, str]]] = None):
def __init__(
self,
evaluation: Optional[types.Channel] = None,
statistics: Optional[types.Channel] = None,
pushed_model: Optional[types.Channel] = None,
json: Optional[str] = None,
template_io: Optional[List[Tuple[str, str]]] = None,
features_include: Optional[List[str]] = None,
features_exclude: Optional[List[str]] = None,
metrics_include: Optional[List[str]] = None,
metrics_exclude: Optional[List[str]] = None,
):
"""Generate a model card for a TFX pipeline.
This executes a Model Card Toolkit workflow, producing a `ModelCard`
Expand Down Expand Up @@ -136,12 +150,42 @@ def __init__(self,
`ModelCardToolkit`'s default HTML template
(`default_template.html.jinja`) and file name (`model_card.html`) are
used.
features_include: The feature paths to include for the dataset statistics.
By default, all features are included. Mutually exclusive with
features_exclude.
features_exclude: The feature paths to exclude for the dataset statistics.
By default, all features are included. Mutually exclusive with
features_include.
metrics_include: The list of metric names to include in the model card. By
default, all metrics are included. Mutually exclusive with
metrics_exclude.
metrics_exclude: The list of metric names to exclude in the model card. By
default, no metrics are excluded. Mutually exclusive with
metrics_include.
Raises:
ValueError:
- When both `features_include` and `features_exclude` are specified.
- When both `metrics_include` and `metrics_exclude` are specified.
"""
if features_include and features_exclude:
raise ValueError(
'Only one of features_include or features_exclude may be specified.')

if metrics_include and metrics_exclude:
raise ValueError(
'Only one of metrics_include or metrics_exclude may be specified.')

spec = ModelCardGeneratorSpec(
evaluation=evaluation,
statistics=statistics,
pushed_model=pushed_model,
model_card=types.Channel(type=artifact.ModelCard),
json=json,
template_io=template_io)
template_io=template_io,
features_include=features_include,
features_exclude=features_exclude,
metrics_include=metrics_include,
metrics_exclude=metrics_exclude,
)
super(ModelCardGenerator, self).__init__(spec=spec)
57 changes: 50 additions & 7 deletions tfx_addons/model_card_generator/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class ComponentTest(absltest.TestCase):
def test_component_construction(self):
this_component = ModelCardGenerator(
model_card_gen = ModelCardGenerator(
statistics=channel_utils.as_channel(
[standard_artifacts.ExampleStatistics()]),
evaluation=channel_utils.as_channel(
Expand All @@ -40,10 +40,11 @@ def test_component_construction(self):
}
}}),
template_io=[('path/to/html/template', 'mc.html'),
('path/to/md/template', 'mc.md')])
('path/to/md/template', 'mc.md')],
)

with self.subTest('outputs'):
self.assertEqual(this_component.outputs['model_card'].type_name,
self.assertEqual(model_card_gen.outputs['model_card'].type_name,
artifact.ModelCard.TYPE_NAME)

with self.subTest('exec_properties'):
Expand All @@ -59,15 +60,57 @@ def test_component_construction(self):
}
}),
'template_io': [('path/to/html/template', 'mc.html'),
('path/to/md/template', 'mc.md')]
}, this_component.exec_properties)
('path/to/md/template', 'mc.md')],
'features_include':
None,
'features_exclude':
None,
'metrics_include':
None,
'metrics_exclude':
None,
}, model_card_gen.exec_properties)

def test_empty_component_construction(self):
this_component = ModelCardGenerator()
model_card_gen = ModelCardGenerator()
with self.subTest('outputs'):
self.assertEqual(this_component.outputs['model_card'].type_name,
self.assertEqual(model_card_gen.outputs['model_card'].type_name,
artifact.ModelCard.TYPE_NAME)

def test_component_construction_with_filtered_features(self):
model_card_gen = ModelCardGenerator(features_include=['feature_name1'])
self.assertEqual(model_card_gen.exec_properties['features_include'],
['feature_name1'])

model_card_gen_features_exclude = ModelCardGenerator(
features_exclude=['feature_name2'], )
self.assertEqual(
model_card_gen_features_exclude.exec_properties['features_exclude'],
['feature_name2'])

with self.assertRaises(ValueError):
ModelCardGenerator(
features_include=['feature_name1'],
features_exclude=['feature_name2'],
)

def test_component_construction_with_filtered_metrics(self):
model_card_gen = ModelCardGenerator(metrics_include=['accuracy'])
self.assertEqual(model_card_gen.exec_properties['metrics_include'],
['accuracy'])

model_card_gen_metrics_exclude = ModelCardGenerator(
metrics_exclude=['loss'], )
self.assertEqual(
model_card_gen_metrics_exclude.exec_properties['metrics_exclude'],
['loss'])

with self.assertRaises(ValueError):
ModelCardGenerator(
metrics_include=['accuracy'],
metrics_exclude=['loss'],
)


if __name__ == '__main__':
absltest.main()
45 changes: 36 additions & 9 deletions tfx_addons/model_card_generator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,35 @@ class Executor(BaseExecutor):
"""Executor for Model Card TFX component."""
def _tfma_source(
self,
input_dict: Dict[str, List[types.Artifact]]) -> Optional[src.TfmaSource]:
input_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any],
) -> Optional[src.TfmaSource]:
"""See base class."""
if not input_dict.get(standard_component_specs.EVALUATION_KEY):
return None
else:
return src.TfmaSource(model_evaluation_artifacts=input_dict[
standard_component_specs.EVALUATION_KEY])
return src.TfmaSource(
model_evaluation_artifacts=input_dict[
standard_component_specs.EVALUATION_KEY],
metrics_include=exec_properties.get('metrics_include', []),
metrics_exclude=exec_properties.get('metrics_exclude', []),
)

def _tfdv_source(
self,
input_dict: Dict[str, List[types.Artifact]]) -> Optional[src.TfdvSource]:
input_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any],
) -> Optional[src.TfdvSource]:
"""See base class."""
if not input_dict.get(standard_component_specs.STATISTICS_KEY):
return None
else:
return src.TfdvSource(example_statistics_artifacts=input_dict[
standard_component_specs.STATISTICS_KEY])
return src.TfdvSource(
example_statistics_artifacts=input_dict[
standard_component_specs.STATISTICS_KEY],
features_include=exec_properties.get('features_include', []),
features_exclude=exec_properties.get('features_exclude', []),
)

def _model_source(
self,
Expand Down Expand Up @@ -100,13 +112,28 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
`ModelCardToolkit`'s default HTML template
(`default_template.html.jinja`) and file name (`model_card.html`)
are used.
- features_include: The feature paths to include for the dataset
statistics.
By default, all features are included. Mutually exclusive with
features_exclude.
- features_exclude: The feature paths to exclude for the dataset
statistics.
By default, all features are included. Mutually exclusive with
features_include.
- metrics_include: The list of metric names to include in the model
card. By default, all metrics are included. Mutually exclusive with
metrics_exclude.
- metrics_exclude: The list of metric names to exclude in the model
card. By default, no metrics are excluded. Mutually exclusive with
metrics_include.
"""

# Initialize ModelCardToolkit
mct = core.ModelCardToolkit(source=src.Source(
tfma=self._tfma_source(input_dict),
tfdv=self._tfdv_source(input_dict),
model=self._model_source(input_dict)),
tfma=self._tfma_source(input_dict, exec_properties),
tfdv=self._tfdv_source(input_dict, exec_properties),
model=self._model_source(input_dict),
),
output_dir=artifact_utils.get_single_instance(
output_dict['model_card']).uri)
template_io = exec_properties.get('template_io') or [
Expand Down
73 changes: 50 additions & 23 deletions tfx_addons/model_card_generator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool,
}
exec_properties['template_io'] = [(self.template_file.full_path,
'my_cool_model_card.html')]
exec_properties['features_include'] = ['feature_name2']
exec_properties['metrics_exclude'] = ['average_loss']

# Call MCT Executor
self.mct_executor.Do(input_dict=input_dict,
Expand Down Expand Up @@ -149,15 +151,24 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool,

if eval_artifacts:
with self.subTest(name='eval_artifacts'):
self.assertCountEqual(
model_card_proto.quantitative_analysis.performance_metrics, [
model_card_pb2.PerformanceMetric(
type='post_export_metrics/example_count', value='2.0'),
model_card_pb2.PerformanceMetric(type='average_loss',
value='0.5')
])
self.assertLen(
model_card_proto.quantitative_analysis.graphics.collection, 2)
if exec_props:
self.assertCountEqual(
model_card_proto.quantitative_analysis.performance_metrics, [
model_card_pb2.PerformanceMetric(
type='post_export_metrics/example_count', value='2.0'),
])
self.assertLen(
model_card_proto.quantitative_analysis.graphics.collection, 1)
else:
self.assertCountEqual(
model_card_proto.quantitative_analysis.performance_metrics, [
model_card_pb2.PerformanceMetric(
type='post_export_metrics/example_count', value='2.0'),
model_card_pb2.PerformanceMetric(type='average_loss',
value='0.5')
])
self.assertLen(
model_card_proto.quantitative_analysis.graphics.collection, 2)

if example_stats_artifacts:
with self.subTest(name='example_stats_artifacts.data'):
Expand All @@ -171,20 +182,36 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool,
)
graphic.image = bytes(
) # ignore graphic.image for below assertions
self.assertIn(
model_card_pb2.Dataset(
name=self.train_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name1',
image='')
])), model_card_proto.model_parameters.data)
self.assertIn(
model_card_pb2.Dataset(
name=self.eval_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name2',
image='')
])), model_card_proto.model_parameters.data)
if exec_props:
self.assertNotIn(
model_card_pb2.Dataset(
name=self.train_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name1',
image='')
])), model_card_proto.model_parameters.data)
self.assertIn(
model_card_pb2.Dataset(
name=self.eval_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name2',
image='')
])), model_card_proto.model_parameters.data)
else:
self.assertIn(
model_card_pb2.Dataset(
name=self.train_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name1',
image='')
])), model_card_proto.model_parameters.data)
self.assertIn(
model_card_pb2.Dataset(
name=self.eval_dataset_name,
graphics=model_card_pb2.GraphicsCollection(collection=[
model_card_pb2.Graphic(name='counts | feature_name2',
image='')
])), model_card_proto.model_parameters.data)

if pushed_model_artifact:
with self.subTest(name='pushed_model_artifact'):
Expand Down

0 comments on commit f872831

Please sign in to comment.