Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau committed Jul 29, 2024
1 parent b6aa23f commit b472555
Show file tree
Hide file tree
Showing 17 changed files with 118 additions and 158 deletions.
48 changes: 21 additions & 27 deletions src/batches/stats/setBatchFactorialDesign.m
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
rfxDir = getRFXdir(opt, nodeName, contrasts{1}, label);
overwriteDir(rfxDir, opt);

assert(exist(fullfile(rfxDir, 'SPM.mat'), 'file') == 0);
assert(~checkSpmMat(rfxDir, opt));

matlabbatch = returnOneWayAnovaBatch(matlabbatch, rfxDir);

Expand All @@ -89,16 +89,7 @@

thisGroup = availableGroups{iGroup};

% grab subjects label from participants.tsv in raw
% and only keep those that are part of the requested subjects
%
% Note that this will lead to different results
% depending on the requested subejcts
%
participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv'));
subjectsInGroup = strcmp(participants.(groupColumnHdr), thisGroup);
subjectsLabel = regexprep(participants.participant_id(subjectsInGroup), '^sub-', '');
subjectsLabel = intersect(subjectsLabel, opt.subjects);
subjectsLabel = returnSubjectLabelInGroup(opt, groupColumnHdr, thisGroup);

% collect all con images from all subjects
for iSub = 1:numel(subjectsLabel)
Expand Down Expand Up @@ -183,16 +174,7 @@
for iGroup = 1:numel(availableGroups)

thisGroup = availableGroups{iGroup};

% grab subjects label from participants.tsv in raw
% and only keep those that are part of the requested subjects
%
% Note that this will lead to different results depending on the requested
% subejcts
%
subjectsInGroup = strcmp(participants.(groupColumnHdr), thisGroup);
subjectsLabel = regexprep(participants.participant_id(subjectsInGroup), '^sub-', '');
subjectsLabel = intersect(subjectsLabel, opt.subjects);
subjectsLabel = returnSubjectLabelInGroup(opt, groupColumnHdr, thisGroup);

% collect all con images from all subjects
for iSub = 1:numel(subjectsLabel)
Expand Down Expand Up @@ -220,6 +202,18 @@

end

function subjectsLabel = returnSubjectLabelInGroup(opt, groupColumnHdr, group)
% grab subjects label from participants.tsv in raw
% and only keep those that are part of the requested subjects
%
% Note that this will lead to different results depending on the requested subejcts
%
participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv'));
subjectsInGroup = strcmp(participants.(groupColumnHdr), group);
subjectsLabel = regexprep(participants.participant_id(subjectsInGroup), '^sub-', '');
subjectsLabel = intersect(subjectsLabel, opt.subjects);
end

function icell = allocateSubjectsContrasts(opt, subjectsLabel, conImages, iCon)

icell(1).scans = {};
Expand Down Expand Up @@ -256,7 +250,7 @@
rfxDir = getRFXdir(opt, nodeName, contrastName, thisGroup);
overwriteDir(rfxDir, opt);

assert(exist(fullfile(rfxDir, 'SPM.mat'), 'file') == 0);
assert(~checkSpmMat(rfxDir, opt));

icell(1).levels = 1;

Expand Down Expand Up @@ -342,11 +336,11 @@
% TODO refactor
participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv'));

model = opt.model.bm;
bm = opt.model.bm;

status = model.validateGroupBy(nodeName, participants);
status = bm.validateGroupBy(nodeName, participants);

[glmType, groupBy] = model.groupLevelGlmType(nodeName, participants);
[glmType, groupBy] = bm.groupLevelGlmType(nodeName, participants);

% only certain type of model supported for now
if ismember(glmType, {'unknown', 'two_sample_t_test'})
Expand All @@ -358,8 +352,8 @@
return
end

datasetLvlContrasts = model.get_contrasts('Name', nodeName);
datasetLvlDummyContrasts = model.get_dummy_contrasts('Name', nodeName);
datasetLvlContrasts = bm.get_contrasts('Name', nodeName);
datasetLvlDummyContrasts = bm.get_dummy_contrasts('Name', nodeName);

if isempty(datasetLvlContrasts) && isempty(datasetLvlDummyContrasts)
msg = sprintf('No contrast specified %s', commonMsg);
Expand Down
14 changes: 7 additions & 7 deletions src/batches/stats/setBatchGroupLevelContrasts.m
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@

printBatchName('group level contrast estimation', opt);

model = opt.model.bm;
bm = opt.model.bm;

participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv'));

