From 41f84e8d08874e874d74adc77968e1c685f9bf58 Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Thu, 11 Jul 2024 09:35:12 +0000 Subject: [PATCH] Fix error when missing keys in dataset --- goldenretriever/data/datasets.py | 10 ++++++++-- goldenretriever/version.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/goldenretriever/data/datasets.py b/goldenretriever/data/datasets.py index 1dd7e64..ea49e80 100644 --- a/goldenretriever/data/datasets.py +++ b/goldenretriever/data/datasets.py @@ -494,11 +494,17 @@ def load_fn( if max_positives != -1: positives = positives[:max_positives] - negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + if "negative_ctxs" in sample: + negatives = list(set([n["text"] for n in sample["negative_ctxs"]])) + else: + negatives = [] if max_negatives != -1: negatives = negatives[:max_negatives] - hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + if "hard_negative_ctxs" in sample: + hard_negatives = list(set([h["text"] for h in sample["hard_negative_ctxs"]])) + else: + hard_negatives = [] if max_hard_negatives != -1: hard_negatives = hard_negatives[:max_hard_negatives] diff --git a/goldenretriever/version.py b/goldenretriever/version.py index 000a79a..ac0ebe7 100644 --- a/goldenretriever/version.py +++ b/goldenretriever/version.py @@ -4,7 +4,7 @@ _MINOR = "9" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "1" +_PATCH = "2" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = os.environ.get("GOLDENRETRIEVER_VERSION_SUFFIX", "")