Skip to content

Commit

Permalink
support remote storage in modelcardgenerator (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongukjae authored Aug 25, 2023
1 parent b2151e5 commit 96d6b63
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tfx_addons/model_card_generator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
ModelCardGenerator.
"""

import tempfile
from typing import Any, Dict, List, Optional

from model_card_toolkit import core
from model_card_toolkit.utils import source as src
from tfx import types
from tfx.dsl.components.base.base_executor import BaseExecutor
from tfx.types import artifact_utils, standard_component_specs
from tfx.utils import io_utils

_DEFAULT_MODEL_CARD_FILE_NAME = 'model_card.html'

Expand Down Expand Up @@ -128,14 +130,16 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
metrics_include.
"""

# Make local temp directory to support remote storage (gcs, s3, ...)
temp_output_dir = tempfile.mkdtemp()

# Initialize ModelCardToolkit
mct = core.ModelCardToolkit(source=src.Source(
tfma=self._tfma_source(input_dict, exec_properties),
tfdv=self._tfdv_source(input_dict, exec_properties),
model=self._model_source(input_dict),
),
output_dir=artifact_utils.get_single_instance(
output_dict['model_card']).uri)
output_dir=temp_output_dir)
template_io = exec_properties.get('template_io') or [
(mct.default_template, _DEFAULT_MODEL_CARD_FILE_NAME)
]
Expand All @@ -144,3 +148,7 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
mct.scaffold_assets(json=exec_properties.get('json'))
for template_path, output_file in template_io:
mct.export_format(template_path=template_path, output_file=output_file)

# Copy all files to output_dir
output_dir = artifact_utils.get_single_uri(output_dict["model_card"])
io_utils.copy_dir(src=temp_output_dir, dst=output_dir)

0 comments on commit 96d6b63

Please sign in to comment.