forked from Pim-Mostert/decoding-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 1
/
decodeCrossValidation.m
129 lines (101 loc) · 4.68 KB
/
decodeCrossValidation.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
function Xhat = decodeCrossValidation(cfg0, X, Y)
% [Xhat] = decodeCrossValidation(cfg, X, Y)
% Implements k-fold cross-validation, in which a subset of trials is left out
% in each iteration as testing data, while training on the remaining trials.
%
% cfg Configuration struct that can possess the following fields:
% .trainfun = [function_name] The training function that is used for training.
% .traincfg = [struct] The configuration struct that will be passed on to
% the training function. Default = [];
% .decodefun = [function_name] The decoding function that is used for decoding.
% .decodecfg = [struct] The configuration struct that will be passed on to
% the decoding function. Default = [].
% .folds = [cell_array] A cell-array of length k, where k is the number of folds,
% in which each cell contains a vector with the trial numbers
% belonging to that particular fold.
% .feedback = 'yes' or 'no' Whether the function should provide feedback on its progress.
% Default = 'no'.
%
% X Matrix of arbitrary dimensions, but of which the last dimension is N, that contains
% the training information. In each fold, a selection of this matrix (along the last
% dimension) is sent to the training function.
%
% Y Matrix of arbitrary dimensions, but of which the last dimension corresponds to the
% number of trials N, that contains the data. In each fold, a selection of this matrix
% (along the last dimension) is sent to the training and decoding function.
%
% Xhat Matrix of dimensions as output by the decoding functiong, plus an additional dimension
% of length N, that contains the decoded data.
%
% See also CREATEFOLDS
% Created by Pim Mostert, 2016
tStart = tic;
if ~isfield(cfg0, 'traincfg')
cfg0.traincfg = [];
end
if ~isfield(cfg0, 'decodecfg')
cfg0.decodecfg = [];
end
if ~isfield(cfg0, 'feedback')
cfg0.feedback = 'no';
end
dimsY = size(Y);
numN = dimsY(end);
numFold = length(cfg0.folds);
%% Reshape data to allow for arbitrary dimensionality
Y = reshape(Y, [prod(dimsY(1:(end-1))), numN]);
if isvector(X)
X = X(:)';
dimsX = size(X);
else
dimsX = size(X);
X = reshape(X, [prod(dimsX(1:(end-1))), numN]);
end
%% Do first fold manually, to determine output size of decoder
iFold = 1;
tFold = tic;
index_train = cell2mat(cfg0.folds((1:numFold) ~= iFold)');
index_decode = cfg0.folds{iFold};
% Select training data
Y_train = reshape(Y(:, index_train), [dimsY(1:(end-1)), length(index_train)]);
X_train = reshape(X(:, index_train), [dimsX(1:(end-1)), length(index_train)]);
% Train decoder
decoder = feval(cfg0.trainfun, cfg0.traincfg, X_train, Y_train);
% Select data to be decoded
Y_decode = reshape(Y(:, index_decode), [dimsY(1:(end-1)), length(index_decode)]);
% Decode data
Xhat_curFold = feval(cfg0.decodefun, cfg0.decodecfg, decoder, Y_decode);
% Feedback
if strcmp(cfg0.feedback, 'yes')
fprintf('%s: finished fold %g/%g - it took %.2f s\n', mfilename, iFold, numFold, toc(tFold));
end
%% Allocate memory for results and do rest of folds
dimsOut = size(Xhat_curFold);
dimsOut = dimsOut(1:(end-1));
Xhat = zeros([prod(dimsOut), numN]);
Xhat(:, index_decode) = reshape(Xhat_curFold, [prod(dimsOut), length(index_decode)]);
for iFold = 2:numFold
tFold = tic;
index_train = cell2mat(cfg0.folds((1:numFold) ~= iFold)');
index_decode = cfg0.folds{iFold};
% Select training data
Y_train = reshape(Y(:, index_train), [dimsY(1:(end-1)), length(index_train)]);
X_train = reshape(X(:, index_train), [dimsX(1:(end-1)), length(index_train)]);
% Train decoder
decoder = feval(cfg0.trainfun, cfg0.traincfg, X_train, Y_train);
% Select data to be decoded
Y_decode = reshape(Y(:, index_decode), [dimsY(1:(end-1)), length(index_decode)]);
% Decode data
Xhat_curFold = feval(cfg0.decodefun, cfg0.decodecfg, decoder, Y_decode);
Xhat(:, index_decode) = reshape(Xhat_curFold, [prod(dimsOut), length(index_decode)]);
% Feedback
if strcmp(cfg0.feedback, 'yes')
fprintf('%s: finished fold %g/%g - it took %.2f s\n', mfilename, iFold, numFold, toc(tFold));
end
end
%% Return
Xhat = reshape(Xhat, [dimsOut, numN]);
if strcmp(cfg0.feedback, 'yes')
fprintf('%s - all finished - it took %.2f s\n', mfilename, toc(tStart));
end
end