forked from mohammadasghari/dqn-multi-agent-rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prioritized_experience_replay.py
51 lines (38 loc) · 1.36 KB
/
prioritized_experience_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Created on Wednesday Jan 16 2019
@author: Seyed Mohammad Asghari
@github: https://github.com/s3yyy3d-m
"""
import random
from sum_tree import SumTree as ST
class Memory(object):
e = 0.05
def __init__(self, capacity, pr_scale):
self.capacity = capacity
self.memory = ST(self.capacity)
self.pr_scale = pr_scale
self.max_pr = 0
def get_priority(self, error):
return (error + self.e) ** self.pr_scale
def remember(self, sample, error):
p = self.get_priority(error)
self_max = max(self.max_pr, p)
self.memory.add(self_max, sample)
def sample(self, n):
sample_batch = []
sample_batch_indices = []
sample_batch_priorities = []
num_segments = self.memory.total() / n
for i in xrange(n):
left = num_segments * i
right = num_segments * (i + 1)
s = random.uniform(left, right)
idx, pr, data = self.memory.get(s)
sample_batch.append((idx, data))
sample_batch_indices.append(idx)
sample_batch_priorities.append(pr)
return [sample_batch, sample_batch_indices, sample_batch_priorities]
def update(self, batch_indices, errors):
for i in xrange(len(batch_indices)):
p = self.get_priority(errors[i])
self.memory.update(batch_indices[i], p)