groupColumnHdr = model.getGroupColumnHdrFromGroupBy(nodeName, participants);
groupColumnHdr = bm.getGroupColumnHdrFromGroupBy(nodeName, participants);
availableGroups = getAvailableGroups(opt, groupColumnHdr);

[groupGlmType, groupBy] = model.groupLevelGlmType(nodeName, participants);
[groupGlmType, groupBy] = bm.groupLevelGlmType(nodeName, participants);
switch groupGlmType

case 'one_sample_t_test'
Expand Down Expand Up @@ -80,10 +80,10 @@
% through the Edge filter.
% Then generate the between group contrasts.

edge = model.get_edge('Destination', nodeName);
edge = bm.get_edge('Destination', nodeName);
contrastsList = edge.Filter.contrast;

thisContrast = model.get_contrasts('Name', nodeName);
thisContrast = bm.get_contrasts('Name', nodeName);

for j = 1:numel(contrastsList)

Expand Down Expand Up @@ -127,13 +127,13 @@
if any(ismember(designMatrix, fieldnames(participants)))
% TODO will this ignore the contrasts define at other levels
% and not passed through the filter ?
edge = model.get_edge('Destination', nodeName);
edge = bm.get_edge('Destination', nodeName);
contrastsList = edge.Filter.contrast;
end

for j = 1:numel(contrastsList)

thisContrast = model.get_contrasts('Name', nodeName);
thisContrast = bm.get_contrasts('Name', nodeName);

spmMatFile = fullfile(getRFXdir(opt, nodeName, contrastsList{j}), 'SPM.mat');

Expand Down
10 changes: 5 additions & 5 deletions src/batches/stats/setBatchSubjectLevelContrasts.m
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@
printBatchName('subject level contrasts specification', opt);

spmMatFile = fullfile(getFFXdir(subLabel, opt), 'SPM.mat');
if noSPMmat(opt, subLabel, spmMatFile)
if ~checkSpmMat(dir, opt)
return
end

load(spmMatFile, 'SPM');

model = opt.model.bm;
model.validateConstrasts();
bm = opt.model.bm;
bm.validateConstrasts();

% Create Contrasts
if nargin < 4 || isempty(nodeName)
contrasts = specifyContrasts(model, SPM);
contrasts = specifyContrasts(bm, SPM);
else
contrasts = specifyContrasts(model, SPM, nodeName);
contrasts = specifyContrasts(bm, SPM, nodeName);
end

validateContrasts(contrasts);
Expand Down
12 changes: 6 additions & 6 deletions src/batches/stats/setBatchSubjectLevelGLMSpec.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
logger('ERROR', msg, 'filename', mfilename(), 'id', 'missingRawDir');
end

opt.model.bm.getModelType();
bm = opt.model.bm;

printBatchName('specify subject level fmri model', opt);

Expand Down Expand Up @@ -87,11 +87,11 @@

fmri_spec.fact = struct('name', {}, 'levels', {});

fmri_spec.mthresh = opt.model.bm.getInclusiveMaskThreshold();
fmri_spec.mthresh = bm.getInclusiveMaskThreshold();

fmri_spec.bases.hrf.derivs = opt.model.bm.getHRFderivatives();
fmri_spec.bases.hrf.derivs = bm.getHRFderivatives();

fmri_spec.cvi = opt.model.bm.getSerialCorrelationCorrection();
fmri_spec.cvi = bm.getSerialCorrelationCorrection();

%% List scans, onsets, confounds for each task / session / run
subLabel = regexify(subLabel);
Expand Down Expand Up @@ -167,7 +167,7 @@
% multicondition selection
fmri_spec.sess(iSpmSess).cond = struct('name', {}, 'onset', {}, 'duration', {});

fmri_spec.sess(iSpmSess).hpf = opt.model.bm.getHighPassFilter();
fmri_spec.sess(iSpmSess).hpf = bm.getHighPassFilter();

end

Expand All @@ -185,7 +185,7 @@
matlabbatch{end + 1}.spm.stats.fmri_design = fmri_spec;

else
node = opt.model.bm.get_root_node;
node = model.get_root_node();

fmri_spec.mask = {getInclusiveMask(opt, node.Name, BIDS, subLabel)};
matlabbatch{end + 1}.spm.stats.fmri_spec = fmri_spec;
Expand Down
8 changes: 5 additions & 3 deletions src/batches/stats/setBatchTwoSampleTTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@
% }
% }

edge = opt.model.bm.get_edge('Destination', nodeName);
bm = opt.model.bm;

edge = bm.get_edge('Destination', nodeName);

