diff --git a/api.env.sample b/api.env.sample index a74b151..3d3ae09 100644 --- a/api.env.sample +++ b/api.env.sample @@ -3,4 +3,5 @@ REDIS_PORT=6379 TDS_URL=http://data-service.staging.terarium.ai SKEMA_RS_URL=http://skema-rs.staging.terarium.ai TA1_UNIFIED_URL=http://skema-unified.staging.terarium.ai -MIT_TR_URL=http://mit-tr.staging.terarium.ai \ No newline at end of file +MIT_TR_URL=http://mit-tr.staging.terarium.ai +LOG_LEVEL=INFO \ No newline at end of file diff --git a/api/server.py b/api/server.py index 0960414..c4abd34 100644 --- a/api/server.py +++ b/api/server.py @@ -176,7 +176,7 @@ def profile_dataset(dataset_id: str, artifact_id: Optional[str] = None): """ from utils import create_job - operation_name = "operations.dataset_profiling_with_document" + operation_name = "operations.dataset_card" options = { "dataset_id": dataset_id, @@ -187,6 +187,32 @@ def profile_dataset(dataset_id: str, artifact_id: Optional[str] = None): return resp +@app.post("/profile_model/{model_id}") +def profile_model(model_id: str, paper_artifact_id: str): + """Profile model with MIT's profiling service. This takes in a paper and code artifact + and updates a model (AMR) with the profiled metadata card. It requires that the paper + has been extracted with `/pdf_to_text` and the code has been converted to an AMR + with `/code_to_amr` + + > NOTE: if nothing the paper is not extracted and the model not created from code this WILL fail. + + Args: + model_id: the id of the model to profile + paper_artifact_id: the id of the paper artifact + """ + from utils import create_job + + operation_name = "operations.model_card" + + options = { + "model_id": model_id, + "paper_artifact_id": paper_artifact_id + } + + resp = create_job(operation_name=operation_name, options=options) + + return resp + @app.post("/link_amr") def link_amr(artifact_id: str, model_id: str): diff --git a/api/utils.py b/api/utils.py index 3f4ca1a..7c27f61 100644 --- a/api/utils.py +++ b/api/utils.py @@ -18,8 +18,14 @@ from rq.exceptions import NoSuchJobError from rq.job import Job +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set + +numeric_level = getattr(logging, LOG_LEVEL, None) +if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {LOG_LEVEL}') + logging.basicConfig() -logging.getLogger().setLevel(logging.DEBUG) +logging.getLogger().setLevel(numeric_level) # REDIS CONNECTION AND QUEUE OBJECTS diff --git a/docker-compose.yaml b/docker-compose.yaml index 47f2f70..81ec980 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -6,7 +6,7 @@ networks: data-api: external: true services: - api: + extraction-api: container_name: api-ta1-extraction-service build: context: ./ diff --git a/workers/operations.py b/workers/operations.py index f31230c..f14bfc4 100644 --- a/workers/operations.py +++ b/workers/operations.py @@ -11,17 +11,25 @@ put_artifact_extraction_to_tds, get_artifact_from_tds, get_dataset_from_tds, + get_model_from_tds, + set_provenance, + find_source_code ) TDS_API = os.getenv("TDS_URL") SKEMA_API = os.getenv("SKEMA_RS_URL") UNIFIED_API = os.getenv("TA1_UNIFIED_URL") MIT_API = os.getenv("MIT_TR_URL") +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set import logging +numeric_level = getattr(logging, LOG_LEVEL, None) +if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {LOG_LEVEL}') + logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) +logger.setLevel(numeric_level) handler = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) @@ -59,6 +67,7 @@ def equations_to_amr(*args, **kwargs): ) try: amr_json = amr_response.json() + logger.debug(f"TA 1 response object: {amr_response}") except: logger.error(f"Failed to parse response from TA1 Service: {amr_response.text}") @@ -112,6 +121,7 @@ def pdf_to_text(*args, **kwargs): f"Response received from TA1 with status code: {response.status_code}" ) extraction_json = response.json() + logger.debug(f"TA 1 response object: {extraction_json}") text = '' for d in extraction_json: text += f"{d['content']}\n" @@ -180,6 +190,7 @@ def pdf_extractions(*args, **kwargs): f"Response received from TA1 with status code: {response.status_code}" ) extraction_json = response.json() + logger.debug(f"TA 1 response object: {extraction_json}") outputs = extraction_json["outputs"] if isinstance(outputs, dict): @@ -233,7 +244,7 @@ def pdf_extractions(*args, **kwargs): return response -def dataset_profiling_with_document(*args, **kwargs): +def data_card(*args, **kwargs): openai_key = os.getenv("OPENAI_API_KEY") dataset_id = kwargs.get("dataset_id") @@ -266,7 +277,7 @@ def dataset_profiling_with_document(*args, **kwargs): resp = requests.post(f"{MIT_API}/cards/get_data_card", params=params, files=files) logger.info(f"Response received from MIT with status: {resp.status_code}") - logger.debug(f"MIT ANNOTATIONS: {resp.json()}") + logger.debug(f"TA 1 response object: {resp.json()}") mit_annotations = resp.json()['DATA_PROFILING_RESULT'] @@ -301,6 +312,77 @@ def dataset_profiling_with_document(*args, **kwargs): return resp.json() +def model_card(*args, **kwargs): + openai_key = os.getenv("OPENAI_API_KEY") + model_id = kwargs.get("model_id") + paper_artifact_id = kwargs.get("paper_artifact_id") + + try: + code_artifact_id = find_source_code(model_id) + if code_artifact_id: + code_artifact_json, code_downloaded_artifact = get_artifact_from_tds(artifact_id=code_artifact_id) + code_file = code_downloaded_artifact.decode('utf-8') + else: + logger.info(f"No associated code artifact found for model {model_id}") + code_file = "No available code associated with model." + except Exception as e: + logger.error(f"Issue finding associated source code: {e}") + code_file = "No available code associated with model." + + logger.debug(f"Code file head (250 chars): {code_file[:250]}") + + paper_artifact_json, paper_downloaded_artifact = get_artifact_from_tds(artifact_id=paper_artifact_id) + text_file = paper_artifact_json['metadata'].get('text', 'There is no documentation for this model').encode() + + amr = get_model_from_tds(model_id).json() + + params = { + 'gpt_key': openai_key + } + + files = { + 'text_file': ('text_file', text_file), + 'code_file': ('doc_file', code_file) + } + + logger.info(f"Sending model {model_id} to MIT service") + + resp = requests.post(f"{MIT_API}/cards/get_model_card", params=params, files=files) + logger.info(f"Response received from MIT with status: {resp.status_code}") + logger.debug(f"TA 1 response object: {resp.json()}") + + if resp.status_code == 200: + try: + card = resp.json() + sys.stdout.flush() + + amr['description'] = card.get('DESCRIPTION') + if not amr.get('metadata',None): + amr['metadata'] = {'card': card} + else: + amr['metadata']['card'] = card + + tds_resp = requests.put(f"{TDS_API}/models/{model_id}", json=amr) + if tds_resp.status_code == 200: + logger.info(f"Updated model {model_id} in TDS: {tds_resp.status_code}") + return { + "status": tds_resp.status_code, + "message": "Model card generated and updated in TDS", + } + else: + raise Exception(f"Error when updating model {model_id} in TDS: {tds_resp.status_code}") + except Exception as e: + logger.error(f"Failed to generate model card for {model_id}: {e}") + return { + "status": 500, + "message": f"Error: {e}", + } + else: + logger.error(f"Bad response from TA1 for {model_id}: {resp.status_code}") + return { + "status": {resp.status_code}, + "message": f"Error: {resp.text}", + } # dccde3a0-0132-430c-afd8-c67953298f48 # 77a2dffb-08b3-4f6e-bfe5-83d27ed259c4 @@ -338,6 +420,7 @@ def link_amr(*args, **kwargs): skema_amr_linking_url = f"{UNIFIED_API}/metal/link_amr" response = requests.post(skema_amr_linking_url, files=files, params=params) + logger.debug(f"TA 1 response object: {response.json()}") if response.status_code == 200: enriched_amr = response.json() @@ -373,6 +456,7 @@ def code_to_amr(*args, **kwargs): artifact_json, downloaded_artifact = get_artifact_from_tds(artifact_id=artifact_id) code_blob = downloaded_artifact.decode("utf-8") + logger.info(code_blob[:250]) code_amr_workflow_url = f"{UNIFIED_API}/workflows/code/snippets-to-pn-amr" request_payload = { @@ -390,12 +474,27 @@ def code_to_amr(*args, **kwargs): try: amr_json = amr_response.json() + logger.debug(f"TA 1 response object: {amr_json}") except: logger.error(f"Failed to parse response from TA1 Service:\n{amr_response.text}") + pass if amr_response.status_code == 200 and amr_json: tds_responses = put_amr_to_tds(amr_json, name, description) + put_artifact_extraction_to_tds( + artifact_id=artifact_id, + name=artifact_json.get("name", None), + filename=artifact_json.get("file_names")[0], + description=artifact_json.get("description", None), + model_id=tds_responses.get("model_id") + ) + + try: + set_provenance(tds_responses.get("model_id"), 'Model', artifact_id, 'Artifact', 'EXTRACTED_FROM') + except Exception as e: + logger.error(f"Failed to store provenance tying model to code artifact: {e}") + response = { "status_code": amr_response.status_code, "amr": amr_json, diff --git a/workers/utils.py b/workers/utils.py index b74546e..4dbe79c 100644 --- a/workers/utils.py +++ b/workers/utils.py @@ -6,8 +6,14 @@ import pandas +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set + import logging +numeric_level = getattr(logging, LOG_LEVEL, None) +if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {LOG_LEVEL}') + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = logging.StreamHandler() @@ -60,7 +66,7 @@ def put_amr_to_tds(amr_payload, name=None, description=None): def put_artifact_extraction_to_tds( - artifact_id, name, description, filename, extractions=None, text=None + artifact_id, name, description, filename, extractions=None, text=None, model_id=None ): if extractions and text: metadata = extractions[0] @@ -69,6 +75,8 @@ def put_artifact_extraction_to_tds( metadata = extractions[0] elif text: metadata = {'text': text} + elif model_id: + metadata = {'model_id': model_id} else: metadata = {} @@ -146,3 +154,56 @@ def get_dataset_from_tds(dataset_id): csv_string = final_df.to_csv(index=False) return dataset, final_df, csv_string + +def get_model_from_tds(model_id): + tds_model_url = f"{TDS_API}/models/{model_id}" + model = requests.get(tds_model_url) + return model + + +def set_provenance( + left_id, left_type, right_id, right_type, relation_type +): + """ + Creates a provenance record in TDS. Used during code to model to associate the + code artifact with the model AMR + """ + + provenance_payload = { + "relation_type": relation_type, + "left": left_id, + "left_type": left_type, + "right": right_id, + "right_type": right_type + } + + # Create TDS provenance + tds_provenance = f"{TDS_API}/provenance" + provenance_resp = requests.post(tds_provenance, json=provenance_payload) + if provenance_resp.status_code == 200: + logger.info(f"Stored provenance to TDS for left {left_id} and right {right_id}") + else: + logger.error(f"Storing provenance failed: {provenance_resp.text}") + + return {"status": provenance_resp.status_code} + +def find_source_code( + model_id +): + """ + For a given model id, finds the associated source code artifact from which it was extracted + """ + + payload = { + "root_id": model_id, + "root_type": "Model" + } + + tds_provenance = f"{TDS_API}/provenance/search?search_type=models_from_code" + resp = requests.post(tds_provenance, json=payload) + logger.info(f"Provenance code lookup for model ID {model_id}: {resp.json()}") + results = resp.json().get('result',[]) + if len(results) > 0: + return results[0] + else: + return None \ No newline at end of file