diff --git a/deepcalcium/models/spikes/unet_1d_segmentation.py b/deepcalcium/models/spikes/unet_1d_segmentation.py index 6246ce2..91e0b51 100644 --- a/deepcalcium/models/spikes/unet_1d_segmentation.py +++ b/deepcalcium/models/spikes/unet_1d_segmentation.py @@ -45,7 +45,7 @@ def on_epoch_end(self, epoch, logs): dpi=120) -def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.1, margin=4): +def unet1d(window_shape=(128,), nb_filters_base=32, conv_kernel_init='he_normal', prop_dropout_base=0.05, margin=4): """Builds and returns the UNet architecture using Keras. # Arguments window_shape: tuple of one integer defining the input/output window shape. @@ -134,10 +134,12 @@ def conv_layer(nbf, x): x = conv_layer(nfb, x) x = conv_layer(nfb, x) - # Apply the error margin before softmax activation. - x = MaxPooling1D(margin + 1, strides=1, padding='same')(x) - x = Conv1D(2, 1, activation='softmax')(x) + x = Conv1D(2, 1)(x) + x = MaxPooling1D(margin+1, strides=1, padding='same')(x) + x = Activation('softmax')(x) + #x = Lambda(lambda x: x[:, :, 1:])(x) + #x = MaxPooling1D(margin + 1, strides=1, padding='same')(x) x = Lambda(lambda x: x[:, :, -1])(x) model = Model(inputs=inputs, outputs=x) @@ -205,7 +207,7 @@ def __init__(self, cpdir='%s/spikes_unet1d' % CHECKPOINTS_DIR, def fit(self, dataset_paths, shape=(4096,), error_margin=4., batch=20, nb_epochs=20, val_type='random_split', prop_trn=0.8, - prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.001)): + prop_val=0.2, nb_folds=5, keras_callbacks=[], optimizer=Adam(0.002)): """Constructs model based on parameters and trains with the given data. Internally, the function uses a local function to abstract the training for both validation types. @@ -282,10 +284,7 @@ def loss(yt, yp): ModelCheckpoint('%s/%d_model_val_F2_{val_F2:3f}_{epoch:03d}.hdf5' % cpt, monitor='val_F2', mode='max', verbose=1, save_best_only=True), CSVLogger('%s/%d_metrics.csv' % cpt), - MetricsPlotCallback('%s/%d_metrics.png' % cpt), - ReduceLROnPlateau(monitor='val_F2', factor=0.5, min_lr=0.0001, - patience=max(10, int(nb_epochs * 0.2)), - mode='max', epsilon=1e-2, verbose=1) + MetricsPlotCallback('%s/%d_metrics.png' % cpt) ] # Train.