diff --git a/comet/models/predict_writer.py b/comet/models/predict_writer.py index 7adde1c..7d95989 100644 --- a/comet/models/predict_writer.py +++ b/comet/models/predict_writer.py @@ -86,7 +86,7 @@ def flatten(list): def flatten_predictions(predictions): flatten_pred = Prediction( - scores=torch.cat([pred.scores for pred in predictions], dim=0) + scores=torch.cat([pred["scores"] for pred in predictions], dim=0) ) if "metadata" in predictions[0]: flatten_pred["metadata"] = flatten_metadata(