-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
21 lines (17 loc) · 691 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.nn.functional as F
from torch import nn
from transformers import BertModel
from configuration import Configuration
class PuncRec(nn.Module):
def __init__(self, config: Configuration):
super().__init__()
self.bert = BertModel.from_pretrained(config.flavor)
self.hl = nn.Linear(768, 768)
self.punc = nn.Linear(768, len(config.punctuation_names.keys()))
self.dropout = nn.Dropout(0.2)
self.to(config.device)
def forward(self, x):
output = self.bert(x)
representations = self.dropout(F.gelu(output['last_hidden_state']))
punc = self.punc(self.dropout(self.hl(representations)))
return punc