Skip to content

Commit

Permalink
Refactor tracking to support a single SMILES string input
Browse files Browse the repository at this point in the history
  • Loading branch information
Malikbadmus committed Jul 1, 2024
1 parent fece745 commit e161461
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 29 deletions.
14 changes: 7 additions & 7 deletions ersilia/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,13 @@ def run(
result = self._run(
input=input, output=output, batch_size=batch_size, track_run=track_run
)
# Start tracking model run if track flag is used in serve
if self._run_tracker is not None and track_run:
self._run_tracker.track(
input=input, result=result, meta=self._model_info
)
self._run_tracker.log(result=result, meta=self._model_info)
return result
# Start tracking model run if track flag is used in serve
if self._run_tracker is not None and track_run:
self._run_tracker.track(
input=input, result=result, meta=self._model_info
)
self._run_tracker.log(result=result, meta=self._model_info)
return result

@property
def paths(self):
Expand Down
59 changes: 37 additions & 22 deletions ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,28 @@
from collections import defaultdict
from ..default import EOS, ERSILIA_RUNS_FOLDER
from ..utils.docker import SimpleDocker
from ..utils.csvfile import CsvDataLoader
from ..io.output_logger import TabularResultLogger
from botocore.exceptions import ClientError, NoCredentialsError




def flatten_dict(data):
"""
This will flatten the nested dictionaries from the generator into a single-level dictionary,
where keys from all levels are merged into one dictionary.
:flat_dict: Result returned in a dictionary
"""
flat_dict = {}
for outer_key, inner_dict in data.items():
for inner_key, value in inner_dict.items():
flat_dict[inner_key] = value
return flat_dict



def log_files_metrics(file_log, model_id):
"""
This function will log the number of errors and warnings in the log files.
Expand Down Expand Up @@ -279,19 +296,6 @@ def upload_to_cddvault(output_df, api_key):
return False


def read_csv(file_path):
"""
Reads a CSV file and returns the data as a list of dictionaries.
:param file_path: Path to the CSV file.
:return: A list of dictionaries containing the CSV data.
"""
with open(file_path, mode="r") as file:
reader = csv.DictReader(file)
data = [row for row in reader]
return data


def get_nan_counts(data_list):
"""
Calculates the number of None values in each key of a list of dictionaries.
Expand Down Expand Up @@ -535,8 +539,14 @@ def track(self, input, result, meta):
"""
self.time_start = datetime.now()
self.docker_client = SimpleDocker()
self.data = CsvDataLoader()
json_dict = {}
input_data = read_csv(input)

if os.path.isfile(input):
input_data = self.data.read(input)
else:
input_data = [{"SMILES": input}]

# Create a temporary file to store the result if it is a generator
if isinstance(result, types.GeneratorType):

Expand All @@ -549,15 +559,20 @@ def track(self, input, result, meta):
temp_output_file = tempfile.NamedTemporaryFile(
delete=False, suffix=".csv", dir=tmp_dir
)
temp_output_path = temp_output_file.name

flat_data_list = [flatten_dict(row) for row in result]
if flat_data_list:
header = list(flat_data_list[0].keys())
temp_output_path = temp_output_file.name
with open(temp_output_path, "w", newline="") as csvfile:
csvWriter = csv.writer(csvfile)
for row in result:
csvWriter.writerow(row)
result_data = read_csv(temp_output_path)
os.remove(temp_output_path)
csvWriter = csv.DictWriter(csvfile, fieldnames=header)
csvWriter.writeheader()
for flat_data in flat_data_list:
csvWriter.writerow(flat_data)
result_data = self.data.read(temp_output_path)
os.remove(temp_output_path)
else:
result_data = read_csv(result)
result_data = self.data.read(result)

session = Session(config_json=self.config_json)
model_id = meta["metadata"].get("Identifier", "Unknown")
Expand Down Expand Up @@ -600,4 +615,4 @@ def track(self, input, result, meta):
def log(self, result, meta):
self.log_result(result)
self.log_meta(meta)
self.log_logs()
self.log_logs()
32 changes: 32 additions & 0 deletions ersilia/utils/csvfile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import csv
import json


class CsvDataLoader(object):
Expand All @@ -21,3 +23,33 @@ def load(self, csv_file):
self.keys += [r[0]]
self.inputs += [r[1]]
self.values += [r[-len(self.features) :]]



def _read_csv_tsv(self, file_path, delimiter):
with open(file_path, mode='r') as file:
reader = csv.DictReader(file, delimiter=delimiter)
data = [row for row in reader]
return data

def _read_json(self, file_path):
with open(file_path, mode="r") as file:
return json.load(file)


def read(self, file_path):
"""
Reads a file and returns the data as a list of dictionaries.
:param file_path: Path to the CSV file.
:return: A list of dictionaries containing the CSV data.
"""

file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.json':
return self._read_json(file_path)
elif file_extension in ['.csv', '.tsv']:
delimiter = '\t' if file_extension == '.tsv' else ','
return self._read_csv_tsv(file_path, delimiter)
else:
raise ValueError("Unsupported file format")

0 comments on commit e161461

Please sign in to comment.