Skip to content

Commit

Permalink
linting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
frcaud committed Dec 1, 2022
1 parent a4c1ac8 commit 73f1ab8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
3 changes: 2 additions & 1 deletion download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

os.makedirs(PATH_DATA, exist_ok=True)

URL_DATA = 'ftp://ftp.cea.fr/pub/unati/people/educhesnay/data/brain_anatomy_schizophrenia_data/sz_public_202211.zip' # PUBLIC DATASET
URL_DATA = 'ftp://ftp.cea.fr/pub/unati/people/educhesnay/data/\
brain_anatomy_schizophrenia_data/sz_public_202211.zip' # PUBLIC DATASET


def fetch_data(urls, dst, verbose=1):
Expand Down
11 changes: 8 additions & 3 deletions problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@


N_FOLDS = 5
problem_title = 'Predict schizophrenia from brain grey matter (classification)'
problem_title = 'Predict schizophrenia from \
brain grey matter (classification)'

_target_column_name = 'diagnosis'
_prediction_label_names = ['control', 'schizophrenia']

# A type (class) which will be used to create wrapper objects for y_pred
Predictions = rw.prediction_types.make_multiclass(label_names=_prediction_label_names)
Predictions = rw.prediction_types.make_multiclass(
label_names=_prediction_label_names)
# An object implementing the workflow
workflow = rw.workflows.Estimator()

Expand All @@ -24,10 +26,12 @@
rw.score_types.Accuracy(name='acc')
]


def get_cv(X, y):
cv_train = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=0)
return cv_train.split(X, y)


def _read_data(path, dataset, datatype=['rois', 'vbm']):
""" Read data.
Expand Down Expand Up @@ -65,7 +69,8 @@ def _read_data(path, dataset, datatype=['rois', 'vbm']):

# Read 3d images and mask
if'vbm' in datatype:
imgs_arr_zip = np.load(os.path.join(path, 'data', "%s_vbm.npz" % dataset))
imgs_arr_zip = np.load(os.path.join(path, 'data',
"%s_vbm.npz" % dataset))
x_img_arr = imgs_arr_zip['imgs_arr'].squeeze()
mask_arr = imgs_arr_zip['mask_arr']
x_img_arr = x_img_arr[:, mask_arr]
Expand Down
6 changes: 5 additions & 1 deletion submissions/starting_kit/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class ROIsFeatureExtractor(BaseEstimator, TransformerMixin):
"""Select only the 284 ROIs features:"""

def fit(self, X, y):
return self

Expand All @@ -27,6 +28,9 @@ def get_estimator():
"""Build your estimator here."""
estimator = make_pipeline(
ROIsFeatureExtractor(),
MLPClassifier(random_state=1, hidden_layer_sizes=(200, 150, 100, 50, 25, )))
MLPClassifier(random_state=1,
hidden_layer_sizes=(200, 150, 100, 50, 25, ),
)
)

return estimator

0 comments on commit 73f1ab8

Please sign in to comment.