-
Notifications
You must be signed in to change notification settings - Fork 0
/
projectMain.m
119 lines (92 loc) · 3.5 KB
/
projectMain.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
clc; clear variables; format long g; close all force;
%Include matlabrc, hopefully
%#function matlabrc
warning('off', 'MATLAB:polyshape:repairedBySimplify');
warning('off', 'MATLAB:structOnObject');
warning('off', 'MATLAB:timer:miliSecPrecNotAllowed');
if(not(isempty(gcp('nocreate'))))
pctRunOnAll warning('off', 'MATLAB:timer:miliSecPrecNotAllowed');
end
%% set pathes if not deployed
if(~isdeployed)
restoredefaultpath();
addpath(genpath('classes'));
addpath(genpath('helper_methods'));
addpath(genpath('ships'));
addpath(genpath('ui'));
end
%% remove any previously running timers
delete(timerfindall);
%% start main UI
% mainGUI_App();
%% Train Neural Network Ship
arena = NNS_Arena([-50,50], [-50,50]);
% propObjs = NNS_PropagatedObjectList();
numShips = 2;
for(i=1:2)
load('nn_ship.mat');
ships(i) = ship; %#ok<SAGROW>
ship.arena = arena;
arena.propObjs.addPropagatedObject(ship);
end
% arena.propObjs.addPropagatedObject(NNS_Ship.createDefaultBasicShip(arena));
simDriver = NNS_SimulationDriver(arena,false);
%% Run GA to train agent
fun = @(x) gaObjFunc(x, simDriver, ships);
outputFunc = @(options,state,flag) gaOutputFunc(options,state,flag, ships);
controllers = ships(1).components.getControllerComps();
agent = controllers(1).getAgent();
x = getXVectFromActor(agent);
options = optimoptions("ga", "PopulationSize",1024, "UseParallel",true, "OutputFcn",outputFunc, "Display","iter", "PlotFcn",{'gaplotscorediversity', 'gaplotbestf', 'gaplotdistance'}, 'MaxGenerations',2000, "FunctionTolerance",0, "FitnessScalingFcn","fitscalingprop", "CrossoverFcn","crossoverheuristic");
[x,fval,exitflag,output,population,scores] = ga(fun,numel(x),[],[],[],[],[],[],[],options);
setXVectFromActor(agent, x);
ship = ships(1);
save('nn_ship_solved.mat','ship');
% profile off; profile('on', '-historysize',500000000);
% fun(x);
% profile viewer;
%% Helper Method
function f = gaObjFunc(x, simDriver, nnShips)
arguments
x double
simDriver NNS_SimulationDriver
nnShips NNS_Ship
end
numRuns = 1;
f = NaN(1,numRuns);
for(i=1:numRuns) %#ok<*NO4LP>
%get the RL agent we're training and set its learnable values
for(j=1:numel(nnShips))
controllers = nnShips(j).components.getControllerComps();
agent = controllers(1).getAgent();
setXVectFromActor(agent, x);
end
%drive simulation
simDriver.driveSimulation();
%retrieve score
scores = simDriver.arena.scorekeeper.getScoresForAllRows();
f(i) = -mean(scores(:,2));
end
f = mean(f);
end
function [state,options,optchanged] = gaOutputFunc(options,state,flag, nnShips)
if(strcmpi(flag,'iter'))
genNum = state.Generation;
scores = state.Score;
[bestScore,I] = min(scores);
x = state.Population(I,:);
controllers = nnShips(1).components.getControllerComps();
agent = controllers(1).getAgent();
setXVectFromActor(agent, x);
date = datestr(now(), 'YYYYmmDD_HHMMSS');
filename = sprintf('nnship_Gen_%u_Score_%0.0f_%s.mat', genNum, bestScore, date);
filepath = fullfile(pwd,'ships','nn_train',filename);
[folder,~,~] = fileparts(filepath);
if(not(exist(folder,"dir")))
mkdir(folder);
end
ship = nnShips(1);
save(filepath, 'ship');
end
optchanged = false;
end