Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the weights available at https://lmb.informatik.uni-freiburg.de/resources/opensource/unet/ for Tensorflow 2 (Keras) #78

Open
maxclac opened this issue Mar 1, 2021 · 6 comments

Comments

@maxclac
Copy link

maxclac commented Mar 1, 2021

Hi everyone!

I am trying to use the Caffe weights available at https://lmb.informatik.uni-freiburg.de/resources/opensource/unet/ for a model implemented in Tensorflow.

To be more specific, I read the 2d_cell_net_v0.modeldef.h5 file in order to get the layer names in the right order:

layer_names=[]
mod=h5py.File('2d_cell_net_v0.modeldef.h5', 'r')
for line in mod['model_prototxt'].value.decode('utf8').split('\n'):
    if line.startswith('layer'):
        name=line.split('name:')[1].split(' ')[1].replace('\'','')
        layer_names.append(name) 

Then, I loop through the layers available in the 2d_cell_net_v0.caffemodel.h5 file and try to find the corresponding layers in TensorFlow Keras. Doing this, I saw that the shape of the layers were different and I corrected this. In the end, I extract the weights from the Caffe layer and set the Tensorflow layer with them.

size_image=512
unet = networks.UNet_Freiburg((size_image,size_image,1))
caffe_weights=[]
unet_weights=h5py.File('2d_cell_net_v0.caffemodel.h5', 'r')
data=unet_weights['data']

for layer_name in layer_names:
    for layer in data:        
        l=unet_weights['/data/'+layer]  
        name=l.name.split('/')[-1]
        if name==layer_name:
            if '0' in l:
                weight_array=np.array(l['0']).T
                bias_array=np.array(l['1']).T
                if 'up' in name:
                    weight_array=np.swapaxes(weight_array,2,3)
                shape_weights=weight_array.shape    
                shape_bias=bias_array.shape

                for ul in unet.layers:
                    if ul.name==name:
                        ul.set_weights([weight_array,bias_array])

The results are not what I expected:

caffe_tf

Could anyone help me with this? What am I missing?

@ThorstenFalk
Copy link
Collaborator

In principle your approach is fine. caffe stores the weights in order (cOut, cIn, y, x) for convolution and (cIn, cOut, y, x) for up-convolution. Probably dimensions must be permuted a little more to work with the corresponding keras layers? I don't know the native dimension order of keras/tensorflow by heart. But this is my first guess of what might cause your strange outputs.

What kind of padding do you use? The plugin processes the image in overlapping tiles and uses mirroring to extrapolate data across image boundaries. You probably used zero padding leading to this wide bright border?

@maxclac
Copy link
Author

maxclac commented Mar 2, 2021

In TensorFlow, I use the Conv2D object with padding='same'. I changed it to 'valid', but it causes some problem now with layer concatenations. In my network, it is done like this:

def UNet_Freiburg(shape):
    inputs = tfk.layers.Input(shape=shape, name='input')
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d0a-b')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d0b-c')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2), name='maxpool1')(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d1a-b')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d1b-c')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2), name='maxpool2')(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d2a-b')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d2b-c')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), name='maxpool3')(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d3a-b')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d3b-c')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2), name='maxpool4')(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d4a-b')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d4b-c')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_d4c_u3a')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u3b-c')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u3c-d')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u3d_u2a')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u2b-c')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u2c-d')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u2d_u1a')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u1b-c')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u1c-d')(conv8)

    up9 = Conv2D(128, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u1d_u0a')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0b-c')(merge9)
    conv9 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0c-d')(conv9)
    conv9 = Conv2D(2, 1, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0d-score')(conv9)
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs, conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
    return model

and I get this error message:

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 56, 56, 512), (None, 47, 47, 512)]

I guess I need to rearrange the whole network in order to adapt to the padding change.

@ThorstenFalk
Copy link
Collaborator

Ah, OK, you used padding, you're right U-Net uses valid convolutions. The left blob must be cropped to match the spatial shape of the right blob.

This is my keras implementation of U-Net:

