Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model card #7

Merged
merged 5 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ REDIS_PORT=6379
TDS_URL=http://data-service.staging.terarium.ai
SKEMA_RS_URL=http://skema-rs.staging.terarium.ai
TA1_UNIFIED_URL=http://skema-unified.staging.terarium.ai
MIT_TR_URL=http://mit-tr.staging.terarium.ai
MIT_TR_URL=http://mit-tr.staging.terarium.ai
LOG_LEVEL=INFO
28 changes: 27 additions & 1 deletion api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def profile_dataset(dataset_id: str, artifact_id: Optional[str] = None):
"""
from utils import create_job

operation_name = "operations.dataset_profiling_with_document"
operation_name = "operations.dataset_card"

options = {
"dataset_id": dataset_id,
Expand All @@ -187,6 +187,32 @@ def profile_dataset(dataset_id: str, artifact_id: Optional[str] = None):

return resp

@app.post("/profile_model/{model_id}")
def profile_model(model_id: str, paper_artifact_id: str):
"""Profile model with MIT's profiling service. This takes in a paper and code artifact
and updates a model (AMR) with the profiled metadata card. It requires that the paper
has been extracted with `/pdf_to_text` and the code has been converted to an AMR
with `/code_to_amr`

> NOTE: if nothing the paper is not extracted and the model not created from code this WILL fail.

