forked from andrewtwilson/multi-armed-bandit
-
Notifications
You must be signed in to change notification settings - Fork 1
/
multi-armed_bandit.cpp
143 lines (115 loc) · 3.13 KB
/
multi-armed_bandit.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <vector>
#include "data_to_plot.h"
// This is a test line.
using namespace std;
class armedbandit
{
int max_reward;
public:
armedbandit(int maxreward) : max_reward(maxreward) {} // Constructor sets a value to be used as the "max reward" possible
double onepull()
{
double number;
double odds_are = (double)rand() / RAND_MAX;
if (odds_are > 0.5)
number = (double)rand() / RAND_MAX * 100 + max_reward; // 50% chance of pulling max_reward + random number between 0 and 100
else
number = (double)rand() / RAND_MAX * -100 + max_reward; // 50% chance of pulling max_reward - random number between 0 and 100
return number; // Return a value to be used as the reward for the one pull
}
};
class Qlearner
{
public:
vector<double> q;
int index_of_best = -1;
int times = 5000; //number of episodes
int action = 0;
int machines = 5; //number of slot machines. If this number is changed, two other spots of code MUST be changed, and a third SHOULD be changed
double alpha = 0.1; //learning reate
double reward = 0; //R
double gamma = 0.0; //discount factor
double epsilon = 0.1; //chance of picking a random machine
void init_q();
int decide();
int greedy_action();
int rand_action();
void update_q();
};
void Qlearner::init_q()
{
for (int j = 0; j < machines; j++)
{
q.push_back(100); // starting q-value for each slot machine
}
}
int Qlearner::decide()
{
for (int i = 0; i < times; i++)
{
double a = (double)rand() / RAND_MAX;
if (a < epsilon)
action = rand_action();
else
action = greedy_action();
}
return action;
}
int Qlearner::greedy_action()
{
double current_best = -9999999999999;
for (int i = 0; i < machines; i++)
{
if (q.at(i) > current_best)
{
index_of_best = i;
current_best = q.at(i);
}
}
return index_of_best;
}
int Qlearner::rand_action()
{
int a = rand() % machines;
return a;
}
void Qlearner::update_q()
{
cout << "Qlearner chose bandit # " << action+1 << endl;
q.at(action) = q.at(action) + alpha*(reward + gamma - q.at(action)); // update the q-value for the machine picked this episode
cout << q.at(0) << " " << q.at(1) << " " << q.at(2) << " " << q.at(3) << " " << q.at(4) << endl; // prints to screen the current q-value for each machine
}
int main()
{
srand(time(NULL));
statistics_library lib;
lib.test_stats();
cout << "COMPLETE!!" << endl;
return 0;
}
/*armedbandit bandit1(100), bandit2(200), bandit3(100), bandit4(150), bandit5(220); // must have the number of different bandits = "machines" variable
Qlearner yup;
vector<armedbandit> casino; // vector to store our slot machines
casino.push_back(bandit1);
casino.push_back(bandit2);
casino.push_back(bandit3);
casino.push_back(bandit4);
casino.push_back(bandit5);
//for (int j = 0; j < yup.machines; j++)
//{
// casino.push_back(0);
//}
yup.init_q();
for (int i = 0; i < yup.times; i++)
{
int action = yup.decide();
yup.reward = casino.at(action).onepull();
yup.update_q();
}
*/
//return 0;
//}