-
Notifications
You must be signed in to change notification settings - Fork 0
/
TerDecomMultibits.m
41 lines (37 loc) · 1.24 KB
/
TerDecomMultibits.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
% Reproduce ternary network: "TERNARY WEIGHT DECOMPOSITION AND BINARY
% ACTIVATION ENCODING FOR FAST AND COMPACT NEURAL NETWORK"
% Algo.1
function [M,a,R] = TerDecomMultibits(W,r,bits)
tic
R = W;
fprintf('TerDecom starts!\n');
for i = 1:r
M(:,i) =sign(round(randn(size(W,1),1)));
iter = 0;
while(iter<60)
a(i,:) = (M(:,i)'*R/(M(:,i)'*M(:,i)));
index = -2^(bits-1):1:2^(bits-1);tempsum = zeros(1,length(index));
index = index/2^(bits-1);
for j = 1:size(W,1)
for q = 1:length(index)
tempsum(q) = (R(j,:)-index(q)*a(i,:))*(R(j,:)-index(q)*a(i,:))';
end
%tempsum = diag((repmat(R(j,:),length(index),1)-index'*a(i,:))*(repmat(R(j,:),length(index),1)-index'*a(i,:))');
if isnan(min(tempsum))
fprintf('exist NaN');
M(j,i) = index(1);
else
M(j,i) = index(find(tempsum == min(tempsum),1));
end
end
iter = iter + 1;
a(i,:) = (M(:,i)'*R/(M(:,i)'*M(:,i)));
end
R = R - M(:,i)*a(i,:);
% if norm(R)<10^-2
% return
% end
fprintf('iter %d/%d completed\n',i,r);
end
toc
end