Args:
model_id: the id of the model to profile
paper_artifact_id: the id of the paper artifact
"""
from utils import create_job

operation_name = "operations.model_card"

options = {
"model_id": model_id,
"paper_artifact_id": paper_artifact_id
}

resp = create_job(operation_name=operation_name, options=options)

return resp


@app.post("/link_amr")
def link_amr(artifact_id: str, model_id: str):
Expand Down
8 changes: 7 additions & 1 deletion api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@
from rq.exceptions import NoSuchJobError
from rq.job import Job

LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set

numeric_level = getattr(logging, LOG_LEVEL, None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {LOG_LEVEL}')

logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger().setLevel(numeric_level)


# REDIS CONNECTION AND QUEUE OBJECTS
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ networks:
data-api:
external: true
services:
api:
extraction-api:
container_name: api-ta1-extraction-service
build:
context: ./
Expand Down
105 changes: 102 additions & 3 deletions workers/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
put_artifact_extraction_to_tds,
get_artifact_from_tds,
get_dataset_from_tds,
get_model_from_tds,
set_provenance,
find_source_code
)

TDS_API = os.getenv("TDS_URL")
SKEMA_API = os.getenv("SKEMA_RS_URL")
UNIFIED_API = os.getenv("TA1_UNIFIED_URL")
MIT_API = os.getenv("MIT_TR_URL")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set

import logging

numeric_level = getattr(logging, LOG_LEVEL, None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {LOG_LEVEL}')

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.setLevel(numeric_level)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
Expand Down Expand Up @@ -59,6 +67,7 @@ def equations_to_amr(*args, **kwargs):
)
try:
amr_json = amr_response.json()
logger.debug(f"TA 1 response object: {amr_response}")
except:
logger.error(f"Failed to parse response from TA1 Service: {amr_response.text}")

Expand Down Expand Up @@ -112,6 +121,7 @@ def pdf_to_text(*args, **kwargs):
f"Response received from TA1 with status code: {response.status_code}"
)
extraction_json = response.json()
logger.debug(f"TA 1 response object: {extraction_json}")
text = ''
for d in extraction_json:
text += f"{d['content']}\n"
Expand Down Expand Up @@ -180,6 +190,7 @@ def pdf_extractions(*args, **kwargs):
f"Response received from TA1 with status code: {response.status_code}"
)
extraction_json = response.json()
logger.debug(f"TA 1 response object: {extraction_json}")
outputs = extraction_json["outputs"]

if isinstance(outputs, dict):
Expand Down Expand Up @@ -233,7 +244,7 @@ def pdf_extractions(*args, **kwargs):
return response


def dataset_profiling_with_document(*args, **kwargs):
def data_card(*args, **kwargs):
openai_key = os.getenv("OPENAI_API_KEY")

dataset_id = kwargs.get("dataset_id")
Expand Down Expand Up @@ -266,7 +277,7 @@ def dataset_profiling_with_document(*args, **kwargs):
resp = requests.post(f"{MIT_API}/cards/get_data_card", params=params, files=files)

logger.info(f"Response received from MIT with status: {resp.status_code}")
logger.debug(f"MIT ANNOTATIONS: {resp.json()}")
logger.debug(f"TA 1 response object: {resp.json()}")

mit_annotations = resp.json()['DATA_PROFILING_RESULT']

Expand Down Expand Up @@ -301,6 +312,77 @@ def dataset_profiling_with_document(*args, **kwargs):

return resp.json()

def model_card(*args, **kwargs):
openai_key = os.getenv("OPENAI_API_KEY")
model_id = kwargs.get("model_id")
paper_artifact_id = kwargs.get("paper_artifact_id")

try:
code_artifact_id = find_source_code(model_id)
if code_artifact_id:
code_artifact_json, code_downloaded_artifact = get_artifact_from_tds(artifact_id=code_artifact_id)
code_file = code_downloaded_artifact.decode('utf-8')
else:
logger.info(f"No associated code artifact found for model {model_id}")
code_file = "No available code associated with model."
except Exception as e:
logger.error(f"Issue finding associated source code: {e}")
code_file = "No available code associated with model."

logger.debug(f"Code file head (250 chars): {code_file[:250]}")

paper_artifact_json, paper_downloaded_artifact = get_artifact_from_tds(artifact_id=paper_artifact_id)
text_file = paper_artifact_json['metadata'].get('text', 'There is no documentation for this model').encode()

amr = get_model_from_tds(model_id).json()

params = {
'gpt_key': openai_key
}

files = {
'text_file': ('text_file', text_file),
'code_file': ('doc_file', code_file)
}

logger.info(f"Sending model {model_id} to MIT service")

resp = requests.post(f"{MIT_API}/cards/get_model_card", params=params, files=files)
logger.info(f"Response received from MIT with status: {resp.status_code}")
logger.debug(f"TA 1 response object: {resp.json()}")

if resp.status_code == 200:
try:
card = resp.json()
sys.stdout.flush()

amr['description'] = card.get('DESCRIPTION')
if not amr.get('metadata',None):
amr['metadata'] = {'card': card}
else:
amr['metadata']['card'] = card

tds_resp = requests.put(f"{TDS_API}/models/{model_id}", json=amr)
if tds_resp.status_code == 200:
logger.info(f"Updated model {model_id} in TDS: {tds_resp.status_code}")
return {
"status": tds_resp.status_code,
"message": "Model card generated and updated in TDS",
}
else:
raise Exception(f"Error when updating model {model_id} in TDS: {tds_resp.status_code}")
except Exception as e:
logger.error(f"Failed to generate model card for {model_id}: {e}")
return {
"status": 500,
"message": f"Error: {e}",
}
else:
logger.error(f"Bad response from TA1 for {model_id}: {resp.status_code}")
return {
"status": {resp.status_code},
"message": f"Error: {resp.text}",
}

# dccde3a0-0132-430c-afd8-c67953298f48
# 77a2dffb-08b3-4f6e-bfe5-83d27ed259c4
Expand Down Expand Up @@ -338,6 +420,7 @@ def link_amr(*args, **kwargs):
skema_amr_linking_url = f"{UNIFIED_API}/metal/link_amr"

response = requests.post(skema_amr_linking_url, files=files, params=params)
logger.debug(f"TA 1 response object: {response.json()}")

if response.status_code == 200:
enriched_amr = response.json()
Expand Down Expand Up @@ -373,6 +456,7 @@ def code_to_amr(*args, **kwargs):
artifact_json, downloaded_artifact = get_artifact_from_tds(artifact_id=artifact_id)

code_blob = downloaded_artifact.decode("utf-8")
logger.info(code_blob[:250])
code_amr_workflow_url = f"{UNIFIED_API}/workflows/code/snippets-to-pn-amr"

request_payload = {
Expand All @@ -390,12 +474,27 @@ def code_to_amr(*args, **kwargs):

try:
amr_json = amr_response.json()
logger.debug(f"TA 1 response object: {amr_json}")
except:
logger.error(f"Failed to parse response from TA1 Service:\n{amr_response.text}")
pass

if amr_response.status_code == 200 and amr_json:
tds_responses = put_amr_to_tds(amr_json, name, description)

put_artifact_extraction_to_tds(
artifact_id=artifact_id,
name=artifact_json.get("name", None),
filename=artifact_json.get("file_names")[0],
description=artifact_json.get("description", None),
model_id=tds_responses.get("model_id")
)

try:
set_provenance(tds_responses.get("model_id"), 'Model', artifact_id, 'Artifact', 'EXTRACTED_FROM')
except Exception as e:
logger.error(f"Failed to store provenance tying model to code artifact: {e}")

response = {
"status_code": amr_response.status_code,
"amr": amr_json,
Expand Down
63 changes: 62 additions & 1 deletion workers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@

import pandas

LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() # default to INFO if not set

import logging

numeric_level = getattr(logging, LOG_LEVEL, None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {LOG_LEVEL}')

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -60,7 +66,7 @@ def put_amr_to_tds(amr_payload, name=None, description=None):


def put_artifact_extraction_to_tds(
artifact_id, name, description, filename, extractions=None, text=None
artifact_id, name, description, filename, extractions=None, text=None, model_id=None
):
if extractions and text:
metadata = extractions[0]
Expand All @@ -69,6 +75,8 @@ def put_artifact_extraction_to_tds(
metadata = extractions[0]
elif text:
metadata = {'text': text}
elif model_id:
metadata = {'model_id': model_id}
else:
metadata = {}

Expand Down Expand Up @@ -146,3 +154,56 @@ def get_dataset_from_tds(dataset_id):
csv_string = final_df.to_csv(index=False)

return dataset, final_df, csv_string

def get_model_from_tds(model_id):
tds_model_url = f"{TDS_API}/models/{model_id}"
model = requests.get(tds_model_url)
return model


def set_provenance(
left_id, left_type, right_id, right_type, relation_type
):
"""
Creates a provenance record in TDS. Used during code to model to associate the
code artifact with the model AMR
"""

provenance_payload = {
"relation_type": relation_type,
"left": left_id,
"left_type": left_type,
"right": right_id,
"right_type": right_type
}

# Create TDS provenance
tds_provenance = f"{TDS_API}/provenance"
provenance_resp = requests.post(tds_provenance, json=provenance_payload)
if provenance_resp.status_code == 200:
logger.info(f"Stored provenance to TDS for left {left_id} and right {right_id}")
else:
logger.error(f"Storing provenance failed: {provenance_resp.text}")

return {"status": provenance_resp.status_code}

def find_source_code(
model_id
):
"""
For a given model id, finds the associated source code artifact from which it was extracted
"""

payload = {
"root_id": model_id,
"root_type": "Model"
}

tds_provenance = f"{TDS_API}/provenance/search?search_type=models_from_code"
resp = requests.post(tds_provenance, json=payload)
logger.info(f"Provenance code lookup for model ID {model_id}: {resp.json()}")
results = resp.json().get('result',[])
if len(results) > 0:
return results[0]
else:
return None
Loading