Skip to content

Commit

Permalink
Utilise model metadata to resolve model output type and shape
Browse files Browse the repository at this point in the history
This is currently only implemented for the single API (run) (#1184)
  • Loading branch information
DhanshreeA authored Jul 3, 2024
1 parent 0de7021 commit debe65e
Showing 1 changed file with 86 additions and 15 deletions.
101 changes: 86 additions & 15 deletions ersilia/hub/fetch/actions/sniff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,12 @@
from ....io.input import ExampleGenerator
from ....io.pure import PureDataTyper
from ....io.annotated import AnnotatedDataTyper
from ....default import API_SCHEMA_FILE, MODEL_SIZE_FILE, METADATA_JSON_FILE
from ....default import API_SCHEMA_FILE, MODEL_SIZE_FILE, METADATA_JSON_FILE, PREDEFINED_EXAMPLE_FILENAME
from ....utils.exceptions_utils.exceptions import EmptyOutputError
from ....utils.exceptions_utils.fetch_exceptions import (
OutputDataTypesNotConsistentError,
)


N = 3

BUILTIN_EXAMPLE_FILE_NAME = "example.csv"


class BuiltinExampleReader(ErsiliaBase):
def __init__(self, model_id, config_json):
ErsiliaBase.__init__(self, config_json=config_json, credentials_json=None)
Expand All @@ -33,7 +27,7 @@ def __init__(self, model_id, config_json):
self.model_id,
"artifacts",
"framework",
BUILTIN_EXAMPLE_FILE_NAME,
PREDEFINED_EXAMPLE_FILENAME,
)

def has_builtin_example(self):
Expand All @@ -42,7 +36,7 @@ def has_builtin_example(self):
else:
return False

def example(self, n):
def example(self, n=3):
data = []
with open(self.example_file, "r") as f:
reader = csv.reader(f)
Expand All @@ -66,7 +60,7 @@ def __init__(self, model_id, config_json):
er = BuiltinExampleReader(model_id, config_json=config_json)
if er.has_builtin_example():
self.logger.debug("Built-in example found")
self.inputs = er.example(N)
self.inputs = er.example()
else:
self.logger.debug("No built-in example available. Generating a test one.")
eg = ExampleGenerator(model_id, config_json=config_json)
Expand Down Expand Up @@ -196,6 +190,59 @@ def _get_schema(self, results):
self.logger.debug("Schema: {0}".format(schema))
self.logger.debug("Done with the schema!")
return schema

@throw_ersilia_exception
def _get_schema_type_for_simple_run_api_case(self):

# read metadata
dest_dir = self._model_path(self.model_id)
metadata_file = os.path.join(dest_dir, METADATA_JSON_FILE)
if not os.path.exists(metadata_file):
self.logger.debug("Metadata file not available (yet)")
return None
with open(metadata_file, "r") as f:
metadata = json.load(f)

# get output type from metadata.json
output_type = metadata["Output Type"]
if len(output_type) == 1:
self.logger.debug("Output type is {0}".format(output_type[0]))
output_type = output_type[0]
elif len(output_type) == 2:
if set(output_type) == set(["Integer", "Float"]):
output_type = "Float"
else:
return None
else:
return None
if output_type not in ["Float", "String"]:
return None

# get output shape from metadata.json
output_shape = metadata["Output Shape"]
if output_shape not in ["Single", "List"]:
return None

def resolve_output_meta_in_schema(output_type, output_shape):
if output_shape == "Single" and output_type == "Float":
return "numeric"
if output_shape == "Single" and output_type == "String":
return "string"
if output_shape == "List" and output_type == "Float":
return "numeric_array"
if output_shape == "List" and output_type == "String":
return "string_array"

output_meta_in_schema = resolve_output_meta_in_schema(output_type, output_shape)

return output_meta_in_schema

def _try_to_resolve_output_shape(self, meta, output_type):
if output_type == "numeric_array" or output_type == "string_array":
if type(meta) is list:
shape = (len(meta),)
return shape
return None

@throw_ersilia_exception
def sniff(self):
Expand All @@ -210,7 +257,11 @@ def sniff(self):
self.model.autoservice.serve()
self.logger.debug("Iterating over APIs")
all_schemas = {}
for api_name in self.model.autoservice.get_apis():

is_schema_done = False
api_names = self.model.autoservice.get_apis()

def get_results(api_name):
self.logger.debug("Running API: {0}".format(api_name))
self.logger.debug(self.inputs)
results = [
Expand All @@ -221,10 +272,30 @@ def sniff(self):
for r in results:
if not r["output"]:
raise EmptyOutputError(model_id=self.model_id, api_name=api_name)
self.logger.debug("Getting schema for API {0}...".format(api_name))
schema = self._get_schema(results)
self.logger.debug("This is the schema {0}".format(schema))
all_schemas[api_name] = schema
return results

# try to get schema without making calculations, just reading metadata and output file
# this only works when the 'run' API is the only one.
if len(api_names) == 1 and api_names[0] == "run":
schema_type_backup = self._get_schema_type_for_simple_run_api_case()
self.logger.debug("This is the schema {0}".format(schema_type_backup))

if not is_schema_done:
for api_name in self.model.autoservice.get_apis():
self.logger.debug("Getting schema for API {0}...".format(api_name))
results = get_results(api_name)
schema = self._get_schema(results)
self.logger.debug("This is the schema {0}".format(schema))
if api_name == "run":
if "outcome" in schema["output"]:
if schema["output"]["outcome"]["type"] is None:
schema["output"]["outcome"]["type"] = schema_type_backup
if "shape" not in schema["output"]["outcome"]:
shape = self._try_to_resolve_output_shape(schema["output"]["outcome"]["meta"], schema["output"]["outcome"]["type"])
if shape is not None:
schema["output"]["outcome"]["shape"] = shape
all_schemas[api_name] = schema

path = os.path.join(self._model_path(self.model_id), API_SCHEMA_FILE)
with open(path, "w") as f:
json.dump(all_schemas, f, indent=4)
Expand Down

0 comments on commit debe65e

Please sign in to comment.