-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipeline.m
220 lines (187 loc) · 9.22 KB
/
pipeline.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
classdef pipeline < baseEstimator & baseTransformer
properties
verbose = false;
end
properties (SetAccess = private)
% the following are private because names and associated objects need to be modified
% simultaneously. We have methods for this below
transformers = {};
transformer_names = {};
estimator = {};
estimator_name = [];
end
properties (Access = {?baseTransformer, ?baseEstimator})
hyper_params = {};
end
methods
function obj = pipeline(steps, varargin)
for i = 1:length(steps) - 1
assert(isa(steps{i}{2}, 'baseTransformer'), ...
sprintf('All but last step must be transformers, but element %d is type %s',i,class(steps{i}{2})));
obj.transformer_names{end+1} = steps{i}{1};
obj.transformers{end+1} = copy(steps{i}{2});
end
if isa(steps{end}{2},'baseEstimator')
obj.estimator_name = steps{end}{1};
obj.estimator = copy(steps{end}{2});
elseif isa(steps{end}{2},'baseTransformer')
obj.transformer_names{end+1} = steps{end}{1};
obj.transformers{end+1} = copy(steps{end}{2});
else
error('Last step must be a transformer or a estimator');
end
for i = 1:length(varargin)
if ischar(varargin{i})
switch varargin{i}
case 'verbose'
obj.verbose = varargin{i+1};
end
end
end
% set hyperparameters to match whatever the constituent step
% hyperparameters are
params = {};
for i = 1:length(obj.transformers)
these_params = obj.transformers{i}.get_params();
params = [params, cellfun(@(x1)([obj.transformer_names{i}, '__', x1]), ...
these_params, 'UniformOutput', false)];
end
if ~isempty(obj.estimator)
warning('off','bayesOptCV:get_params')
these_params = obj.estimator.get_params();
warning('on','bayesOptCV:get_params');
params = [params, cellfun(@(x1)([obj.estimator_name, '__', x1]), these_params, 'UniformOutput', false)];
end
obj.hyper_params = params;
end
% fit all transformers and any estimators
function fit(obj, dat, Y)
t0 = tic;
for i = 1:length(obj.transformers)
% output from one transformer is input to the next
if obj.verbose, fprintf('Fitting %s\n', obj.transformer_names{i}); end
dat = obj.transformers{i}.fit_transform(dat, Y);
end
if ~isempty(obj.estimator)
if obj.verbose, fprintf('Fitting %s\n', obj.estimator_name); end
obj.estimator.fit(dat, Y);
end
obj.isFitted = true;
obj.fitTime = toc(t0);
end
% apply all transforms
function dat = transform(obj, dat, varargin)
for i = 1:length(obj.transformers)
if obj.verbose, fprintf('Applying %s\n', obj.transformer_names{i}); end
% output from one transformer is input to the next
dat = obj.transformers{i}.transform(dat);
end
end
% apply all transforms and predict
function yfit_raw = score_samples(obj, dat, varargin)
assert(~isempty(obj.estimator), ...
'This pipeline does not terminate in a estimator. Try pipeline.transform() instead');
dat = obj.transform(dat, varargin{:});
if ~isempty(obj.estimator)
if obj.verbose, fprintf('Applying %s\n', obj.estimator_name); end
yfit_raw = obj.estimator.score_samples(dat);
end
end
% apply all transforms and predict
function yfit = predict(obj, dat, varargin)
assert(~isempty(obj.estimator), ...
'This pipeline does not terminate in a estimator. Try pipeline.transform() instead');
dat = obj.transform(dat, varargin{:});
if ~isempty(obj.estimator)
if obj.verbose, fprintf('Applying %s\n', obj.estimator_name); end
yfit = obj.estimator.predict(dat);
end
end
% apply all transforms and predict
function yfit_null = score_null(obj, varargin)
assert(~isempty(obj.estimator), ...
'This pipeline does not terminate in a estimator. Try pipeline.transform() instead');
yfit_null = obj.estimator.score_null(varargin{:});
end
% apply all transforms and predict
function yfit_null = predict_null(obj, varargin)
assert(~isempty(obj.estimator), ...
'This pipeline does not terminate in a estimator. Try pipeline.transform() instead');
yfit_null = obj.estimator.predict_null(varargin{:});
end
function params = get_params(obj)
params = obj.hyper_params;
end
% finds object to modify and calls its obj.set_params(passThrough,
% hyp_val) where passThrough are the residual tokens of hyp_name
% after removing the target object name. In most cases the residual
% token will be a hyperparameter name, but if you're using a
% pipeline of pipelines then the residual token could be another
% parameter of the form class_param, in which case the function
% recurses.
function set_params(obj, hyp_name, hyp_val)
hyp_name = strsplit(hyp_name,'__');
for i = 1:length(obj.transformers)
if strcmp(hyp_name{1}, obj.transformer_names{i})
passThrough = strjoin(hyp_name(2:end),'__');
obj.transformers{i}.set_params(passThrough, hyp_val);
return
end
end
for i = 1:length(obj.estimator)
if strcmp(hyp_name{1}, obj.estimator_name)
passThrough = strjoin(hyp_name(2:end),'__');
obj.estimator.set_params(passThrough, hyp_val);
return
end
end
end
function obj = set_transformer(obj,transformers)
obj.transformers = {};
obj.transformer_names = {};
for i = 1:length(transformers)
assert(isa(transformers{i}{2}, 'baseTransformer'), 'All steps must be transformers');
obj.transformer_names{end+1} = transformers{i}{1};
obj.transformers{end+1} = copy(transformers{i}{2});
end
%obj.isFitted = false;
end
function obj = set_estimator(obj,estimator)
assert(isa(estimator{2}, 'baseEstimator'), 'estimator must be type Estimator');
obj.estimator_name = estmator{1};
obj.estimator = copy(estimator{2});
obj.isFitted = false;
end
end
methods (Access = protected)
function newObj = copyElement(obj)
newObj = copyElement@matlab.mixin.Copyable(obj);
fnames = fieldnames(obj);
newObj.transformers = copyCell(obj.transformers);
fnames(ismember(fnames,'transformers')) = [];
for i = 1:length(fnames)
if isa(obj.(fnames{i}), 'cell')
hasHandles = checkCellsForHandles(obj.(fnames{i}));
if hasHandles
try
newObj.(fnames{i}) = cell(size(obj.(fnames{i})));
warning('%s.%s has handle objects, but corresponding deep copy support hasn''t been implemented. Dropping %s.',class(obj), fnames{i}, fnames{i});
catch
error('%s.%s has handle objects, corresponding copy support hasn''t been implemented, and element cannot be dropped. Cannot complete deep copy.',class(obj), fnames{i});
end
end
elseif isa(obj.(fnames{i}), 'matlab.mixin.Copyable')
newObj.(fnames{i}) = copy(obj.(fnames{i}));
elseif isa(obj.(fnames{i}), 'handle') % implicitly: & ~isa(obj.(fnames{i}), 'matlab.mixin.Copyable')
% the issue here is that fuction handles that are
% copied can contain references to the object they
% belong to, but these references will continue to
% point to the original object, and not the copy
% becaues matlab cannot parse these function handles
% appropriately.
warning('%s.%s is a handle but not copyable. This can lead to unepected behavior and is not ideal', class(obj), fnames{i});
end
end
end
end
end