-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilities.py
151 lines (117 loc) · 3.61 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import warnings
from torch.utils.data import Dataset
from abc import ABC, abstractmethod
import os
class AbstractNetwork(ABC, nn.Module):
def __init__(self, outputs):
super().__init__()
self.output_size = outputs
self._task = 0
self.used_tasks = set()
@abstractmethod
def build_net(self):
pass
@abstractmethod
def eval_forward(self, x):
pass
@abstractmethod
def embedding(self, x):
pass
@property
def task(self):
return self._task
@task.setter
def task(self, value):
# if value > self.output_size:
# value = self.output_size
# self._used_tasks.update(value)
self._task = value
@task.getter
def task(self):
return self._task
class GeneralDatasetLoader(ABC, Dataset):
def __init__(self, folder: str, transform=None, target_transform=None, *args, **kwargs):
super(Dataset).__init__()
self.folder = folder
self.transform = transform
self.target_transform = target_transform
# download fields
self.url = None
self.filename = None
self.unzipped_folder = None
self.download_path = os.path.join(folder, 'download')
self.transform = transform
self.target_transform = target_transform
self._phase = 'train'
self._current_task = 0
self.download = False
self.force_download = False
self.train_split = 1.0
self.task_manager = None
self.X, self.Y, self.class_to_idx, self.idx_to_class = None, None, None, None
self.task2idx = None
self._n_tasks = None
self.class_to_idx = None
self.idx_to_class = None
def train_phase(self):
self._phase = 'train'
def test_phase(self):
self._phase = 'test'
def next_task(self, round_robin=False):
self._current_task = self._current_task + 1
if round_robin:
self._current_task = self._current_task % self._n_tasks
else:
if self._current_task > self._n_tasks - 1:
warnings.warn("No more tasks...")
self._current_task = self._n_tasks - 1
return False
return True
def reset(self):
self._phase = 'train'
self._current_task = 0
@property
def tasks_number(self):
return self._n_tasks
@property
def phase(self):
return self._phase
@property
def task(self):
return self._current_task
@task.setter
def task(self, value):
# if value >= self._n_tasks:
# value = self._n_tasks - 1
self._current_task = value
@task.getter
def task(self):
return self._current_task
def task_mask(self, task=None):
if task is None:
task = self._current_task
return list(self.task2idx.keys())[task]
@abstractmethod
def getIterator(self, batch_size, task=None):
raise NotImplementedError
@abstractmethod
def load_dataset(self):
raise NotImplementedError
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def download_dataset(self):
raise NotImplementedError
def already_downloaded(self):
print(self.download_path)
if not os.path.exists(self.download_path):
os.makedirs(self.download_path)
return False
else:
if len(os.listdir(self.download_path)) == 0:
return False
return True