-
Notifications
You must be signed in to change notification settings - Fork 1
/
class_qbr.py
79 lines (60 loc) · 2.2 KB
/
class_qbr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
import numpy as np
from collections import deque
from class_sota import Baseline
###########################################################################################
# Queue-Based Resampling (QBR) #
###########################################################################################
class QBR(Baseline):
###############
# Constructor #
###############
def __init__(self, model, queue_size_budget):
Baseline.__init__(self, model)
# budget
self.budget = queue_size_budget
# init queues
self.xs_neg = deque(maxlen=1)
self.ys_neg = deque(maxlen=1)
self.xs_pos = deque(maxlen=1)
self.ys_pos = deque(maxlen=1)
#############
# Auxiliary #
#############
def adapt_queue(self, q, q_cap):
if q == 'neg':
self.xs_neg = deque(self.xs_neg, q_cap)
self.ys_neg = deque(self.ys_neg, q_cap)
elif q == 'pos':
self.xs_pos = deque(self.xs_pos, q_cap)
self.ys_pos = deque(self.ys_pos, q_cap)
#######
# API #
#######
def get_training_set(self, n_features):
# merge queues
xs = list(self.xs_neg) + list(self.xs_pos)
ys = list(self.ys_neg) + list(self.ys_pos)
# convert merged queues to np arrays
size = len(ys) # current queue size
x = np.array(xs).reshape(size, n_features)
y = np.array(ys).reshape(size, 1)
# batch GD
self.model.change_minibatch_size(size)
# return
return x, y
def append_to_queues(self, x, y):
if y == 0:
self.xs_neg.append(x)
self.ys_neg.append(y)
length = len(self.ys_neg)
capacity = self.ys_neg.maxlen
if length == capacity and capacity < self.budget / 2.0:
self.adapt_queue('neg', capacity + 1)
else:
self.xs_pos.append(x)
self.ys_pos.append(y)
length = len(self.ys_pos)
capacity = self.ys_pos.maxlen
if length == capacity and capacity < self.budget / 2.0:
self.adapt_queue('pos', capacity + 1)