if isfield(edge, 'Filter') && ...
isfield(edge.Filter, 'contrast') && ...
Expand All @@ -195,12 +197,12 @@
else

% TODO?? can't imagine a 2 sample t-test with dummy contrasts
node = opt.model.bm.get_nodes('Name', nodeName);
node = bm.get_nodes('Name', nodeName);

% if no specific dummy contrasts mentioned also include all contrasts from previous levels
% or if contrasts are mentioned we grab them
if isfield(node, 'Contrasts')
tmp = getContrastsList(opt.model.bm, nodeName);
tmp = getContrastsList(bm, nodeName);
for i = 1:numel(tmp)
contrastsList{end + 1} = tmp{i}.Name;
end
Expand Down
12 changes: 6 additions & 6 deletions src/bids_model/getContrastsListForFactorialDesign.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
% assuming we want to only average / comparisons at the group level
participants = bids.util.tsvread(fullfile(opt.dir.raw, 'participants.tsv'));

model = opt.model.bm;
bm = opt.model.bm;

groupGlmType = model.groupLevelGlmType(nodeName, participants);
groupGlmType = bm.groupLevelGlmType(nodeName, participants);

if ismember(groupGlmType, {'one_sample_t_test', 'one_way_anova'})

edge = model.get_edge('Destination', nodeName);
edge = bm.get_edge('Destination', nodeName);

if isfield(edge, 'Filter') && ...
isfield(edge.Filter, 'contrast') && ...
Expand All @@ -34,14 +34,14 @@
else

% this assumes DummyContrasts exist
contrastsList = getDummyContrastsList(opt.model.bm, nodeName, participants);
contrastsList = getDummyContrastsList(bm, nodeName, participants);

node = model.get_nodes('Name', nodeName);
node = bm.get_nodes('Name', nodeName);

% if no specific dummy contrasts mentioned also include all contrasts from previous levels
% or if contrasts are mentioned we grab them
if ~isfield(node.DummyContrasts, 'Contrasts') || isfield(node, 'Contrasts')
tmp = getContrastsList(opt.model.bm, nodeName, columns);
tmp = getContrastsList(bm, nodeName, columns);
for i = 1:numel(tmp)
contrastsList{end + 1} = tmp{i}.Name;
end
Expand Down
8 changes: 5 additions & 3 deletions src/bids_model/getInclusiveMask.m
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

% (C) Copyright 2022 bidspm developers

bm = opt.model.bm;

if nargin < 2
[mask, nodeName] = opt.model.bm.getModelMask();
[mask, nodeName] = bm.getModelMask();
else
[mask, nodeName] = opt.model.bm.getModelMask('Name', nodeName);
[mask, nodeName] = bm.getModelMask('Name', nodeName);
end

node = opt.model.bm.get_nodes('Name', nodeName);
node = bm.get_nodes('Name', nodeName);

% TODO refactor with bidsResults part for checking background for montage
if isstruct(mask)
Expand Down
45 changes: 0 additions & 45 deletions src/messages/noSPMmat.m

This file was deleted.

8 changes: 5 additions & 3 deletions src/stats/subject_level/convertOnsetTsvToMat.m
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@
opt.model.bm = BidsModel('file', opt.model.file);
end

varToConvolve = opt.model.bm.getVariablesToConvolve();
designMatrix = opt.model.bm.getBidsDesignMatrix();
bm = opt.model.bm;

varToConvolve = bm.getVariablesToConvolve();
designMatrix = bm.getBidsDesignMatrix();
designMatrix = removeIntercept(designMatrix);

% conditions to be filled according to the conditions present in each run
Expand All @@ -125,7 +127,7 @@
condToModel.idx = 1;

% TODO get / apply transformers from a specific node
transformers = opt.model.bm.getBidsTransformers();
transformers = bm.getBidsTransformers();
tsv.content = bids.transformers(transformers, tsv.content);

for iVar = 1:numel(varToConvolve)
Expand Down
6 changes: 4 additions & 2 deletions src/stats/subject_level/createAndReturnCounfoundMatFile.m
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
return
end

designMatrix = opt.model.bm.getBidsDesignMatrix();
bm = opt.model.bm;

transformers = opt.model.bm.getBidsTransformers();
designMatrix = bm.getBidsDesignMatrix();

transformers = bm.getBidsTransformers();
content = bids.transformers(transformers, content);

[names, R] = createConfounds(content, designMatrix, opt.glm.maxNbVols); %#ok<*ASGLU>
Expand Down
Loading

0 comments on commit b472555

Please sign in to comment.