diff --git a/models/readme.md b/models/readme.md index 2917b57..c726f8e 100644 --- a/models/readme.md +++ b/models/readme.md @@ -114,7 +114,7 @@ Example: # labels: author::google test_group::daily,monthly ``` -Labels are saved in your cache directory and can later be retrieved using the function `turnkey.common.labels.load_from_cache()`, which receives the `cache_dir` and `build_name` as inputs and returns the labels as a dictionary. +Labels are saved in your cache directory in the `turnkey_stats.yaml` file under the "labels" key. ### Parameters diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index 27c13ef..3c213bb 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -138,11 +138,6 @@ def explore_invocation( inputs[all_args[i]] = args[i] invocation_info.inputs = inputs - # Save model labels - if model_info.model_type != build.ModelType.ONNX_FILE: - tracer_args.labels["class"] = [f"{type(model_info.model).__name__}"] - labels.save_to_cache(tracer_args.cache_dir, build_name, tracer_args.labels) - # If the user has not provided a specific runtime, select the runtime # based on the device provided. if tracer_args.runtime is None: @@ -182,13 +177,16 @@ def explore_invocation( fs.Keys.PARAMETERS, model_info.params, ) + if model_info.model_type != build.ModelType.ONNX_FILE: + stats.save_stat(fs.Keys.CLASS, type(model_info.model).__name__) if fs.Keys.AUTHOR in tracer_args.labels: stats.save_stat(fs.Keys.AUTHOR, tracer_args.labels[fs.Keys.AUTHOR][0]) - if fs.Keys.CLASS in tracer_args.labels: - stats.save_stat(fs.Keys.CLASS, tracer_args.labels[fs.Keys.CLASS][0]) if fs.Keys.TASK in tracer_args.labels: stats.save_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0]) + # Save all of the lables in one place + stats.save_stat(fs.Keys.LABELS, tracer_args.labels) + # If the input script is a built-in TurnkeyML model, make a note of # which one if os.path.abspath(fs.MODELS_DIR) in os.path.abspath(tracer_args.input): diff --git a/src/turnkeyml/common/filesystem.py b/src/turnkeyml/common/filesystem.py index 082ecef..2767d15 100644 --- a/src/turnkeyml/common/filesystem.py +++ b/src/turnkeyml/common/filesystem.py @@ -333,6 +333,8 @@ class Keys: MODEL_NAME = "model_name" # References the per-build stats section BUILDS = "builds" + # Catch-all for storing a file's labels + LABELS = "labels" # Author of the model AUTHOR = "author" # Class type of the model diff --git a/src/turnkeyml/common/labels.py b/src/turnkeyml/common/labels.py index 812962e..be3d2e1 100644 --- a/src/turnkeyml/common/labels.py +++ b/src/turnkeyml/common/labels.py @@ -1,4 +1,3 @@ -import os from typing import Dict, List import turnkeyml.common.printing as printing @@ -44,36 +43,6 @@ def load_from_file(file_path: str) -> Dict[str, List[str]]: return {} -def load_from_cache(cache_dir: str, build_name: str) -> Dict[str, List[str]]: - """ - Loads labels from the cache directory - """ - # Open file - file_path = os.path.join(cache_dir, "labels", f"{build_name}.txt") - with open(file_path, encoding="utf-8") as f: - first_line = f.readline() - - # Return label dict - label_list = first_line.replace("\n", "").split(" ") - return to_dict(label_list) - - -def save_to_cache(cache_dir: str, build_name: str, label_dict: Dict[str, List[str]]): - """ - Save labels as a stand-alone file as part of the cache directory - """ - labels_list = [f"{k}::{','.join(label_dict[k])}" for k in label_dict.keys()] - - # Create labels folder if it doesn't exist - labels_dir = os.path.join(cache_dir, "labels") - os.makedirs(labels_dir, exist_ok=True) - - # Save labels to cache - file_path = os.path.join(labels_dir, f"{build_name}.txt") - with open(file_path, "w", encoding="utf8") as fp: - fp.write(" ".join(labels_list)) - - def is_subset(label_dict_a: Dict[str, List[str]], label_dict_b: Dict[str, List[str]]): """ This function returns True if label_dict_a is a subset of label_dict_b. diff --git a/test/analysis.py b/test/analysis.py index 8d6581e..598d3e7 100644 --- a/test/analysis.py +++ b/test/analysis.py @@ -36,8 +36,7 @@ # filesystem access test_scripts_dot_py = { - "linear_pytorch.py": """ -# labels: test_group::selftest license::mit framework::pytorch tags::selftest,small + "linear_pytorch.py": """# labels: test_group::selftest license::mit framework::pytorch tags::selftest,small import torch import argparse @@ -235,8 +234,10 @@ def test_05_cache(self): ] ) build_name = f"linear_pytorch_{model_hash}" - labels_found = labels.load_from_cache(cache_dir, build_name) != {} - assert cache_is_lean(cache_dir, build_name) and labels_found + labels_found = filesystem.Stats(cache_dir, build_name).stats[ + filesystem.Keys.LABELS + ] + assert cache_is_lean(cache_dir, build_name) and labels_found != {}, labels_found def test_06_generic_args(self): output = run_cli( diff --git a/test/cli.py b/test/cli.py index e6ce708..3cf3835 100644 --- a/test/cli.py +++ b/test/cli.py @@ -311,9 +311,10 @@ def test_021_cli_report(self): ] linear_summary = summary[1] assert len(summary) == len(test_scripts) - assert all( - elem in linear_summary for elem in expected_cols - ), f"Looked for each of {expected_cols} in {linear_summary.keys()}" + for elem in expected_cols: + assert ( + elem in linear_summary + ), f"Couldn't find expected key {elem} in results spreadsheet" # Check whether all rows we expect to be populated are actually populated assert (