-
Notifications
You must be signed in to change notification settings - Fork 3
/
testsvm.m
34 lines (30 loc) · 1.04 KB
/
testsvm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
function testsvm()
% Generate data + labels
quantity = 1000;
rands1 = [rand(quantity,1),2*pi*rand(quantity,1)];
rands2 = [rand(quantity,1),2*pi*rand(quantity,1)];
polarands1 = [rands1(:,1),rands1(:,1)].*[cos(rands1(:,2)),sin(rands1(:,2))];
polarands2 = [rands2(:,1),rands2(:,1)].*[cos(rands2(:,2)),sin(rands2(:,2))];
X = [ones(quantity,2)+polarands1;2.1*ones(quantity,2)+polarands2];
y = [repmat({'A'},quantity,1);repmat({'B'},quantity,1)];
% Make SVM
m = fitcsvm(X,y);
sv = m.SupportVectors;
cv = crossval(m);
classloss = kfoldLoss(cv);
% Plot figure
figure
whitebg(1,'k')
gscatter(X(:,1),X(:,2),y,[],[],6*[1,1])
hold on
plot(sv(:,1),sv(:,2),'yo','MarkerSize',3)
% Plot SVM line
w = sum(repmat(m.Alpha,1,2).*sv);
svcentroid = sum(sv)/size(sv,1);
bias = w*svcentroid'/w(2);
xlimits = xlim;
linex = linspace((3*xlimits(1)+xlimits(2))/4,(xlimits(1)+3*xlimits(2))/4);
%liney = -w(1)*linex/w(2)-m.Bias;
liney = -w(1)*linex/w(2)+bias;
plot(linex,liney)
end