diff --git a/FlagEmbedding/flag_models.py b/FlagEmbedding/flag_models.py index 121f6533..74932bb1 100644 --- a/FlagEmbedding/flag_models.py +++ b/FlagEmbedding/flag_models.py @@ -106,7 +106,7 @@ def encode(self, if convert_to_numpy: all_embeddings = np.concatenate(all_embeddings, axis=0) else: - all_embeddings = torch.stack(all_embeddings) + all_embeddings = torch.concatenate(all_embeddings) if input_was_string: return all_embeddings[0]