-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
74 lines (56 loc) · 2.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
import datasets
import ptvsd
import torch
from sentence_transformers import SentenceTransformer
from matryoshka_adaptor.model import MatryoshkaAdaptor
from matryoshka_adaptor.train import training_loop
logger = logging.getLogger(__name__)
debug_mode = False
if debug_mode:
print("Waiting for debugger to attach...")
ptvsd.enable_attach(address=("localhost", 5678))
ptvsd.wait_for_attach()
embedding_dim_full = 384
model_string = "sentence-transformers/all-MiniLM-L6-v2"
dataset = "mteb/hotpotqa"
n_epochs_unsupervised = 0
n_epochs_supervised = 6
k = 20
learning_rate = 0.00001
query_batch_size = 800
n_non_matching_augmentations = 4
corpus_batch_size = 256
encoding_batch_size = 512
unsupervised_learning_model_file = "adaptor_hotpotqa_unsupervised.pt"
supervised_learning_model_file = "adaptor_hotpotqa_supervised.pt"
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
corpus = datasets.load_dataset(dataset, "corpus", split="corpus")
queries = datasets.load_dataset(dataset, "queries", split="queries")
associations = datasets.load_dataset(dataset, split="train")
original_model = SentenceTransformer(model_string)
adaptor = MatryoshkaAdaptor(embedding_dim_full)
# load model from adaptor.pt
adaptor.load_state_dict(torch.load("adaptor_hotpotqa_supervised.pt"))
adaptor.to(device)
training_loop(
corpus=corpus,
queries=queries,
associations=associations,
original_model=original_model,
adaptor_model=adaptor,
n_epochs_unsupervised=n_epochs_unsupervised,
n_epochs_supervised=n_epochs_supervised,
query_batch_size=query_batch_size,
corpus_batch_size=corpus_batch_size,
n_non_matching_augmentations=n_non_matching_augmentations,
encoding_batch_size=encoding_batch_size,
unsupervised_learning_model_file=unsupervised_learning_model_file,
supervised_learning_model_file=supervised_learning_model_file,
learning_rate_unsupervised=learning_rate,
learning_rate_supervised=learning_rate / 20,
gamma=10,
)
4