Skip to content

Commit

Permalink
Merge pull request #775 from DhanshreeA/master
Browse files Browse the repository at this point in the history
Correctly handle Pair of Lists example generation
  • Loading branch information
GemmaTuron committed Aug 15, 2023
2 parents 37d16ff + d13fe0a commit 03b90ef
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ersilia/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(self, model_id, config_json=None):
if type(self.input_shape) is InputShapeList:
self._flatten = self._flatten_list
if type(self.input_shape) is InputShapePairOfLists:
self._flatten = self._flatten_single
self._flatten = self._flatten_pair_of_lists

@throw_ersilia_exception
def check_model_id(self, model_id):
Expand All @@ -204,10 +204,10 @@ def _get_delimiter(file_name):
return ","

def _flatten_single(self, datum):
return datum
return [datum]

def _flatten_list(self, datum):
return self._string_delimiter.join(datum)
return [self._string_delimiter.join(datum)]

def _flatten_pair_of_lists(self, datum):
return [
Expand Down Expand Up @@ -242,7 +242,7 @@ def example(self, n_samples, file_name, simple):
writer = csv.writer(f, delimiter=delimiter)
if simple:
for v in self.IO.example(n_samples):
writer.writerow([self._flatten(v["input"])])
writer.writerow(self._flatten(v["input"]))
else:
writer.writerow(["key", "input", "text"])
for v in self.IO.example(n_samples):
Expand Down

0 comments on commit 03b90ef

Please sign in to comment.