class Unet2D:

  def __init__(self, snapshot=None, n_channels=1, n_classes=2, n_levels=4,
               n_features=64, name="U-Net"):

    self.concat_blobs = []

    self.n_channels = n_channels
    self.n_classes = n_classes
    self.n_levels = n_levels
    self.n_features = n_features
    self.name = name

    self.trainModel, self.padding = self._createModel(True)
    self.testModel, _ = self._createModel(False)

    if snapshot is not None:
      self.trainModel.load_weights(snapshot)
      self.testModel.load_weights(snapshot)

  def _weighted_categorical_crossentropy(self, y_true, y_pred, weights):
    return tf.losses.softmax_cross_entropy(
      y_true, y_pred, weights=weights, reduction=tf.losses.Reduction.MEAN)

  def _createModel(self, training):

    data = keras.layers.Input(shape=(None, None, self.n_channels), name="data")

    concat_blobs = []

    if training:
      labels = keras.layers.Input(
        shape=(None, None, self.n_classes), name="labels")
      weights = keras.layers.Input(shape=(None, None), name="weights")

    # Modules of the analysis path consist of two convolutions and max pooling
    for l in range(self.n_levels):
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal",
          name="conv_d{}a-b".format(l))(data if l == 0 else t))
      concat_blobs.append(
        keras.layers.LeakyReLU(alpha=0.1)(
          keras.layers.Conv2D(
            2**l * self.n_features, 3, padding="valid",
            kernel_initializer="he_normal", name="conv_d{}b-c".format(l))(t)))
      t = keras.layers.MaxPooling2D(pool_size=(2, 2))(concat_blobs[-1])

    # Deepest layer has two convolutions only
    t = keras.layers.LeakyReLU(alpha=0.1)(
      keras.layers.Conv2D(
        2**self.n_levels * self.n_features, 3, padding="valid",
        kernel_initializer="he_normal",
        name="conv_d{}a-b".format(self.n_levels))(t))
    t = keras.layers.LeakyReLU(alpha=0.1)(
      keras.layers.Conv2D(
        2**self.n_levels * self.n_features, 3, padding="valid",
        kernel_initializer="he_normal",
        name="conv_d{}b-c".format(self.n_levels))(t))
    pad = 8

    # Modules in the synthesis path consist of up-convolution,
    # concatenation and two convolutions
    for l in range(self.n_levels - 1, -1, -1):
      name = "upconv_{}{}{}_u{}a".format(
        *(("d", l+1, "c", l) if l == self.n_levels - 1 else ("u", l+1, "d", l)))
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 2, padding="same",
          kernel_initializer="he_normal", name=name)(
            keras.layers.UpSampling2D(size = (2,2))(t)))
      t = keras.layers.Concatenate()(
        [keras.layers.Cropping2D(cropping=int(pad / 2))(concat_blobs[l]), t])
      pad = 2 * (pad + 8)
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal", name="conv_u{}b-c".format(l))(t))
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal", name="conv_u{}c-d".format(l))(t))
    pad /= 2

    score = keras.layers.Conv2D(
      self.n_classes, 1, kernel_initializer = 'he_normal',
      name="conv_u0d-score")(t)
    softmax_score = keras.layers.Softmax()(score)

    if training:
      model = keras.Model(inputs=[data, labels, weights], outputs=softmax_score)
      model.add_loss(
        self._weighted_categorical_crossentropy(labels, score, weights))
      adam = keras.optimizers.Adam(
        lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0,
        amsgrad=False)
      model.compile(optimizer=adam, loss=None)
    else:
      model = keras.Model(inputs=data, outputs=softmax_score)

    return model, int(pad)

  def loadCaffeModelH5(self, path):
    train_layer_dict = dict([(layer.name, layer)
                             for layer in self.trainModel.layers])
    test_layer_dict = dict([(layer.name, layer)
                            for layer in self.testModel.layers])
    pre = h5py.File(path, 'a')
    l = list(pre['data'].keys())
    for i in l:
      kernel = pre['data'][i]['0'][()]
      bias = pre['data'][i]['1'][()]
      train_layer_dict[i].set_weights([kernel,bias])
      test_layer_dict[i].set_weights([kernel,bias])
    pre.close()

  def train(self, sample_generator, validation_generator=None,
            n_epochs=100, snapshot_interval=1, snapshot_prefix=None):

    callbacks = [TensorBoard(log_dir="logs/{}-{}".format(self.name, time()))]
    if snapshot_prefix is not None:
      callbacks.append(keras.callbacks.ModelCheckpoint(
        (snapshot_prefix if snapshot_prefix is not None else self.name) +
        ".{epoch:04d}.h5", mode='auto', period=snapshot_interval))
    self.trainModel.fit_generator(
      sample_generator, epochs=n_epochs, validation_data=validation_generator,
      verbose=1, callbacks=callbacks)

  def predict(self, tile_generator):

    smscores = []
    segmentations = []

    for tileIdx in range(tile_generator.__len__()):
      tile = tile_generator.__getitem__(tileIdx)
      outIdx = tile[0]["image_index"]
      outShape = tile[0]["image_shape"]
      outSlice = tile[0]["out_slice"]
      inSlice = tile[0]["in_slice"]
      softmax_score = self.testModel.predict(tile[0]["data"], verbose=1)
      if len(smscores) < outIdx + 1:
        smscores.append(np.empty((*outShape, self.n_classes)))
        segmentations.append(np.empty(outShape))
      smscores[outIdx][outSlice] = softmax_score[0][inSlice]
      segmentations[outIdx][outSlice] = np.argmax(
        softmax_score[0], axis=-1)[inSlice]

    return smscores, segmentations

I think this should make everything clear. It even includes loading the weights.

All the best,
Thorsten

@maxclac
Copy link
Author

maxclac commented Mar 3, 2021

Thank you! I will have a look.

@maxclac
Copy link
Author

maxclac commented Mar 5, 2021

This is unfortunately not TensorFlow 2, I will have to do some work to adapt it to my environment.

@maxclac
Copy link
Author

maxclac commented Mar 22, 2021

Hi again!
Now what I did is simply go back to TensorFlow 1.14, rather than porting the code to TF2.
I have now a problem that I already have before, namely that in

    def loadCaffeModelH5(self, path):
        train_layer_dict = dict([(layer.name, layer)
                                 for layer in self.trainModel.layers])
        test_layer_dict = dict([(layer.name, layer)
                                for layer in self.testModel.layers])
        pre = h5py.File(path, 'a')
        l = list(pre['data'].keys())
        for i in l:
            print(i)
            print(pre['data'][i].keys())
            try:
                kernel = pre['data'][i]['0'][()]
                bias = pre['data'][i]['1'][()]
                train_layer_dict[i].set_weights([kernel, bias])
                test_layer_dict[i].set_weights([kernel, bias])
            except KeyError:
                pass
        pre.close()

there is a conflict between the shapes of the layers in Caffe and the shapes expected by TF:

ValueError: Layer weight shape (3, 3, 1, 64) not compatible with provided weight shape (64, 1, 3, 3)

This could maybe be solved by taking the transpose of the arrays, but then there is the issue with the upconv layers:

ValueError: Layer weight shape (3, 3, 128, 64) not compatible with provided weight shape (3, 3, 192, 128)

These are the problems I had already before setting up this issue. Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants