Skip to content

Commit

Permalink
Better naming for Stats methods and members (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers authored Dec 5, 2023
1 parent d5390fb commit f8f8093
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 115 deletions.
2 changes: 1 addition & 1 deletion docs/contribute.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ To add a runtime to a plugin:
- `"RuntimeClass": <class_name>`, where `<class_name>` is a unique name for a Python class that inherits `BaseRT` and implements the runtime.
- For example, `"RuntimeClass": ExampleRT` implements the `example` runtime.
- The interface for the runtime class is defined in [Runtime Class](#runtime-class) below.
- (Optional) `"status_stats": List[str]`: a list of keys from the build stats that should be printed out at the end of benchmarking in the CLI's `Status` output. These keys, and corresponding values, must be set in the runtime class using `self.stats.add_build_stat(key, value)`.
- (Optional) `"status_stats": List[str]`: a list of keys from the build stats that should be printed out at the end of benchmarking in the CLI's `Status` output. These keys, and corresponding values, must be set in the runtime class using `self.stats.save_model_eval_stat(key, value)`.
- (Optional) `"requirement_check": Callable`: a callable that runs before each benchmark. This may be used to check whether the device selected is available and functional before each benchmarking run. Exceptions raised during this callable will halt the benchmark of all selected files.

1. Populate the package with the following files (see [Plugin Directory Layout](#plugin-directory-layout)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def benchmark(self) -> MeasuredPerformance:

# Assign values to the stats that will be printed
# out by the CLI when status is reported
self.stats.add_build_stat("magic_perf_points", 42)
self.stats.add_build_stat("super_runtime_points", 100)
self.stats.save_model_eval_stat("magic_perf_points", 42)
self.stats.save_model_eval_stat("super_runtime_points", 100)

return MeasuredPerformance(
mean_latency=self.mean_latency,
Expand Down
62 changes: 33 additions & 29 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,40 +152,40 @@ def explore_invocation(
invocation_info.stats_keys = []

# Create an ID for the build stats by combining the device and runtime.
# We don't need more info in the stats_id because changes to benchmark_model()
# We don't need more info in the evaluation_id because changes to benchmark_model()
# arguments (e.g., sequence) will trigger a rebuild, which is intended to replace the
# build stats so long as the device and runtime have not changed.
stats_id = f"{tracer_args.device}_{selected_runtime}"
evaluation_id = f"{tracer_args.device}_{selected_runtime}"

stats = fs.Stats(
tracer_args.cache_dir,
build_name,
stats_id,
evaluation_id,
)
invocation_info.stats = stats

# Stats that apply to the model, regardless of build
stats.save_stat(
stats.save_model_stat(
fs.Keys.HASH,
model_info.hash,
)
stats.save_stat(
stats.save_model_stat(
fs.Keys.MODEL_NAME,
tracer_args.script_name,
)
stats.save_stat(
stats.save_model_stat(
fs.Keys.PARAMETERS,
model_info.params,
)
if model_info.model_type != build.ModelType.ONNX_FILE:
stats.save_stat(fs.Keys.CLASS, type(model_info.model).__name__)
stats.save_model_stat(fs.Keys.CLASS, type(model_info.model).__name__)
if fs.Keys.AUTHOR in tracer_args.labels:
stats.save_stat(fs.Keys.AUTHOR, tracer_args.labels[fs.Keys.AUTHOR][0])
stats.save_model_stat(fs.Keys.AUTHOR, tracer_args.labels[fs.Keys.AUTHOR][0])
if fs.Keys.TASK in tracer_args.labels:
stats.save_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0])
stats.save_model_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0])

# Save all of the lables in one place
stats.save_stat(fs.Keys.LABELS, tracer_args.labels)
stats.save_model_stat(fs.Keys.LABELS, tracer_args.labels)

# If the input script is a built-in TurnkeyML model, make a note of
# which one
Expand All @@ -203,18 +203,18 @@ def explore_invocation(
fs.MODELS_DIR,
f"https://github.com/onnx/turnkeyml/tree/{git_hash}/models",
).replace("\\", "/")
stats.save_stat(fs.Keys.MODEL_SCRIPT, relative_path)
stats.save_model_stat(fs.Keys.MODEL_SCRIPT, relative_path)

# Build-specific stats
stats.add_build_stat(
stats.save_model_eval_stat(
fs.Keys.DEVICE_TYPE,
tracer_args.device,
)
stats.add_build_stat(
stats.save_model_eval_stat(
fs.Keys.RUNTIME,
selected_runtime,
)
stats.add_build_stat(
stats.save_model_eval_stat(
fs.Keys.ITERATIONS,
tracer_args.iterations,
)
Expand All @@ -233,12 +233,14 @@ def explore_invocation(
# we will try to catch the exception and note it in the stats.
# If a concluded build still has a status of "running", this means
# there was an uncaught exception.
stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.RUNNING)
stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.RUNNING
)

perf = benchmark_model(
model_info.model,
inputs,
stats_id=stats_id,
evaluation_id=evaluation_id,
device=tracer_args.device,
runtime=selected_runtime,
build_name=build_name,
Expand All @@ -263,7 +265,7 @@ def explore_invocation(
invocation_info.status_message = f"Build Error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)

_store_traceback(invocation_info)

Expand All @@ -275,22 +277,22 @@ def explore_invocation(
)
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.KILLED)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.KILLED)

except exp.ArgError as e:
# ArgError indicates that some argument to benchmark_model() was
# illegal. In that case we want to halt execution so that users can
# fix their arguments.

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)

raise e

except exp.Error as e:
invocation_info.status_message = f"Error: {e}."
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)

_store_traceback(invocation_info)

Expand All @@ -300,19 +302,21 @@ def explore_invocation(
invocation_info.status_message = f"Unknown turnkey error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)
stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED)

_store_traceback(invocation_info)
else:
# If there was no exception then we consider the build to be a success
stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.SUCCESSFUL)
stats.save_model_eval_stat(
fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.SUCCESSFUL
)

finally:
# Ensure that stdout/stderr is not being forwarded before updating status
util.stop_logger_forward()

system_info = build.get_system_info()
stats.save_stat(
stats.save_model_stat(
fs.Keys.SYSTEM_INFO,
system_info,
)
Expand All @@ -324,11 +328,11 @@ def explore_invocation(

# ONNX stats that we want to save into the build's turnkey_stats.yaml file
# so that they can be easily accessed by the report command later
if fs.Keys.ONNX_FILE in stats.build_stats.keys():
if fs.Keys.ONNX_FILE in stats.evaluation_stats.keys():
# Just in case the ONNX file was generated on a different machine:
# strip the state's cache dir, then prepend the current cache dir
final_onnx_file = fs.rebase_cache_dir(
stats.build_stats[fs.Keys.ONNX_FILE],
stats.evaluation_stats[fs.Keys.ONNX_FILE],
build_name,
tracer_args.cache_dir,
)
Expand All @@ -337,22 +341,22 @@ def explore_invocation(
onnx_model_info = util.populate_onnx_model_info(final_onnx_file)
onnx_input_dimensions = util.onnx_input_dimensions(final_onnx_file)

stats.save_stat(
stats.save_model_stat(
fs.Keys.ONNX_OPS_COUNTER,
onnx_ops_counter,
)
stats.save_stat(
stats.save_model_stat(
fs.Keys.ONNX_MODEL_INFO,
onnx_model_info,
)
stats.save_stat(
stats.save_model_stat(
fs.Keys.ONNX_INPUT_DIMENSIONS,
onnx_input_dimensions,
)

if perf:
for key, value in vars(perf).items():
stats.add_build_stat(
stats.save_model_eval_stat(
key=key,
value=value,
)
Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/analyze/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def print_invocation(
if unique_invocation.stats_keys is not None:
for key in unique_invocation.stats_keys:
nice_key = _pretty_print_key(key)
value = unique_invocation.stats.build_stats[key]
value = unique_invocation.stats.evaluation_stats[key]
printing.logn(f"{ident}\t\t\t{nice_key}:\t{value}")
print()
else:
Expand Down
30 changes: 20 additions & 10 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,10 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down Expand Up @@ -307,8 +309,10 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down Expand Up @@ -428,8 +432,10 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down Expand Up @@ -492,8 +498,10 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down Expand Up @@ -596,8 +604,10 @@ def fire(self, state: build.State):
if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(
state.cache_dir, state.config.build_name, state.evaluation_id
)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down
5 changes: 2 additions & 3 deletions src/turnkeyml/build/hummingbird.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ def fire(self, state: build.State):
np.save(state.original_inputs_file, state.inputs)

state.intermediate_results = [output_path]
stats = fs.Stats(state.cache_dir, state.config.build_name)
stats.add_sub_stat(
state.stats_id,
stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id)
stats.save_model_eval_stat(
fs.Keys.ONNX_FILE,
output_path,
)
Expand Down
4 changes: 2 additions & 2 deletions src/turnkeyml/build/ignition.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _rebuild_if_needed(

def load_or_make_state(
config: build.Config,
stats_id: str,
evaluation_id: str,
cache_dir: str,
rebuild: str,
model_type: build.ModelType,
Expand All @@ -274,7 +274,7 @@ def load_or_make_state(
"inputs": inputs,
"monitor": monitor,
"rebuild": rebuild,
"stats_id": stats_id,
"evaluation_id": evaluation_id,
"cache_dir": cache_dir,
"config": config,
"model_type": model_type,
Expand Down
6 changes: 3 additions & 3 deletions src/turnkeyml/build/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def launch(self, state: build.State) -> build.State:
raise exp.Error(msg)

# Collect telemetry for the build
stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id)
stats.save_model_eval_stat(
fs.Keys.ALL_BUILD_STAGES,
self.get_names(),
)
Expand All @@ -292,7 +292,7 @@ def launch(self, state: build.State) -> build.State:
# Collect telemetry about the stage
execution_time = time.time() - start_time

stats.add_build_sub_stat(
stats.save_model_eval_sub_stat(
parent_key=fs.Keys.COMPLETED_BUILD_STAGES,
key=stage.unique_name,
value=execution_time,
Expand Down
6 changes: 3 additions & 3 deletions src/turnkeyml/build_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def build_model(
model: build.UnionValidModelInstanceTypes = None,
inputs: Optional[Dict[str, Any]] = None,
build_name: Optional[str] = None,
stats_id: Optional[str] = "build",
evaluation_id: Optional[str] = "build",
cache_dir: str = filesystem.DEFAULT_CACHE_DIR,
monitor: Optional[bool] = None,
rebuild: Optional[str] = None,
Expand All @@ -30,7 +30,7 @@ def build_model(
build_name: Unique name for the model that will be
used to store the ONNX file and build state on disk. Defaults to the
name of the file that calls build_model().
stats_id: Unique name for build statistics that should persist across multiple
evaluation_id: Unique name for evaluation statistics that should persist across multiple
builds of the same model.
cache_dir: Directory to use as the cache for this build. Output files
from this build will be stored at cache_dir/build_name/
Expand Down Expand Up @@ -94,7 +94,7 @@ def build_model(
# Get the state of the model from the cache if a valid build is available
state = ignition.load_or_make_state(
config=config,
stats_id=stats_id,
evaluation_id=evaluation_id,
cache_dir=parsed_cache_dir,
rebuild=rebuild or build.DEFAULT_REBUILD_POLICY,
model_type=model_type,
Expand Down
Loading

0 comments on commit f8f8093

Please sign in to comment.