-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
27 lines (22 loc) · 791 Bytes
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch import nn
import transformers
from transformers import BertModel
class RetrieverEncoder(nn.Module):
def __init__(self):
super(RetrieverEncoder, self).__init__()
self.passage_encoder = BertModel.from_pretrained("skt/kobert-base-v1")
self.query_encoder = BertModel.from_pretrained("skt/kobert-base-v1")
#self.hidden_size=768
def forward(self, input_ids, att_mask, type):
if type == "passage":
return self.passage_encoder(
input_ids=input_ids, attention_mask=att_mask
).pooler_output
else:
return self.query_encoder(
input_ids=input_ids, attention_mask=att_mask
).pooler_output
MODEL_DICT = {
"retriever": RetrieverEncoder,
}