From 0ceb02b1522afe5c1a483781ac83cc480d7856a2 Mon Sep 17 00:00:00 2001 From: Andrea Vallebueno Date: Sat, 9 Nov 2024 18:17:22 -0800 Subject: [PATCH] Account for un-approx variances --- glove_v/variance.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/glove_v/variance.py b/glove_v/variance.py index cc889da..1df8403 100644 --- a/glove_v/variance.py +++ b/glove_v/variance.py @@ -40,7 +40,7 @@ def load_variance( diagonal = f.get_tensor(f"diag_{word_idx}") return np.diag(diagonal) - # Otherwise, it must be an SVD approximation + # SVD approximation elif f"U_{word_idx}" in f.keys(): U = f.get_tensor(f"U_{word_idx}") s = f.get_tensor(f"s_{word_idx}") @@ -49,6 +49,11 @@ def load_variance( # Reconstruct using SVD components: U * diag(s) * Vt return U @ np.diag(s) @ Vt + # Complete approximation (SVD unavailable) + elif f"complete_{word_idx}" in f.keys(): + complete = f.get_tensor(f"complete_{word_idx}") + return complete + else: raise KeyError( f"[ERROR No approximation found for word index {word_idx}"