-
Notifications
You must be signed in to change notification settings - Fork 3
/
Batch.py
49 lines (38 loc) · 1.23 KB
/
Batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dataclasses import dataclass
import torch
@dataclass
class Batch:
batch_id: int
inputs: torch.Tensor
labels: torch.Tensor
# For PIPA experiment we use this field to store identity label.
aux: torch.Tensor = None
def __post_init__(self):
self.batch_size = self.inputs.shape[0]
def to(self, device):
inputs = self.inputs.to(device)
labels = self.labels.to(device)
aux = self.aux.to(device) if self.aux is not None else None
# if self.aux is not None:
# aux = self.aux.to(device)
# else:
# aux = None
return Batch(self.batch_id, inputs, labels, aux)
def clone(self):
inputs = self.inputs.clone()
labels = self.labels.clone()
if self.aux is not None:
aux = self.aux.clone()
else:
aux = None
return Batch(self.batch_id, inputs, labels, aux)
def clip(self, batch_size):
if batch_size is None:
return self
inputs = self.inputs[:batch_size]
labels = self.labels[:batch_size]
if self.aux is None:
aux = None
else:
aux = self.aux[:batch_size]
return Batch(self.batch_id, inputs, labels, aux)