Skip to content

Commit

Permalink
Utilise model metadata to resolve model output type and shape; this i…
Browse files Browse the repository at this point in the history
…s currently only implemented for the single API (run)

case
  • Loading branch information
DhanshreeA committed Jul 3, 2024
1 parent 0de7021 commit ba81e82
Showing 1 changed file with 86 additions and 15 deletions.
101 changes: 86 additions & 15 deletions ersilia/hub/fetch/actions/sniff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,12 @@
from ....io.input import ExampleGenerator
from ....io.pure import PureDataTyper
from ....io.annotated import AnnotatedDataTyper
from ....default import API_SCHEMA_FILE, MODEL_SIZE_FILE, METADATA_JSON_FILE
from ....default import API_SCHEMA_FILE, MODEL_SIZE_FILE, METADATA_JSON_FILE, PREDEFINED_EXAMPLE_FILENAME
from ....utils.exceptions_utils.exceptions import EmptyOutputError
from ....utils.exceptions_utils.fetch_exceptions import (
OutputDataTypesNotConsistentError,
)


N = 3

BUILTIN_EXAMPLE_FILE_NAME = "example.csv"


class BuiltinExampleReader(ErsiliaBase):
def __init__(self, model_id, config_json):
ErsiliaBase.__init__(self, config_json=config_json, credentials_json=None)
Expand All @@ -33,7 +27,7 @@ def __init__(self, model_id, config_json):
self.model_id,
"artifacts",
"framework",
BUILTIN_EXAMPLE_FILE_NAME,
PREDEFINED_EXAMPLE_FILENAME,
)

def has_builtin_example(self):
Expand All @@ -42,7 +36,7 @@ def has_builtin_example(self):
else:
return False

def example(self, n):
def example(self, n=3):
data = []
with open(self.example_file, "r") as f:
reader = csv.reader(f)
Expand All @@ -66,7 +60,7 @@ def __init__(self, model_id, config_json):
er = BuiltinExampleReader(model_id, config_json=config_json)
if er.has_builtin_example():
self.logger.debug("Built-in example found")
self.inputs = er.example(N)
self.inputs = er.example()
else:
self.logger.debug("No built-in example available. Generating a test one.")
eg = ExampleGenerator(model_id, config_json=config_json)
Expand Down Expand Up @@ -196,6 +190,59 @@ def _get_schema(self, results):
self.logger.debug("Schema: {0}".format(schema))
self.logger.debug("Done with the schema!")
return schema

@throw_ersilia_exception
def _get_schema_type_for_simple_run_api_case(self):

# read metadata
dest_dir = self._model_path(self.model_id)
metadata_file = os.path.join(dest_dir, METADATA_JSON_FILE)
if not os.path.exists(metadata_file):
self.logger.debug("Metadata file not available (yet)")
return None
with open(metadata_file, "r") as f:
metadata = json.load(f)

# get output type from metadata.json
output_type = metadata["Output Type"]
if len(output_type) == 1:
self.logger.debug("Output type is {0}".format(output_type[0]))
output_type = output_type[0]
elif len(output_type) == 2:
if set(output_type) == set(["Integer", "Float"]):
output_type = "Float"
else:
return None
else:
return None
if output_type not in ["Float", "String"]:
return None

# get output shape from metadata.json
output_shape = metadata["Output Shape"]
if output_shape not in ["Single", "List"]:
return None

def resolve_output_meta_in_schema(output_type, output_shape):
if output_shape == "Single" and output_type == "Float":
return "numeric"
if output_shape == "Single" and output_type == "String":
return "string"
if output_shape == "List" and output_type == "Float":
return "numeric_array"
if output_shape == "List" and output_type == "String":
return "string_array"

output_meta_in_schema = resolve_output_meta_in_schema(output_type, output_shape)

return output_meta_in_schema

def _try_to_resolve_output_shape(self, meta, output_type):
if output_type == "numeric_array" or output_type == "string_array":
if type(meta) is list:
shape = (len(meta),)
return shape
return None

@throw_ersilia_exception
def sniff(self):
Expand All @@ -210,7 +257,11 @@ def sniff(self):
self.model.autoservice.serve()
self.logger.debug("Iterating over APIs")
all_schemas = {}
for api_name in self.model.autoservice.get_apis():

is_schema_done = False
api_names = self.model.autoservice.get_apis()

def get_results(api_name):
self.logger.debug("Running API: {0}".format(api_name))
self.logger.debug(self.inputs)
results = [
Expand All @@ -221,10 +272,30 @@ def sniff(self):
for r in results:
if not r["output"]:
raise EmptyOutputError(model_id=self.model_id, api_name=api_name)
self.logger.debug("Getting schema for API {0}...".format(api_name))
schema = self._get_schema(results)
self.logger.debug("This is the schema {0}".format(schema))
all_schemas[api_name] = schema
return results

# try to get schema without making calculations, just reading metadata and output file
# this only works when the 'run' API is the only one.
if len(api_names) == 1 and api_names[0] == "run":
schema_type_backup = self._get_schema_type_for_simple_run_api_case()
self.logger.debug("This is the schema {0}".format(schema_type_backup))

if not is_schema_done:
for api_name in self.model.autoservice.get_apis():
self.logger.debug("Getting schema for API {0}...".format(api_name))
results = get_results(api_name)
schema = self._get_schema(results)
self.logger.debug("This is the schema {0}".format(schema))
if api_name == "run":
if "outcome" in schema["output"]:
if schema["output"]["outcome"]["type"] is None:
schema["output"]["outcome"]["type"] = schema_type_backup
if "shape" not in schema["output"]["outcome"]:
shape = self._try_to_resolve_output_shape(schema["output"]["outcome"]["meta"], schema["output"]["outcome"]["type"])
if shape is not None:
schema["output"]["outcome"]["shape"] = shape
all_schemas[api_name] = schema

path = os.path.join(self._model_path(self.model_id), API_SCHEMA_FILE)
with open(path, "w") as f:
json.dump(all_schemas, f, indent=4)
Expand Down

0 comments on commit ba81e82

Please sign in to comment.