diff --git a/.gitignore b/.gitignore index 5b8e7688..605474a2 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ # C extensions **/*.so +# Unit test +.pytest_cache/ + # Distribution / packaging .Python # build/ # build/ contains required files for building tfx packages. @@ -37,6 +40,9 @@ env/* **/env **/venv -#Editor +# pyenv +.python-version + +# Editor .idea/* .vscode/* diff --git a/tfx_addons/model_card_generator/component.py b/tfx_addons/model_card_generator/component.py index 24e917ab..73491e3a 100644 --- a/tfx_addons/model_card_generator/component.py +++ b/tfx_addons/model_card_generator/component.py @@ -40,7 +40,15 @@ class ModelCardGeneratorSpec(component_spec.ComponentSpec): # See below link for details. # https://github.com/tensorflow/tfx/blob/4ff5e97b09540ff8a858076a163ecdf209716324/tfx/types/component_spec.py#L308 'template_io': - component_spec.ExecutionParameter(type=List[Any], optional=True) + component_spec.ExecutionParameter(type=List[Any], optional=True), + 'features_include': + component_spec.ExecutionParameter(type=List[str], optional=True), + 'features_exclude': + component_spec.ExecutionParameter(type=List[str], optional=True), + 'metrics_include': + component_spec.ExecutionParameter(type=List[str], optional=True), + 'metrics_exclude': + component_spec.ExecutionParameter(type=List[str], optional=True), } INPUTS = { standard_component_specs.STATISTICS_KEY: @@ -94,12 +102,18 @@ class ModelCardGenerator(BaseComponent): SPEC_CLASS = ModelCardGeneratorSpec EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor) - def __init__(self, - evaluation: Optional[types.Channel] = None, - statistics: Optional[types.Channel] = None, - pushed_model: Optional[types.Channel] = None, - json: Optional[str] = None, - template_io: Optional[List[Tuple[str, str]]] = None): + def __init__( + self, + evaluation: Optional[types.Channel] = None, + statistics: Optional[types.Channel] = None, + pushed_model: Optional[types.Channel] = None, + json: Optional[str] = None, + template_io: Optional[List[Tuple[str, str]]] = None, + features_include: Optional[List[str]] = None, + features_exclude: Optional[List[str]] = None, + metrics_include: Optional[List[str]] = None, + metrics_exclude: Optional[List[str]] = None, + ): """Generate a model card for a TFX pipeline. This executes a Model Card Toolkit workflow, producing a `ModelCard` @@ -136,12 +150,42 @@ def __init__(self, `ModelCardToolkit`'s default HTML template (`default_template.html.jinja`) and file name (`model_card.html`) are used. + features_include: The feature paths to include for the dataset statistics. + By default, all features are included. Mutually exclusive with + features_exclude. + features_exclude: The feature paths to exclude for the dataset statistics. + By default, all features are included. Mutually exclusive with + features_include. + metrics_include: The list of metric names to include in the model card. By + default, all metrics are included. Mutually exclusive with + metrics_exclude. + metrics_exclude: The list of metric names to exclude in the model card. By + default, no metrics are excluded. Mutually exclusive with + metrics_include. + + Raises: + ValueError: + - When both `features_include` and `features_exclude` are specified. + - When both `metrics_include` and `metrics_exclude` are specified. """ + if features_include and features_exclude: + raise ValueError( + 'Only one of features_include or features_exclude may be specified.') + + if metrics_include and metrics_exclude: + raise ValueError( + 'Only one of metrics_include or metrics_exclude may be specified.') + spec = ModelCardGeneratorSpec( evaluation=evaluation, statistics=statistics, pushed_model=pushed_model, model_card=types.Channel(type=artifact.ModelCard), json=json, - template_io=template_io) + template_io=template_io, + features_include=features_include, + features_exclude=features_exclude, + metrics_include=metrics_include, + metrics_exclude=metrics_exclude, + ) super(ModelCardGenerator, self).__init__(spec=spec) diff --git a/tfx_addons/model_card_generator/component_test.py b/tfx_addons/model_card_generator/component_test.py index da3aeb34..a1a2b477 100644 --- a/tfx_addons/model_card_generator/component_test.py +++ b/tfx_addons/model_card_generator/component_test.py @@ -25,7 +25,7 @@ class ComponentTest(absltest.TestCase): def test_component_construction(self): - this_component = ModelCardGenerator( + model_card_gen = ModelCardGenerator( statistics=channel_utils.as_channel( [standard_artifacts.ExampleStatistics()]), evaluation=channel_utils.as_channel( @@ -40,10 +40,11 @@ def test_component_construction(self): } }}), template_io=[('path/to/html/template', 'mc.html'), - ('path/to/md/template', 'mc.md')]) + ('path/to/md/template', 'mc.md')], + ) with self.subTest('outputs'): - self.assertEqual(this_component.outputs['model_card'].type_name, + self.assertEqual(model_card_gen.outputs['model_card'].type_name, artifact.ModelCard.TYPE_NAME) with self.subTest('exec_properties'): @@ -59,15 +60,57 @@ def test_component_construction(self): } }), 'template_io': [('path/to/html/template', 'mc.html'), - ('path/to/md/template', 'mc.md')] - }, this_component.exec_properties) + ('path/to/md/template', 'mc.md')], + 'features_include': + None, + 'features_exclude': + None, + 'metrics_include': + None, + 'metrics_exclude': + None, + }, model_card_gen.exec_properties) def test_empty_component_construction(self): - this_component = ModelCardGenerator() + model_card_gen = ModelCardGenerator() with self.subTest('outputs'): - self.assertEqual(this_component.outputs['model_card'].type_name, + self.assertEqual(model_card_gen.outputs['model_card'].type_name, artifact.ModelCard.TYPE_NAME) + def test_component_construction_with_filtered_features(self): + model_card_gen = ModelCardGenerator(features_include=['feature_name1']) + self.assertEqual(model_card_gen.exec_properties['features_include'], + ['feature_name1']) + + model_card_gen_features_exclude = ModelCardGenerator( + features_exclude=['feature_name2'], ) + self.assertEqual( + model_card_gen_features_exclude.exec_properties['features_exclude'], + ['feature_name2']) + + with self.assertRaises(ValueError): + ModelCardGenerator( + features_include=['feature_name1'], + features_exclude=['feature_name2'], + ) + + def test_component_construction_with_filtered_metrics(self): + model_card_gen = ModelCardGenerator(metrics_include=['accuracy']) + self.assertEqual(model_card_gen.exec_properties['metrics_include'], + ['accuracy']) + + model_card_gen_metrics_exclude = ModelCardGenerator( + metrics_exclude=['loss'], ) + self.assertEqual( + model_card_gen_metrics_exclude.exec_properties['metrics_exclude'], + ['loss']) + + with self.assertRaises(ValueError): + ModelCardGenerator( + metrics_include=['accuracy'], + metrics_exclude=['loss'], + ) + if __name__ == '__main__': absltest.main() diff --git a/tfx_addons/model_card_generator/executor.py b/tfx_addons/model_card_generator/executor.py index e946ee55..1e980a0b 100644 --- a/tfx_addons/model_card_generator/executor.py +++ b/tfx_addons/model_card_generator/executor.py @@ -33,23 +33,35 @@ class Executor(BaseExecutor): """Executor for Model Card TFX component.""" def _tfma_source( self, - input_dict: Dict[str, List[types.Artifact]]) -> Optional[src.TfmaSource]: + input_dict: Dict[str, List[types.Artifact]], + exec_properties: Dict[str, Any], + ) -> Optional[src.TfmaSource]: """See base class.""" if not input_dict.get(standard_component_specs.EVALUATION_KEY): return None else: - return src.TfmaSource(model_evaluation_artifacts=input_dict[ - standard_component_specs.EVALUATION_KEY]) + return src.TfmaSource( + model_evaluation_artifacts=input_dict[ + standard_component_specs.EVALUATION_KEY], + metrics_include=exec_properties.get('metrics_include', []), + metrics_exclude=exec_properties.get('metrics_exclude', []), + ) def _tfdv_source( self, - input_dict: Dict[str, List[types.Artifact]]) -> Optional[src.TfdvSource]: + input_dict: Dict[str, List[types.Artifact]], + exec_properties: Dict[str, Any], + ) -> Optional[src.TfdvSource]: """See base class.""" if not input_dict.get(standard_component_specs.STATISTICS_KEY): return None else: - return src.TfdvSource(example_statistics_artifacts=input_dict[ - standard_component_specs.STATISTICS_KEY]) + return src.TfdvSource( + example_statistics_artifacts=input_dict[ + standard_component_specs.STATISTICS_KEY], + features_include=exec_properties.get('features_include', []), + features_exclude=exec_properties.get('features_exclude', []), + ) def _model_source( self, @@ -100,13 +112,28 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]], `ModelCardToolkit`'s default HTML template (`default_template.html.jinja`) and file name (`model_card.html`) are used. + - features_include: The feature paths to include for the dataset + statistics. + By default, all features are included. Mutually exclusive with + features_exclude. + - features_exclude: The feature paths to exclude for the dataset + statistics. + By default, all features are included. Mutually exclusive with + features_include. + - metrics_include: The list of metric names to include in the model + card. By default, all metrics are included. Mutually exclusive with + metrics_exclude. + - metrics_exclude: The list of metric names to exclude in the model + card. By default, no metrics are excluded. Mutually exclusive with + metrics_include. """ # Initialize ModelCardToolkit mct = core.ModelCardToolkit(source=src.Source( - tfma=self._tfma_source(input_dict), - tfdv=self._tfdv_source(input_dict), - model=self._model_source(input_dict)), + tfma=self._tfma_source(input_dict, exec_properties), + tfdv=self._tfdv_source(input_dict, exec_properties), + model=self._model_source(input_dict), + ), output_dir=artifact_utils.get_single_instance( output_dict['model_card']).uri) template_io = exec_properties.get('template_io') or [ diff --git a/tfx_addons/model_card_generator/executor_test.py b/tfx_addons/model_card_generator/executor_test.py index 9f046cb8..5c960b64 100644 --- a/tfx_addons/model_card_generator/executor_test.py +++ b/tfx_addons/model_card_generator/executor_test.py @@ -109,6 +109,8 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool, } exec_properties['template_io'] = [(self.template_file.full_path, 'my_cool_model_card.html')] + exec_properties['features_include'] = ['feature_name2'] + exec_properties['metrics_exclude'] = ['average_loss'] # Call MCT Executor self.mct_executor.Do(input_dict=input_dict, @@ -149,15 +151,24 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool, if eval_artifacts: with self.subTest(name='eval_artifacts'): - self.assertCountEqual( - model_card_proto.quantitative_analysis.performance_metrics, [ - model_card_pb2.PerformanceMetric( - type='post_export_metrics/example_count', value='2.0'), - model_card_pb2.PerformanceMetric(type='average_loss', - value='0.5') - ]) - self.assertLen( - model_card_proto.quantitative_analysis.graphics.collection, 2) + if exec_props: + self.assertCountEqual( + model_card_proto.quantitative_analysis.performance_metrics, [ + model_card_pb2.PerformanceMetric( + type='post_export_metrics/example_count', value='2.0'), + ]) + self.assertLen( + model_card_proto.quantitative_analysis.graphics.collection, 1) + else: + self.assertCountEqual( + model_card_proto.quantitative_analysis.performance_metrics, [ + model_card_pb2.PerformanceMetric( + type='post_export_metrics/example_count', value='2.0'), + model_card_pb2.PerformanceMetric(type='average_loss', + value='0.5') + ]) + self.assertLen( + model_card_proto.quantitative_analysis.graphics.collection, 2) if example_stats_artifacts: with self.subTest(name='example_stats_artifacts.data'): @@ -171,20 +182,36 @@ def test_do(self, eval_artifacts: bool, example_stats_artifacts: bool, ) graphic.image = bytes( ) # ignore graphic.image for below assertions - self.assertIn( - model_card_pb2.Dataset( - name=self.train_dataset_name, - graphics=model_card_pb2.GraphicsCollection(collection=[ - model_card_pb2.Graphic(name='counts | feature_name1', - image='') - ])), model_card_proto.model_parameters.data) - self.assertIn( - model_card_pb2.Dataset( - name=self.eval_dataset_name, - graphics=model_card_pb2.GraphicsCollection(collection=[ - model_card_pb2.Graphic(name='counts | feature_name2', - image='') - ])), model_card_proto.model_parameters.data) + if exec_props: + self.assertNotIn( + model_card_pb2.Dataset( + name=self.train_dataset_name, + graphics=model_card_pb2.GraphicsCollection(collection=[ + model_card_pb2.Graphic(name='counts | feature_name1', + image='') + ])), model_card_proto.model_parameters.data) + self.assertIn( + model_card_pb2.Dataset( + name=self.eval_dataset_name, + graphics=model_card_pb2.GraphicsCollection(collection=[ + model_card_pb2.Graphic(name='counts | feature_name2', + image='') + ])), model_card_proto.model_parameters.data) + else: + self.assertIn( + model_card_pb2.Dataset( + name=self.train_dataset_name, + graphics=model_card_pb2.GraphicsCollection(collection=[ + model_card_pb2.Graphic(name='counts | feature_name1', + image='') + ])), model_card_proto.model_parameters.data) + self.assertIn( + model_card_pb2.Dataset( + name=self.eval_dataset_name, + graphics=model_card_pb2.GraphicsCollection(collection=[ + model_card_pb2.Graphic(name='counts | feature_name2', + image='') + ])), model_card_proto.model_parameters.data) if pushed_model_artifact: with self.subTest(name='pushed_model_artifact'):