diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 685d5f0..6a02e09 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,58 +1,63 @@ -# Automatically build multi-architectural tagged container images and push them to DockerHub -# https://github.com/FNNDSC/cookiecutter-chrisapp/wiki/Automatic-Builds +# Continuous integration testing for ChRIS Plugin. +# https://github.com/FNNDSC/python-chrisapp-template/wiki/Continuous-Integration # -# - targeted platforms: x86_64, PowerPC64, ARM64 -# - master is built as fnndsc/pl-fastsurfer_inference:latest -# - tagged commits are built as fnndsc/pl-fastsurfer_inference: -# - tagged commits are also uploaded to chrisstore.co -# -# In order to use this workflow, see -# https://github.com/FNNDSC/cookiecutter-chrisapp/wiki/Automatic-Builds#steps-to-enable +# - on push and PR: run pytest +# - on push to main: build and push container images as ":latest" +# - on push to semver tag: build and push container image with tag and +# upload plugin description to https://chrisstore.co -name: ci +name: build on: push: - # we have to guess what the name of the default branch is - branches: [ master, main, trunk ] - tags: [ '**' ] + branches: [ main ] + tags: + - "v?[0-9]+.[0-9]+.[0-9]+*" pull_request: - branches: [ master, main, trunk ] + branches: [ main ] jobs: test: - runs-on: ubuntu-20.04 + name: Unit tests + if: false # delete this line to enable automatic testing + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v2 - - name: build - run: docker build -t "${GITHUB_REPOSITORY,,}" . - - name: nose tests - run: docker run "${GITHUB_REPOSITORY,,}" nosetests + - uses: actions/checkout@v3 + - uses: docker/setup-buildx-action@v2 + - name: Cache Docker layers + uses: actions/cache@v3 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-buildx- + - name: Build + uses: docker/build-push-action@v3 + with: + build-args: extras_require=dev + context: . + load: true + push: false + tags: "localhost/local/app:dev" + cache-from: type=local,src=/tmp/.buildx-cache + cache-to: type=local,dest=/tmp/.buildx-cache + - name: Run pytest + run: | + docker run -v "$GITHUB_WORKSPACE:/app:ro" -w /app localhost/local/app:dev \ + pytest -o cache_dir=/tmp/pytest - publish: + build: + name: Build if: github.event_name == 'push' || github.event_name == 'release' - runs-on: ubuntu-20.04 - - # we want to both push the build to DockerHub, but also - # keep a local copy so that we can run - # - # docker run fnndsc/pl-app app --json > App.json - # - # buildx currently does not support multiple output locations, - # neither can multi-architectural builds be loaded into docker. - # Here we use a local registry to cache the build. - services: - registry: - image: registry:2 - ports: - - 5000:5000 + # needs: [ test ] # uncomment to require passing tests + runs-on: ubuntu-22.04 steps: - name: Get git tag id: git_info if: startsWith(github.ref, 'refs/tags/') - run: echo "::set-output name=tag::${GITHUB_REF##*/}" - - name: Decide image tag name + run: echo "tag=${GITHUB_REF##*/}" >> $GITHUB_OUTPUT + - name: Get project info id: determine env: git_tag: ${{ steps.git_info.outputs.tag }} @@ -60,51 +65,101 @@ jobs: repo="${GITHUB_REPOSITORY,,}" # to lower case # if build triggered by tag, use tag name tag="${git_tag:-latest}" + + # if tag is a version number prefixed by 'v', remove the 'v' + if [[ "$tag" =~ ^v[0-9].* ]]; then + tag="${tag:1}" + fi + dock_image=$repo:$tag echo $dock_image - echo "::set-output name=dock_image::$dock_image" - echo "::set-output name=repo::$repo" - - uses: actions/checkout@v2 + echo "dock_image=$dock_image" >> $GITHUB_OUTPUT + echo "repo=$repo" >> $GITHUB_OUTPUT - # QEMU is for emulating non-x86_64 platforms - - uses: docker/setup-qemu-action@v1 - # buildx is the next-generation docker image builder - - uses: docker/setup-buildx-action@v1 + - uses: actions/checkout@v3 + # QEMU is used for non-x86_64 builds + - uses: docker/setup-qemu-action@v2 + # buildx adds additional features to docker build + - uses: docker/setup-buildx-action@v2 with: driver-opts: network=host - # save some time during rebuilds + # cache slightly improves rebuild time - name: Cache Docker layers - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /tmp/.buildx-cache key: ${{ runner.os }}-buildx-${{ github.sha }} restore-keys: | ${{ runner.os }}-buildx- + + # Here, we want to do the docker build twice: + # The first build pushes to our local registry for testing. + # The second build pushes to Docker Hub and ghcr.io + - name: Build (local only) + uses: docker/build-push-action@v3 + id: docker_build + with: + context: . + file: ./Dockerfile + tags: localhost/${{ steps.determine.outputs.dock_image }} + load: true + cache-from: type=local,src=/tmp/.buildx-cache + cache-to: type=local,dest=/tmp/.buildx-cache + # If you have a directory called examples/incoming/ and examples/outgoing/, then + # run your ChRIS plugin with no parameters, and asser that it creates all the files + # which are expected. File contents are not compared. + - name: Run examples + id: run_examples + run: | + if ! [ -d 'examples/incoming/' ] || ! [ -d 'examples/outgoing/' ]; then + echo "No examples." + exit 0 + fi + + dock_image=localhost/${{ steps.determine.outputs.dock_image }} + output_dir=$(mktemp -d) + cmd=$(docker image inspect -f '{{ (index .Config.Cmd 0) }}' $dock_image) + docker run --rm -u "$(id -u):$(id -g)" \ + -v "$PWD/examples/incoming:/incoming:ro" \ + -v "$output_dir:/outgoing:rw" \ + $dock_image $cmd /incoming /outgoing + + for expected_file in $(find examples/outgoing -type f); do + fname="${expected_file##*/}" + out_path="$output_dir/$fname" + printf "Checking output %s exists..." "$out_path" + if [ -f "$out_path" ]; then + echo "ok" + else + echo "not found" + exit 1 + fi + done + - name: Login to DockerHub id: dockerhub_login - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - + - name: Login to GitHub Container Registry - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push - uses: docker/build-push-action@v2 - id: docker_build + uses: docker/build-push-action@v3 + if: github.event_name == 'push' || github.event_name == 'release' with: context: . file: ./Dockerfile tags: | - ${{ steps.determine.outputs.dock_image }} - localhost:5000/${{ steps.determine.outputs.dock_image }} + docker.io/${{ steps.determine.outputs.dock_image }} ghcr.io/${{ steps.determine.outputs.dock_image }} - platforms: linux/amd64 + # if non-x86_84 architectures are supported, add them here + platforms: linux/amd64 #,linux/arm64,linux/ppc64le push: true cache-from: type=local,src=/tmp/.buildx-cache cache-to: type=local,dest=/tmp/.buildx-cache @@ -114,49 +169,16 @@ jobs: run: | repo=${{ steps.determine.outputs.repo }} dock_image=${{ steps.determine.outputs.dock_image }} - docker pull localhost:5000/$dock_image - docker tag localhost:5000/$dock_image $dock_image - script=$(docker inspect --format '{{ (index .Config.Cmd 0) }}' $dock_image) - json="$(docker run --rm $dock_image $script --json)" - jq <<< "$json" # pretty print in log - echo "::set-output name=json::$json" - echo "::set-output name=title::$(jq -r '.title' <<< "$json")" + docker run --rm localhost/$dock_image chris_plugin_info > /tmp/description.json + jq < /tmp/description.json # pretty print in log + echo "title=$(jq -r '.title' < /tmp/description.json)" >> $GITHUB_OUTPUT + - name: Update DockerHub description - uses: peter-evans/dockerhub-description@v2 + uses: peter-evans/dockerhub-description@v3 continue-on-error: true # it is not crucial that this works with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} short-description: ${{ steps.pluginmeta.outputs.title }} - readme-filepath: ./README.rst - repository: ${{ steps.determine.outputs.repo }} - - - name: Upload to ChRIS Store - if: "!endsWith(steps.determine.outputs.dock_image, ':latest')" - run: | - dock_image=${{ steps.determine.outputs.dock_image }} - plname="$(sed 's/^.*\///' <<< $GITHUB_REPOSITORY)" && echo "name=$plname" - descriptor_file=$(mktemp --suffix .json) - cat > $descriptor_file << ENDOFPLUGINJSONDESCRIPTION - ${{ steps.pluginmeta.outputs.json }} - ENDOFPLUGINJSONDESCRIPTION - res=$( - curl -s -u "${{ secrets.CHRIS_STORE_USER }}" "https://chrisstore.co/api/v1/plugins/" \ - -H 'Accept:application/vnd.collection+json' \ - -F "name=$plname" \ - -F "dock_image=$dock_image" \ - -F "descriptor_file=@$descriptor_file" \ - -F "public_repo=https://github.com/${{ github.repository }}" - ) - success=$? - echo "::debug::$res" - if [ "$success" = "0" ]; then - href="$(jq -r '.collection.items[0].href' <<< "$res")" - echo $href - echo "::set-output name=pluginurl::$href" - else - echo "::error ::Failed upload to ChRIS Store" - echo "$res" - exit $success - fi - + readme-filepath: ./README.md + repository: ${{ steps.determine.outputs.repo }} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 5f6dd96..e37502a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,10 +21,10 @@ # docker run -ti -e HOST_IP=$(ip route | grep -v docker | awk '{if(NF==11) print $9}') --entrypoint /bin/bash local/pl-fastsurfer_inference # -FROM tensorflow/tensorflow:latest-gpu-py3 +FROM deepmi/fastsurfer:gpu-v1.1.1 LABEL maintainer="Martin Reuter (FastSurfer), Sandip Samal (FNNDSC) (sandip.samal@childrens.harvard.edu) " -WORKDIR /usr/local/src +WORKDIR /fastsurfer COPY requirements.txt . RUN pip install -r requirements.txt @@ -32,4 +32,4 @@ RUN pip install -r requirements.txt COPY . . RUN pip install . -CMD ["fastsurfer_inference", "--help"] +ENTRYPOINT [""] diff --git a/checkpoints/Axial_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl b/checkpoints/Axial_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl deleted file mode 100644 index 1d42158..0000000 Binary files a/checkpoints/Axial_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl and /dev/null differ diff --git a/checkpoints/Coronal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl b/checkpoints/Coronal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl deleted file mode 100644 index e652daa..0000000 Binary files a/checkpoints/Coronal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl and /dev/null differ diff --git a/checkpoints/Sagittal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl b/checkpoints/Sagittal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl deleted file mode 100644 index 7fe0346..0000000 Binary files a/checkpoints/Sagittal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl and /dev/null differ diff --git a/data_loader/__init__.py b/data_loader/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/data_loader/augmentation.py b/data_loader/augmentation.py deleted file mode 100644 index c6f9276..0000000 --- a/data_loader/augmentation.py +++ /dev/null @@ -1,136 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import numpy as np -import torch - - -## -# Transformations for evaluation -## -class ToTensorTest(object): - """ - Convert np.ndarrays in sample to Tensors. - """ - - def __call__(self, img): - img = img.astype(np.float32) - - # Normalize and clamp between 0 and 1 - img = np.clip(img / 255.0, a_min=0.0, a_max=1.0) - - # swap color axis because - # numpy image: H x W x C - # torch image: C X H X W - img = img.transpose((2, 0, 1)) - - return img - - -## -# Transformations for training -## -class ToTensor(object): - """ - Convert ndarrays in sample to Tensors. - """ - - def __call__(self, sample): - img, label, weight = sample['img'], sample['label'], sample['weight'] - - img = img.astype(np.float32) - - # Normalize image and clamp between 0 and 1 - img = np.clip(img / 255.0, a_min=0.0, a_max=1.0) - - # swap color axis because - # numpy image: H x W x C - # torch image: C X H X W - img = img.transpose((2, 0, 1)) - - return {'img': torch.from_numpy(img), 'label': label, 'weight': weight} - - -class AugmentationPadImage(object): - """ - Pad Image with either zero padding or reflection padding of img, label and weight - """ - - def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="edge"): - - assert isinstance(pad_size, (int, tuple)) - - if isinstance(pad_size, int): - - # Do not pad along the channel dimension - self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) - self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size)) - - else: - self.pad_size = pad_size - - self.pad_type = pad_type - - def __call__(self, sample): - img, label, weight = sample['img'], sample['label'], sample['weight'] - - img = np.pad(img, self.pad_size_image, self.pad_type) - label = np.pad(label, self.pad_size_mask, self.pad_type) - weight = np.pad(weight, self.pad_size_mask, self.pad_type) - - return {'img': img, 'label': label, 'weight': weight} - - -class AugmentationRandomCrop(object): - """ - Randomly Crop Image to given size - """ - - def __init__(self, output_size, crop_type='Random'): - - assert isinstance(output_size, (int, tuple)) - - if isinstance(output_size, int): - self.output_size = (output_size, output_size) - - else: - self.output_size = output_size - - self.crop_type = crop_type - - def __call__(self, sample): - img, label, weight = sample['img'], sample['label'], sample['weight'] - - h, w, _ = img.shape - - if self.crop_type == 'Center': - top = (h - self.output_size[0]) // 2 - left = (w - self.output_size[1]) // 2 - - else: - top = np.random.randint(0, h - self.output_size[0]) - left = np.random.randint(0, w - self.output_size[1]) - - bottom = top + self.output_size[0] - right = left + self.output_size[1] - - # print(img.shape) - img = img[top:bottom, left:right, :] - label = label[top:bottom, left:right] - weight = weight[top:bottom, left:right] - - return {'img': img, 'label': label, 'weight': weight} diff --git a/data_loader/conform.py b/data_loader/conform.py deleted file mode 100644 index 823f065..0000000 --- a/data_loader/conform.py +++ /dev/null @@ -1,300 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import optparse -import sys -import numpy as np -import nibabel as nib - -HELPTEXT = """ -Script to conform an MRI brain image to UCHAR, RAS orientation, and 1mm isotropic voxels - - -USAGE: -conform.py -i -o - - -Dependencies: - Python 3.5 - - Numpy - http://www.numpy.org - - Nibabel to read and write FreeSurfer data - http://nipy.org/nibabel/ - - -Original Author: Martin Reuter -Date: Jul-09-2019 - -""" - -h_input = 'path to input image' -h_output = 'path to ouput image' -h_order = 'order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic)' - - -def options_parse(): - """ - Command line option parser - """ - parser = optparse.OptionParser(version='$Id: conform.py,v 1.0 2019/07/19 10:52:08 mreuter Exp $', - usage=HELPTEXT) - parser.add_option('--input', '-i', dest='input', help=h_input) - parser.add_option('--output', '-o', dest='output', help=h_output) - parser.add_option('--order', dest='order', help=h_order, type="int", default=1) - (fin_options, args) = parser.parse_args() - if fin_options.input is None or fin_options.output is None: - sys.exit('ERROR: Please specify input and output images') - return fin_options - - -def map_image(img, out_affine, out_shape, ras2ras=np.array([[1.0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), - order=1): - """ - Function to map image to new voxel space (RAS orientation) - - :param nibabel.MGHImage img: the src 3D image with data and affine set - :param np.ndarray out_affine: trg image affine - :param np.ndarray out_shape: the trg shape information - :param np.ndarray ras2ras: ras2ras an additional maping that should be applied (default=id to just reslice) - :param int order: order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic) - :return: mapped Image data array - """ - from scipy.ndimage import affine_transform - from numpy.linalg import inv - - # compute vox2vox from src to trg - vox2vox = inv(out_affine) @ ras2ras @ img.affine - - # here we apply the inverse vox2vox (to pull back the src info to the target image) - new_data = affine_transform(img.get_data(), inv(vox2vox), output_shape=out_shape, order=order) - return new_data - - -def getscale(data, dst_min, dst_max, f_low=0.0, f_high=0.999): - """ - Function to get offset and scale of image intensities to robustly rescale to range dst_min..dst_max. - Equivalent to how mri_convert conforms images. - - :param np.ndarray data: Image data (intensity values) - :param float dst_min: future minimal intensity value - :param float dst_max: future maximal intensity value - :param f_low: robust cropping at low end (0.0 no cropping) - :param f_high: robust cropping at higher end (0.999 crop one thousandths of high intensity voxels) - :return: returns (adjusted) src_min and scale factor - """ - # get min and max from source - src_min = np.min(data) - src_max = np.max(data) - - if src_min < 0.0: - sys.exit('ERROR: Min value in input is below 0.0!') - - print("Input: min: " + format(src_min) + " max: " + format(src_max)) - - if f_low == 0.0 and f_high == 1.0: - return src_min, 1.0 - - # compute non-zeros and total vox num - nz = (np.abs(data) >= 1e-15).sum() - voxnum = data.shape[0] * data.shape[1] * data.shape[2] - - # compute histogram - histosize = 1000 - bin_size = (src_max - src_min) / histosize - hist, bin_edges = np.histogram(data, histosize) - - # compute cummulative sum - cs = np.concatenate(([0], np.cumsum(hist))) - - # get lower limit - nth = int(f_low * voxnum) - idx = np.where(cs < nth) - - if len(idx[0]) > 0: - idx = idx[0][-1] + 1 - - else: - idx = 0 - - src_min = idx * bin_size + src_min - - # print("bin min: "+format(idx)+" nth: "+format(nth)+" passed: "+format(cs[idx])+"\n") - # get upper limit - nth = voxnum - int((1.0 - f_high) * nz) - idx = np.where(cs >= nth) - - if len(idx[0]) > 0: - idx = idx[0][0] - 2 - - else: - print('ERROR: rescale upper bound not found') - - src_max = idx * bin_size + src_min - # print("bin max: "+format(idx)+" nth: "+format(nth)+" passed: "+format(voxnum-cs[idx])+"\n") - - # scale - if src_min == src_max: - scale = 1.0 - - else: - scale = (dst_max - dst_min) / (src_max - src_min) - - print("rescale: min: " + format(src_min) + " max: " + format(src_max) + " scale: " + format(scale)) - - return src_min, scale - - -def scalecrop(data, dst_min, dst_max, src_min, scale): - """ - Function to crop the intensity ranges to specific min and max values - - :param np.ndarray data: Image data (intensity values) - :param float dst_min: future minimal intensity value - :param float dst_max: future maximal intensity value - :param float src_min: minimal value to consider from source (crops below) - :param float scale: scale value by which source will be shifted - :return: scaled Image data array - """ - data_new = dst_min + scale * (data - src_min) - - # clip - data_new = np.clip(data_new, dst_min, dst_max) - print("Output: min: " + format(data_new.min()) + " max: " + format(data_new.max())) - - return data_new - - -def rescale(data, dst_min, dst_max, f_low=0.0, f_high=0.999): - """ - Function to rescale image intensity values (0-255) - - :param np.ndarray data: Image data (intensity values) - :param float dst_min: future minimal intensity value - :param float dst_max: future maximal intensity value - :param f_low: robust cropping at low end (0.0 no cropping) - :param f_high: robust cropping at higher end (0.999 crop one thousandths of high intensity voxels) - :return: returns scaled Image data array - """ - src_min, scale = getscale(data, dst_min, dst_max, f_low, f_high) - data_new = scalecrop(data, dst_min, dst_max, src_min, scale) - return data_new - - -def conform(img, order=1): - """ - Python version of mri_convert -c, which turns image intensity values into UCHAR, reslices images to standard position, fills up - slices to standard 256x256x256 format and enforces 1 mm isotropic voxel sizes. - - Difference to mri_convert -c is that we first interpolate (float image), and then rescale to uchar. mri_convert is - doing it the other way. However, we compute the scale factor from the input to be more similar again - - :param nibabel.MGHImage img: loaded source image - :param int order: interpolation order (0=nearest,1=linear(default),2=quadratic,3=cubic) - :return:nibabel.MGHImage new_img: conformed image - """ - from nibabel.freesurfer.mghformat import MGHHeader - - cwidth = 256 - csize = 1 - h1 = MGHHeader.from_header(img.header) # may copy some parameters if input was MGH format - - h1.set_data_shape([cwidth, cwidth, cwidth, 1]) - h1.set_zooms([csize, csize, csize]) - h1['Mdc'] = [[-1, 0, 0], [0, 0, -1], [0, 1, 0]] - h1['fov'] = cwidth - h1['Pxyz_c'] = img.affine.dot(np.hstack((np.array(img.shape[:3]) / 2.0, [1])))[:3] - - # from_header does not compute Pxyz_c (and probably others) when importing from nii - # Pxyz is the center of the image in world coords - - # get scale for conversion on original input before mapping to be more similar to mri_convert - src_min, scale = getscale(img.get_data(), 0, 255) - - mapped_data = map_image(img, h1.get_affine(), h1.get_data_shape(), order=order) - # print("max: "+format(np.max(mapped_data))) - - if not img.get_data_dtype() == np.dtype(np.uint8): - - if np.max(mapped_data) > 255: - mapped_data = scalecrop(mapped_data, 0, 255, src_min, scale) - - new_data = np.uint8(np.rint(mapped_data)) - new_img = nib.MGHImage(new_data, h1.get_affine(), h1) - - # make sure we store uchar - new_img.set_data_dtype(np.uint8) - - return new_img - - -def is_conform(img, eps=1e-06): - """ - Function to check if an image is already conformed or not (Dimensions: 256x256x256, Voxel size: 1x1x1, and - LIA orientation. - - :param nibabel.MGHImage img: Loaded source image - :param float eps: allowed deviation from zero for LIA orientation check (default 1e-06). - Small inaccuracies can occur through the inversion operation. Already conformed images are - thus sometimes not correctly recognized. The epsilon accounts for these small shifts. - :return: True if image is already conformed, False otherwise - """ - ishape = img.shape - - if len(ishape) > 3 and ishape[3] != 1: - sys.exit('ERROR: Multiple input frames (' + format(img.shape[3]) + ') not supported!') - - # check dimensions - if ishape[0] != 256 or ishape[1] != 256 or ishape[2] != 256: - return False - - # check voxel size - izoom = img.header.get_zooms() - if izoom[0] != 1.0 or izoom[1] != 1.0 or izoom[2] != 1.0: - return False - - # check orientation LIA - iaffine = img.affine[0:3, 0:3] + [[1, 0, 0], [0, 0, -1], [0, 1, 0]] - - if np.max(np.abs(iaffine)) > 0.0 + eps: - return False - - return True - - -if __name__ == "__main__": - # Command Line options are error checking done here - options = options_parse() - - print("Reading input: {} ...".format(options.input)) - image = nib.load(options.input) - - if len(image.shape) > 3 and image.shape[3] != 1: - sys.exit('ERROR: Multiple input frames (' + format(image.shape[3]) + ') not supported!') - - if is_conform(image): - sys.exit("Input " + format(options.input) + " is already conform! No output created.\n") - - new_image = conform(image, options.order) - print ("Writing conformed image: {}".format(options.output)) - - nib.save(new_image, options.output) - - sys.exit(0) - - diff --git a/data_loader/load_neuroimaging_data.py b/data_loader/load_neuroimaging_data.py deleted file mode 100644 index d0268df..0000000 --- a/data_loader/load_neuroimaging_data.py +++ /dev/null @@ -1,469 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import nibabel as nib -import numpy as np -import h5py - -from skimage.measure import label -from torch.utils.data.dataset import Dataset -from .conform import is_conform, conform - -## -# Helper Functions -## - - -# Conform an MRI brain image to UCHAR, RAS orientation, and 1mm isotropic voxels -def load_and_conform_image(img_filename, interpol=1): - """ - Function to load MRI image and conform it to UCHAR, RAS orientation and 1mm isotropic voxels size - (if it does not already have this format) - :param str img_filename: path and name of volume to read - :param int interpol: interpolation order for image conformation (0=nearest,1=linear(default),2=quadratic,3=cubic) - :return: - """ - orig = nib.load(img_filename) - - if not is_conform(orig): - print('Conforming image to UCHAR, RAS orientation, and 1mm isotropic voxels') - orig = conform(orig, interpol) - - # Collect header and affine information - header_info = orig.header - affine_info = orig.affine - orig = np.asarray(orig.get_data(), dtype=np.uint8) - - return header_info, affine_info, orig - - -# Transformation for mapping -def transform_axial(vol, coronal2axial=True): - """ - Function to transform volume into Axial axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2axial: transform from coronal to axial = True (default), - transform from axial to coronal = False - :return: - """ - if coronal2axial: - return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) - else: - return np.moveaxis(vol, [0, 1, 2], [2, 0, 1]) - - -def transform_sagittal(vol, coronal2sagittal=True): - """ - Function to transform volume into Sagittal axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2sagittal: transform from coronal to sagittal = True (default), - transform from sagittal to coronal = False - :return: - """ - if coronal2sagittal: - return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) - else: - return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) - - -# Thick slice generator (for eval) and blank slices filter (for training) -def get_thick_slices(img_data, slice_thickness=3): - """ - Function to extract thick slices from the image - (feed slice_thickness preceeding and suceeding slices to network, - label only middle one) - :param np.ndarray img_data: 3D MRI image read in with nibabel - :param int slice_thickness: number of slices to stack on top and below slice of interest (default=3) - :return: - """ - h, w, d = img_data.shape - img_data_pad = np.expand_dims(np.pad(img_data, ((0, 0), (0, 0), (slice_thickness, slice_thickness)), mode='edge'), - axis=3) - img_data_thick = np.ndarray((h, w, d, 0), dtype=np.uint8) - - for slice_idx in range(2 * slice_thickness + 1): - img_data_thick = np.append(img_data_thick, img_data_pad[:, :, slice_idx:d + slice_idx, :], axis=3) - - return img_data_thick - - -def filter_blank_slices_thick(img_vol, label_vol, weight_vol, threshold=50): - """ - Function to filter blank slices from the volume using the label volume - :param np.ndarray img_vol: orig image volume - :param np.ndarray label_vol: label images (ground truth) - :param np.ndarray weight_vol: weight corresponding to labels - :param int threshold: threshold for number of pixels needed to keep slice (below = dropped) - :return: - """ - # Get indices of all slices with more than threshold labels/pixels - select_slices = (np.sum(label_vol, axis=(0, 1)) > threshold) - - # Retain only slices with more than threshold labels/pixels - img_vol = img_vol[:, :, select_slices, :] - label_vol = label_vol[:, :, select_slices] - weight_vol = weight_vol[:, :, select_slices] - - return img_vol, label_vol, weight_vol - - -# weight map generator -def create_weight_mask(mapped_aseg, max_weight=5, max_edge_weight=5): - """ - Function to create weighted mask - with median frequency balancing and edge-weighting - :param mapped_aseg: - :param max_weight: - :param max_edge_weight: - :return: - """ - unique, counts = np.unique(mapped_aseg, return_counts=True) - - # Median Frequency Balancing - class_wise_weights = np.median(counts) / counts - class_wise_weights[class_wise_weights > max_weight] = max_weight - (h, w, d) = mapped_aseg.shape - - weights_mask = np.reshape(class_wise_weights[mapped_aseg.ravel()], (h, w, d)) - - # Gradient Weighting - (gx, gy, gz) = np.gradient(mapped_aseg) - grad_weight = max_edge_weight * np.asarray(np.power(np.power(gx, 2) + np.power(gy, 2) + np.power(gz, 2), 0.5) > 0, - dtype='float') - - weights_mask += grad_weight - - return weights_mask - - -# Label mapping functions (to aparc (eval) and to label (train)) -def map_label2aparc_aseg(mapped_aseg): - """ - Function to perform look-up table mapping from label space to aparc.DKTatlas+aseg space - :param np.ndarray mapped_aseg: label space segmentation (aparc.DKTatlas + aseg) - :return: - """ - aseg = np.zeros_like(mapped_aseg) - labels = np.array([0, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44, - 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, - 77, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011, - 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, - 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035, - 2002, 2005, 2010, 2012, 2013, 2014, 2016, 2017, 2021, 2022, 2023, - 2024, 2025, 2028]) - h, w, d = aseg.shape - - aseg = labels[mapped_aseg.ravel()] - - aseg = aseg.reshape((h, w, d)) - - return aseg - - -def map_aparc_aseg2label(aseg, aseg_nocc=None): - """ - Function to perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space - :param np.ndarray aseg: ground truth aparc+aseg - :param None/np.ndarray aseg_nocc: ground truth aseg without corpus callosum segmentation - :return: - """ - aseg_temp = aseg.copy() - aseg[aseg == 80] = 77 # Hypointensities Class - aseg[aseg == 85] = 0 # Optic Chiasma to BKG - aseg[aseg == 62] = 41 # Right Vessel to Right GM - aseg[aseg == 30] = 2 # Left Vessel to Left GM - aseg[aseg == 72] = 24 # 5th Ventricle to CSF - - # If corpus callosum is not removed yet, do it now - if aseg_nocc is not None: - cc_mask = (aseg >= 251) & (aseg <= 255) - aseg[cc_mask] = aseg_nocc[cc_mask] - - aseg[aseg == 3] = 0 # Map Remaining Cortical labels to background - aseg[aseg == 42] = 0 - - cortical_label_mask = (aseg >= 2000) & (aseg <= 2999) - aseg[cortical_label_mask] = aseg[cortical_label_mask] - 1000 - - # Preserve Cortical Labels - aseg[aseg_temp == 2014] = 2014 - aseg[aseg_temp == 2028] = 2028 - aseg[aseg_temp == 2012] = 2012 - aseg[aseg_temp == 2016] = 2016 - aseg[aseg_temp == 2002] = 2002 - aseg[aseg_temp == 2023] = 2023 - aseg[aseg_temp == 2017] = 2017 - aseg[aseg_temp == 2024] = 2024 - aseg[aseg_temp == 2010] = 2010 - aseg[aseg_temp == 2013] = 2013 - aseg[aseg_temp == 2025] = 2025 - aseg[aseg_temp == 2022] = 2022 - aseg[aseg_temp == 2021] = 2021 - aseg[aseg_temp == 2005] = 2005 - - labels = np.array([0, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44, - 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, - 77, 1002, 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011, - 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, - 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035, - 2002, 2005, 2010, 2012, 2013, 2014, 2016, 2017, 2021, 2022, 2023, - 2024, 2025, 2028]) - - h, w, d = aseg.shape - lut_aseg = np.zeros(max(labels) + 1, dtype='int') - for idx, value in enumerate(labels): - lut_aseg[value] = idx - - # Remap Label Classes - Perform LUT Mapping - Coronal, Axial - - mapped_aseg = lut_aseg.ravel()[aseg.ravel()] - - mapped_aseg = mapped_aseg.reshape((h, w, d)) - - # Map Sagittal Labels - aseg[aseg == 2] = 41 - aseg[aseg == 3] = 42 - aseg[aseg == 4] = 43 - aseg[aseg == 5] = 44 - aseg[aseg == 7] = 46 - aseg[aseg == 8] = 47 - aseg[aseg == 10] = 49 - aseg[aseg == 11] = 50 - aseg[aseg == 12] = 51 - aseg[aseg == 13] = 52 - aseg[aseg == 17] = 53 - aseg[aseg == 18] = 54 - aseg[aseg == 26] = 58 - aseg[aseg == 28] = 60 - aseg[aseg == 31] = 63 - - cortical_label_mask = (aseg >= 2000) & (aseg <= 2999) - aseg[cortical_label_mask] = aseg[cortical_label_mask] - 1000 - - labels_sag = np.array([0, 14, 15, 16, 24, 41, 43, 44, 46, 47, 49, - 50, 51, 52, 53, 54, 58, 60, 63, 77, 1002, - 1003, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, - 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, - 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035]) - - h, w, d = aseg.shape - lut_aseg = np.zeros(max(labels_sag) + 1, dtype='int') - for idx, value in enumerate(labels_sag): - lut_aseg[value] = idx - - # Remap Label Classes - Perform LUT Mapping - Coronal, Axial - - mapped_aseg_sag = lut_aseg.ravel()[aseg.ravel()] - - mapped_aseg_sag = mapped_aseg_sag.reshape((h, w, d)) - - return mapped_aseg, mapped_aseg_sag - - -def sagittal_coronal_remap_lookup(x): - """ - Dictionary mapping to convert left labels to corresponding right labels for aseg - :param int x: label to look up - :return: - """ - return { - 2: 41, - 3: 42, - 4: 43, - 5: 44, - 7: 46, - 8: 47, - 10: 49, - 11: 50, - 12: 51, - 13: 52, - 17: 53, - 18: 54, - 26: 58, - 28: 60, - 31: 63, - }[x] - - -def map_prediction_sagittal2full(prediction_sag, num_classes=79): - """ - Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks - (full aparc.DKTatlas+aseg.mgz) - :param prediction_sag: sagittal prediction (labels) - :param int num_classes: number of classes (96 for full classes, 79 for hemi split) - :return: Remapped prediction - """ - if num_classes == 96: - idx_list = np.asarray([0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1, 2, 3, 14, 15, 4, 16, - 17, 18, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 20, 21, 22, - 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], dtype=np.int16) - - else: - idx_list = np.asarray([0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1, 2, 3, 14, 15, 4, 16, - 17, 18, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 20, 22, 27, - 29, 30, 31, 33, 34, 38, 39, 40, 41, 42, 45], dtype=np.int16) - - prediction_full = prediction_sag[:, idx_list, :, :] - return prediction_full - - -# Clean up and class separation -def bbox_3d(img): - """ - Function to extract the three-dimensional bounding box coordinates. - :param np.ndarray img: mri image - :return: - """ - - r = np.any(img, axis=(1, 2)) - c = np.any(img, axis=(0, 2)) - z = np.any(img, axis=(0, 1)) - - rmin, rmax = np.where(r)[0][[0, -1]] - cmin, cmax = np.where(c)[0][[0, -1]] - zmin, zmax = np.where(z)[0][[0, -1]] - - return rmin, rmax, cmin, cmax, zmin, zmax - - -def get_largest_cc(segmentation): - """ - Function to find largest connected component of segmentation. - :param np.ndarray segmentation: segmentation - :return: - """ - labels = label(segmentation, connectivity=3, background=0) - - bincount = np.bincount(labels.flat) - background = np.argmax(bincount) - bincount[background] = -1 - - largest_cc = labels == np.argmax(bincount) - - return largest_cc - - -# Class Operator for image loading (orig only) -class OrigDataThickSlices(Dataset): - """ - Class to load a given image and segmentation and prepare it - for network training. - """ - def __init__(self, img_filename, orig, plane='Axial', slice_thickness=3, transforms=None): - - try: - self.img_filename = img_filename - self.plane = plane - self.slice_thickness = slice_thickness - - # Transform Data as needed - if plane == 'Sagittal': - orig = transform_sagittal(orig) - print('Loading Sagittal') - - elif plane == 'Axial': - orig = transform_axial(orig) - print('Loading Axial') - - else: - print('Loading Coronal.') - - # Create Thick Slices - orig_thick = get_thick_slices(orig, self.slice_thickness) - - # Make 4D - orig_thick = np.transpose(orig_thick, (2, 0, 1, 3)) - self.images = orig_thick - - self.count = self.images.shape[0] - - self.transforms = transforms - - print("Successfully loaded Image from {}".format(img_filename)) - - except Exception as e: - print("Loading failed. {}".format(e)) - - def __getitem__(self, index): - - img = self.images[index] - - if self.transforms is not None: - img = self.transforms(img) - - return {'image': img} - - def __len__(self): - return self.count - - -## -# Dataset loading (for training) -## - -# Operator to load hdf5-file for training -class AsegDatasetWithAugmentation(Dataset): - """ - Class for loading aseg file with augmentations (transforms) - """ - def __init__(self, params, transforms=None): - - # Load the h5 file and save it to the dataset - try: - self.params = params - - # Open file in reading mode - with h5py.File(self.params['dataset_name'], "r") as hf: - self.images = np.array(hf.get('orig_dataset')) - self.labels = np.array(hf.get('aseg_dataset')) - self.weights = np.array(hf.get('weight_dataset')) - self.subjects = np.array(hf.get("subject")) - - self.count = self.images.shape[0] - self.transforms = transforms - - print("Successfully loaded {} with plane: {}".format(params["dataset_name"], params["plane"])) - - except Exception as e: - print("Loading failed: {}".format(e)) - - def get_subject_names(self): - return self.subjects - - def __getitem__(self, index): - - img = self.images[index] - label = self.labels[index] - weight = self.weights[index] - - if self.transforms is not None: - tx_sample = self.transforms({'img': img, 'label': label, 'weight': weight}) - img = tx_sample['img'] - label = tx_sample['label'] - weight = tx_sample['weight'] - - return {'image': img, 'label': label, 'weight': weight} - - def __len__(self): - return self.count - diff --git a/fastsurfer_inference.py b/fastsurfer_inference.py new file mode 100644 index 0000000..7e7a081 --- /dev/null +++ b/fastsurfer_inference.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python + +from pathlib import Path +from argparse import ArgumentParser, Namespace, ArgumentDefaultsHelpFormatter +from chris_plugin import chris_plugin, PathMapper +from loguru import logger +from pftag import pftag +from pflog import pflog +from datetime import datetime +import sys +import os +import subprocess +LOG = logger.debug + +logger_format = ( + "{time:YYYY-MM-DD HH:mm:ss} │ " + "{level: <5} │ " + "{name: >28}::" + "{function: <30} @" + "{line: <4} ║ " + "{message}" +) +logger.remove() +logger.opt(colors = True) +logger.add(sys.stderr, format=logger_format) + +__version__ = '1.3.1' + +DISPLAY_TITLE = r""" + __ _ __ _ __ + / _| | | / _| (_) / _| +| |_ __ _ ___| |_ ___ _ _ _ __| |_ ___ _ __ _ _ __ | |_ ___ _ __ ___ _ __ ___ ___ +| _/ _` / __| __/ __| | | | '__| _/ _ \ '__| | '_ \| _/ _ \ '__/ _ \ '_ \ / __/ _ \ +| || (_| \__ \ |_\__ \ |_| | | | || __/ | | | | | | || __/ | | __/ | | | (_| __/ +|_| \__,_|___/\__|___/\__,_|_| |_| \___|_| |_|_| |_|_| \___|_| \___|_| |_|\___\___| + ______ + |______| +""" + "\t\t -- version " + __version__ + " --\n\n" + + +parser = ArgumentParser(description='A ChRIS plugin to send DICOMs to a remote PACS store', + formatter_class=ArgumentDefaultsHelpFormatter) +# Required options +# 1. Directory information (where to read from, where to write to) +parser.add_argument('--subjectDir',"--sd", + dest='subjectDir', + type=str, + help="directory (relative to ) of subjects to process", + default="") + +# 2. Options for the MRI volumes +# (name of in and output, order of interpolation if not conformed) +parser.add_argument('--iname', '--t1', + type=str, + dest='iname', + help='name of the input (raw) file to process (default: brain.mgz)', + default='brain.mgz') +parser.add_argument('--out_name', '--seg', + dest='oname', + type=str, + default='aparc.DKTatlas+aseg.deep.mgz', + help='name of the output segmented file') +parser.add_argument('--order', + dest='order', + type=int, + default=1, + help="interpolation order") + +# 3. Options for log-file and search-tag +parser.add_argument('--subject', + dest='subject', + type=str, + default="*", + help='subject(s) to process. This expression is globbed.') +parser.add_argument('--log', + dest='logfile', + type=str, + help='name of logfile (default: deep-seg.log)', + default='deep-seg.log') + +# 4. Pre-trained weights -- NB NB NB -- currently (Jan 2021) these CANNOT +# be set by an enduser. These weight files are RELATIVE/INTERNAL to the +# container +parser.add_argument('--network_sagittal_path', + dest='network_sagittal_path', + type=str, + help="path to pre-trained sagittal network weights", + default='./checkpoints/Sagittal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') +parser.add_argument('--network_coronal_path', + dest='network_coronal_path', + type=str, + help="path to pre-trained coronal network weights", + default='./checkpoints/Coronal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') +parser.add_argument('--network_axial_path', + dest='network_axial_path', + type=str, + help="path to pre-trained axial network weights", + default='./checkpoints/Axial_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') + +# 5. Clean up and GPU/CPU options (disable cuda, change batchsize) +parser.add_argument('--clean', + dest='cleanup', + type=bool, + default=True, + help="if specified, clean up segmentation") +parser.add_argument('--no_cuda', + dest='no_cuda', + type=bool, + default=False, + help='if specified, do not use GPU') +parser.add_argument('--batch_size', + dest='batch_size', + type=int, + default=8, + help="batch size for inference (default: 8") +parser.add_argument('--simple_run', + dest='simple_run', + default=False, + type=bool, + help='simplified run: only analyze one subject') + +# Adding check to parallel processing, default = false +parser.add_argument('--run_parallel', + dest='run_parallel', + type=bool, + default=False, + help='if specified, allows for execute on multiple GPUs') + +parser.add_argument('--copyInputFiles', + dest='copyInputFiles', + type=str, + default="", + help="if specified, copy i/p files matching the input regex to o/p dir") +parser.add_argument('-f', '--fileFilter', default='dcm', type=str, + help='input file filter glob') +parser.add_argument('-V', '--version', action='version', + version=f'%(prog)s {__version__}') +parser.add_argument( '--pftelDB', + dest = 'pftelDB', + default = '', + type = str, + help = 'optional pftel server DB path') + + +def preamble_show(options) -> None: + """ + Just show some preamble "noise" in the output terminal + """ + + LOG(DISPLAY_TITLE) + + LOG("plugin arguments...") + for k, v in options.__dict__.items(): + LOG("%25s: [%s]" % (k, v)) + LOG("") + + LOG("base environment...") + for k, v in os.environ.items(): + LOG("%25s: [%s]" % (k, v)) + LOG("") + + +# The main function of this *ChRIS* plugin is denoted by this ``@chris_plugin`` "decorator." +# Some metadata about the plugin is specified here. There is more metadata specified in setup.py. +# +# documentation: https://fnndsc.github.io/chris_plugin/chris_plugin.html#chris_plugin +@chris_plugin( + parser=parser, + title='An app to efficiently perform cortical parcellation and segmentation on raw brain MRI images', + category='', # ref. https://chrisstore.co/plugins + min_memory_limit='2Gi', # supported units: Mi, Gi + min_cpu_limit='4000m', # millicores, e.g. "1000m" = 1 CPU core + min_gpu_limit=0 # set min_gpu_limit=1 to enable GPU +) +@pflog.tel_logTime( + event = 'fastsurfer_inference', + log = 'Create segmentation using FastSurferCNN' +) +def main(options: Namespace, inputdir: Path, outputdir: Path): + """ + *ChRIS* plugins usually have two positional arguments: an **input directory** containing + input files and an **output directory** where to write output files. Command-line arguments + are passed to this main method implicitly when ``main()`` is called below without parameters. + + :param options: non-positional arguments parsed by the parser given to @chris_plugin + :param inputdir: directory containing (read-only) input files + :param outputdir: directory where to write output files + """ + FS_SCRIPT = "/fastsurfer/run_fastsurfer.sh" + preamble_show(options) + mapper = PathMapper.file_mapper(inputdir, outputdir, glob=f"**/{options.iname}", fail_if_empty=False) + for input_file, output_file in mapper: + l_cli_params = [FS_SCRIPT] + l_cli_params.extend(["--sd", outputdir, + "--t1", f"{input_file}", + "--sid", f"{options.subject}", + "--seg_only", + "--parallel"]) + LOG(f"Running {FS_SCRIPT} on input: {input_file.name}") + try: + subprocess.call(l_cli_params) + except Exception as ex: + subprocess.call(FS_SCRIPT) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/fastsurfer_inference/__init__.py b/fastsurfer_inference/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fastsurfer_inference/__main__.py b/fastsurfer_inference/__main__.py deleted file mode 100644 index 18d646f..0000000 --- a/fastsurfer_inference/__main__.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastsurfer_inference.fastsurfer_inference import Fastsurfer_inference - - -def main(): - chris_app = Fastsurfer_inference() - chris_app.launch() - - -if __name__ == "__main__": - main() diff --git a/fastsurfer_inference/fastsurfer_inference.py b/fastsurfer_inference/fastsurfer_inference.py deleted file mode 100644 index 174a3ac..0000000 --- a/fastsurfer_inference/fastsurfer_inference.py +++ /dev/null @@ -1,782 +0,0 @@ -# -# fastsurfer_inference ds ChRIS plugin app -# -# (c) 2022 Fetal-Neonatal Neuroimaging & Developmental Science Center -# Boston Children's Hospital -# -# http://childrenshospital.org/FNNDSC/ -# dev@babyMRI.org -# - -# IMPORTS -import optparse -import nibabel as nib -import numpy as np -import time -import sys -import glob -import os -import os.path as op -import logging -import torch -import torch.nn as nn -import shutil - -from torch.autograd import Variable -from torch.utils.data.dataloader import DataLoader -from torchvision import transforms, utils - -from scipy.ndimage.filters import median_filter, gaussian_filter -from skimage.measure import label, regionprops -from skimage.measure import label - -from collections import OrderedDict -from os import makedirs - -from data_loader.load_neuroimaging_data import OrigDataThickSlices -from data_loader.load_neuroimaging_data import map_label2aparc_aseg -from data_loader.load_neuroimaging_data import map_prediction_sagittal2full -from data_loader.load_neuroimaging_data import get_largest_cc -from data_loader.load_neuroimaging_data import load_and_conform_image - -from data_loader.augmentation import ToTensorTest - -from models.networks import FastSurferCNN - -sys.path.append(os.path.dirname(__file__)) - -# import the Chris app superclass -from chrisapp.base import ChrisApp - - -Gstr_title = r""" - __ _ __ _ __ - / _| | | / _| (_) / _| -| |_ __ _ ___| |_ ___ _ _ _ __| |_ ___ _ __ _ _ __ | |_ ___ _ __ ___ _ __ ___ ___ -| _/ _` / __| __/ __| | | | '__| _/ _ \ '__| | '_ \| _/ _ \ '__/ _ \ '_ \ / __/ _ \ -| || (_| \__ \ |_\__ \ |_| | | | || __/ | | | | | | || __/ | | __/ | | | (_| __/ -|_| \__,_|___/\__|___/\__,_|_| |_| \___|_| |_|_| |_|_| \___|_| \___|_| |_|\___\___| - ______ - |______| -""" - -Gstr_synopsis = """ - - NAME - - fastsurfer_inference.py - - SYNOPSIS - - python fastsurfer_inference.py \\ - [--subjectDir ] \\ - [--subject ] \\ - [--in_name ] \\ - [--out_name ] \\ - [--log ] \\ - [--clean] \\ - [--no_cuda] \\ - [--batch_size ] \\ - [-v ] [--verbosity ] \\ - [--version] \\ - [--man] \\ - [--meta] \\ - - - - BRIEF EXAMPLE - - * Bare bones execution: - - Assume that the ``in`` directory contains a single subject - ``subdir``, which in turn contains single file, called - ``brain.mgz``: - - mkdir out && chmod 777 out - docker run --rm \\ - -v $(pwd)/in:/incoming -v $(pwd)/out:/outgoing \\ - fnndsc/pl-fastsurfer_inference \\ - fastsurfer_inference.py \\ - /incoming /outgoing - - DESCRIPTION - - ``fastsurfer_inference.py`` is a ChRIS plugin constructed around the - FastSurfer application developed by the "Deep Medical Imaging Lab" - (https://deep-mi.org). For the publication describing this method, see - - Henschel L, Conjeti S, Estrada S, Diers K, Fischl B, Reuter M. - "FastSurfer - A fast and accurate deep learning based neuroimaging - pipeline." NeuroImage. 2020. - - https://arxiv.org/pdf/2009.04392.pdf - https://deep-mi.org/static/pub/ewert_2020.bib - - This specific plugin wraps around many of the FastSurfer CLI in - addition to providing ChRIS plugin-specific CLI. It's primary purpose - is to segment and input brain image file (in FreeSurfer ``mgz`` format), - creating in turn as output a labeled ``mgz`` format file. - - ARGS - - [--subjectDir ] - By default, the is assumed to be the . However, - the can be nested relative to the , and can thus - be specified with this flag. - - The is assumed by default to contain one level of sub - directory, and these sub dirs, considered the ``subjects``, each contain - a single ``mgz`` to process. - - [--subject ] - This can denote a sub-set of subject(s) (i.e. sub directory within the - ). The is "globbed", so an expression - like ``--subject 10*`` would process all ``subjects`` starting with the - text string ``10``. Note to protect from shell expansion of wildcard - characters, the argument should be protected in single quotes. - - [--in_name ] - The name of the raw ``.mgz`` file of a subject. The default value is - ``brain.mgz``. The full path to the is constructed - by concatenating - - ``///`` - - [--out_name ] - The order of interpolation: - - 0 = nearest - 1 = linear (default) - 2 = quadratic - 3 = cubic - - [--log ] - The name of the log file containing inference info. Default value is - - ``deep-seg.log`` - - [--clean] - If specified, clean the segmentation. - - [--no_cuda] - If specified, run on CPU, not GPU. Depending on CPU/GPU, your apparent - mileage will vary, but expect orders longer time than compared to a - GPU. - - For example, in informal testing, GPU takes about a minute per - subject, while CPU approximately 1.5 hours per subject! - - [--batch_size ] - If specified, copies the input files matching the keywords to output dir. This can be - useful to create an easy association between a given input files and the - segmented output. - - [-v ] [--verbosity ] - Verbosity level for app. Not used currently. - - [--version] - If specified, print version number. - - [--man] - If specified, print (this) man page. - - [--meta] - If specified, print plugin meta data. -""" - - -class Fastsurfer_inference(ChrisApp): - """ - An app to efficiently perform cortical parcellation and segmentation - on raw brain MRI images, using FastSurfer developed by the - Deep Medical Imaging Lab (https://deep-mi.org). - """ - - PACKAGE = __package__ - TITLE = 'An inference app of FastSurfer' - CATEGORY = '' - TYPE = 'ds' - ICON = '' # url of an icon image - MAX_NUMBER_OF_WORKERS = 1 # Override with integer value - MIN_NUMBER_OF_WORKERS = 1 # Override with integer value - MIN_CPU_LIMIT = 2000 # Override with millicore value as string, e.g. '2000m' - MIN_MEMORY_LIMIT = 2000 # Override with string, e.g. '1Gi', '2000Mi' - MIN_GPU_LIMIT = 0 # Override with the minimum number of GPUs, as an integer, for your plugin - MAX_GPU_LIMIT = 0 # Override with the maximum number of GPUs, as an integer, for your plugin - - # Use this dictionary structure to provide key-value output descriptive information - # that may be useful for the next downstream plugin. For example: - # - # { - # "finalOutputFile": "final/file.out", - # "viewer": "genericTextViewer", - # } - # - # The above dictionary is saved when plugin is called with a ``--saveoutputmeta`` - # flag. Note also that all file paths are relative to the system specified - # output directory. - OUTPUT_META_DICT = {} - - def define_parameters(self): - """ - Define the CLI arguments accepted by this plugin app. - Use self.add_argument to specify a new app argument. - """ - # Requiered options - # 1. Directory information (where to read from, where to write to) - self.add_argument( '--subjectDir', - dest = 'subjectDir', - type = str, - optional = True, - help = "directory (relative to ) of subjects to process", - default = "") - - # 2. Options for the MRI volumes - # (name of in and output, order of interpolation if not conformed) - self.add_argument( '--in_name', '-i', - type = str, - dest = 'iname', - help = 'name of the input (raw) file to process (default: brain.mgz)', - optional = True, - default = 'brain.mgz') - self.add_argument( '--out_name', '-o', - dest = 'oname', - type = str, - optional = True, - default = 'aparc.DKTatlas+aseg.deep.mgz', - help = 'name of the output segmented file') - self.add_argument( '--order', - dest = 'order', - type = int, - default = 1, - optional = True, - help = "interpolation order") - - # 3. Options for log-file and search-tag - self.add_argument( '--subject', - dest = 'subject', - type = str, - default = "*", - optional = True, - help = 'subject(s) to process. This expression is globbed.') - self.add_argument( '--log', - dest = 'logfile', - type = str, - optional = True, - help = 'name of logfile (default: deep-seg.log)', - default = 'deep-seg.log') - - # 4. Pre-trained weights -- NB NB NB -- currently (Jan 2021) these CANNOT - # be set by an enduser. These weight files are RELATIVE/INTERNAL to the - # container - self.add_argument( '--network_sagittal_path', - dest = 'network_sagittal_path', - type = str, - optional = True, - help = "path to pre-trained sagittal network weights", - default = './checkpoints/Sagittal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') - self.add_argument( '--network_coronal_path', - dest = 'network_coronal_path', - type = str, - optional = True, - help = "path to pre-trained coronal network weights", - default = './checkpoints/Coronal_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') - self.add_argument( '--network_axial_path', - dest = 'network_axial_path', - type = str, - optional = True, - help = "path to pre-trained axial network weights", - default = './checkpoints/Axial_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl') - - # 5. Clean up and GPU/CPU options (disable cuda, change batchsize) - self.add_argument( '--clean', - dest = 'cleanup', - type = bool, - optional = True, - default = True, - help = "if specified, clean up segmentation", - action = 'store_true') - self.add_argument( '--no_cuda', - dest = 'no_cuda', - action = 'store_true', - type = bool, - optional = True, - default = False, - help = 'if specified, do not use GPU') - self.add_argument( '--batch_size', - dest = 'batch_size', - type = int, - default = 8, - optional = True, - help = "batch size for inference (default: 8") - self.add_argument( '--simple_run', - dest = 'simple_run', - action = 'store_true', - optional = True, - default = False, - type = bool, - help = 'simplified run: only analyze one subject') - - # Adding check to parallel processing, default = false - self.add_argument( '--run_parallel', - dest = 'run_parallel', - type = bool, - action = 'store_true', - optional = True , - default = False, - help = 'if specified, allows for execute on multiple GPUs') - - self.add_argument( '--copyInputFiles', - dest = 'copyInputFiles', - type = str, - default = "", - optional = True, - help = "if specified, copy i/p files matching the input regex to o/p dir") - - def fast_surfer_cnn(self,img_filename, save_as, logger, args): - """ - Cortical parcellation of single image - :param str img_filename: name of image file - :param parser.Options args: Arguments (passed via command line) to set up networks - * args.network_sagittal_path: path to sagittal checkpoint (stored pretrained network) - * args.network_coronal_path: path to coronal checkpoint (stored pretrained network) - * args.network_axial_path: path to axial checkpoint (stored pretrained network) - * args.cleanup: Whether to clean up the segmentation (medial filter on certain labels) - * args.no_cuda: Whether to use CUDA (GPU) or not (CPU) - :param logging.logger logger: Logging instance info messages will be written to - :param str save_as: name under which to save prediction. - :return None: saves prediction to save_as - """ - start_total = time.time() - logger.info("Reading volume {}".format(img_filename)) - - header_info, affine_info, orig_data = load_and_conform_image(img_filename, interpol=args.order) - - # Copy file(s) from input dir to output dir subject wise - if args.copyInputFiles !="": - - # define source and destination dir - source_dir_path = img_filename.replace(args.iname,'') - dest_dir_path = save_as.replace(args.oname,'') - - # search for files using the given keyword - found_files = glob.glob(source_dir_path+"/*%s*" %args.copyInputFiles, recursive = True) - - # now iteratively copy the files - for file_path in found_files: - logger.info("Copying files from %s" %file_path) - shutil.copy(file_path, dest_dir_path) - - - transform_test = transforms.Compose([ToTensorTest()]) - - test_dataset_axial = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Axial') - test_dataset_sagittal = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Sagittal') - test_dataset_coronal = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Coronal') - - start = time.time() - - test_data_loader = DataLoader(dataset=test_dataset_axial, batch_size=args.batch_size, shuffle=False) - - # Axial View Testing - params_network = {'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, - 'stride_pool': 2, 'num_classes': 79, 'kernel_c': 1, 'kernel_d': 1} - - # Select the model - model = FastSurferCNN(params_network) - - # Put it onto the GPU or CPU - use_cuda = not args.no_cuda and torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - logger.info("Cuda available: {}, # Available GPUS: {}, " - "Cuda user disabled (--no_cuda flag): {}, --> Using device: {}".format(torch.cuda.is_available(), - torch.cuda.device_count(), - args.no_cuda, device)) - - if use_cuda and torch.cuda.device_count() > 1 and args.run_parallel: - model = nn.DataParallel(model) - model_parallel = True - else: - model_parallel = False - - model.to(device) - - # Set up state dict (remapping of names, if not multiple GPUs/CPUs) - logger.info("Loading Axial Net from {}".format(args.network_axial_path)) - - model_state = torch.load(args.network_axial_path, map_location=device) - new_state_dict = OrderedDict() - - for k, v in model_state["model_state_dict"].items(): - - if k[:7] == "module." and not model_parallel: - new_state_dict[k[7:]] = v - - elif k[:7] != "module." and model_parallel: - new_state_dict["module." + k] = v - - else: - new_state_dict[k] = v - - model.load_state_dict(new_state_dict) - - model.eval() - prediction_probability_axial = torch.zeros((256, params_network["num_classes"], 256, 256), dtype=torch.float) - - logger.info("Axial model loaded.") - with torch.no_grad(): - - start_index = 0 - for batch_idx, sample_batch in enumerate(test_data_loader): - images_batch = Variable(sample_batch["image"]) - - if use_cuda: - images_batch = images_batch.cuda() - - temp = model(images_batch) - - prediction_probability_axial[start_index:start_index + temp.shape[0]] = temp.cpu() - start_index += temp.shape[0] - logger.info("--->Batch {} Axial Testing Done.".format(batch_idx)) - - logger.info("Axial View Tested in {:0.4f} seconds".format(time.time() - start)) - - # Coronal View Testing - start = time.time() - - test_data_loader = DataLoader(dataset=test_dataset_coronal, batch_size=args.batch_size, shuffle=False) - - params_network = {'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, - 'stride_pool': 2, 'num_classes': 79, 'kernel_c': 1, 'kernel_d': 1} - - # Select the model - - model = FastSurferCNN(params_network) - - # Put it onto the GPU or CPU - use_cuda = not args.no_cuda and torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - - if use_cuda and torch.cuda.device_count() > 1 and args.run_parallel: - model = nn.DataParallel(model) - model_parallel = True - else: - model_parallel = False - - model.to(device) - - # Set up new state dict (remapping of names, if not multiple GPUs/CPUs) - logger.info("Loading Coronal Net from {}".format(args.network_coronal_path)) - - model_state = torch.load(args.network_coronal_path, map_location=device) - new_state_dict = OrderedDict() - - for k, v in model_state["model_state_dict"].items(): - - if k[:7] == "module." and not model_parallel: - new_state_dict[k[7:]] = v - - elif k[:7] != "module." and model_parallel: - new_state_dict["module." + k] = v - - else: - new_state_dict[k] = v - - model.load_state_dict(new_state_dict) - - model.eval() - prediction_probability_coronal = torch.zeros((256, params_network["num_classes"], 256, 256), dtype=torch.float) - - logger.info("Coronal model loaded.") - start_index = 0 - with torch.no_grad(): - - for batch_idx, sample_batch in enumerate(test_data_loader): - - images_batch = Variable(sample_batch["image"]) - - if use_cuda: - images_batch = images_batch.cuda() - - temp = model(images_batch) - - prediction_probability_coronal[start_index:start_index + temp.shape[0]] = temp.cpu() - start_index += temp.shape[0] - logger.info("--->Batch {} Coronal Testing Done.".format(batch_idx)) - - logger.info("Coronal View Tested in {:0.4f} seconds".format(time.time() - start)) - - start = time.time() - - test_data_loader = DataLoader(dataset=test_dataset_sagittal, batch_size=args.batch_size, shuffle=False) - - params_network = {'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, - 'stride_pool': 2, 'num_classes': 51, 'kernel_c': 1, 'kernel_d': 1} - - # Select the model - model = FastSurferCNN(params_network) - - # Put it onto the GPU or CPU - use_cuda = not args.no_cuda and torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - - if use_cuda and torch.cuda.device_count() > 1 and args.run_parallel: - model = nn.DataParallel(model) - model_parallel = True - else: - model_parallel = False - - model.to(device) - - # Set up new state dict (remapping of names, if not multiple GPUs/CPUs) - logger.info("Loading Sagittal Net from {}".format(args.network_sagittal_path)) - - model_state = torch.load(args.network_sagittal_path, map_location=device) - new_state_dict = OrderedDict() - - for k, v in model_state["model_state_dict"].items(): - - if k[:7] == "module." and not model_parallel: - new_state_dict[k[7:]] = v - - elif k[:7] != "module." and model_parallel: - new_state_dict["module." + k] = v - - else: - new_state_dict[k] = v - - model.load_state_dict(new_state_dict) - - model.eval() - prediction_probability_sagittal = torch.zeros((256, params_network["num_classes"], 256, 256), dtype=torch.float) - - start_index = 0 - with torch.no_grad(): - - for batch_idx, sample_batch in enumerate(test_data_loader): - - images_batch = Variable(sample_batch["image"]) - - if use_cuda: - images_batch = images_batch.cuda() - - temp = model(images_batch) - - prediction_probability_sagittal[start_index:start_index + temp.shape[0]] = temp.cpu() - start_index += temp.shape[0] - logger.info("--->Batch {} Sagittal Testing Done.".format(batch_idx)) - - prediction_probability_sagittal = map_prediction_sagittal2full(prediction_probability_sagittal) - - logger.info("Sagittal View Tested in {:0.4f} seconds".format(time.time() - start)) - - del model, test_dataset_axial, test_dataset_coronal, test_dataset_sagittal, test_data_loader - - start = time.time() - - # Start View Aggregation: change from N,C,X,Y to coronal view with C in last dimension = H,W,D,C - prediction_probability_axial = prediction_probability_axial.permute(3, 0, 2, 1) - prediction_probability_coronal = prediction_probability_coronal.permute(2, 3, 0, 1) - prediction_probability_sagittal = prediction_probability_sagittal.permute(0, 3, 2, 1) - - intermediate_img = torch.add(prediction_probability_axial, prediction_probability_coronal) - del prediction_probability_axial, prediction_probability_coronal - - _, prediction_image = torch.max(torch.add(torch.mul(intermediate_img, 0.4), - torch.mul(prediction_probability_sagittal, 0.2)), 3) - - del prediction_probability_sagittal, intermediate_img - - prediction_image = prediction_image.numpy() - - end = time.time() - start - logger.info("View Aggregation finished in {:0.4f} seconds.".format(end)) - - prediction_image = map_label2aparc_aseg(prediction_image) - - # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 - rh_wm = get_largest_cc(prediction_image == 41) - lh_wm = get_largest_cc(prediction_image == 2) - rh_wm = regionprops(label(rh_wm, background=0)) - lh_wm = regionprops(label(lh_wm, background=0)) - centroid_rh = np.asarray(rh_wm[0].centroid) - centroid_lh = np.asarray(lh_wm[0].centroid) - - labels_list = np.array([1003, 1006, 1007, 1008, 1009, 1011, - 1015, 1018, 1019, 1020, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035]) - - for label_current in labels_list: - - label_img = label(prediction_image == label_current, connectivity=3, background=0) - - for region in regionprops(label_img): - - if region.label != 0: # To avoid background - - if np.linalg.norm(np.asarray(region.centroid) - centroid_rh) < np.linalg.norm( - np.asarray(region.centroid) - centroid_lh): - mask = label_img == region.label - prediction_image[mask] = label_current + 1000 - - # Quick Fixes for overlapping classes - aseg_lh = gaussian_filter(1000 * np.asarray(prediction_image == 2, dtype=np.float), sigma=3) - aseg_rh = gaussian_filter(1000 * np.asarray(prediction_image == 41, dtype=np.float), sigma=3) - - lh_rh_split = np.argmax(np.concatenate((np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)), axis=3), - axis=3) - - # Problematic classes: 1026, 1011, 1029, 1019 - for prob_class_lh in [1011, 1019, 1026, 1029]: - prob_class_rh = prob_class_lh + 1000 - mask_lh = ((prediction_image == prob_class_lh) | (prediction_image == prob_class_rh)) & (lh_rh_split == 0) - mask_rh = ((prediction_image == prob_class_lh) | (prediction_image == prob_class_rh)) & (lh_rh_split == 1) - - prediction_image[mask_lh] = prob_class_lh - prediction_image[mask_rh] = prob_class_rh - - # Clean-Up - if args.cleanup is True: - - labels = [2, 4, 5, 7, 8, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44, - 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, - 77, 1026, 2026] - - start = time.time() - prediction_image_medfilt = median_filter(prediction_image, size=(3, 3, 3)) - mask = np.zeros_like(prediction_image) - tolerance = 25 - - for current_label in labels: - current_class = (prediction_image == current_label) - label_image = label(current_class, connectivity=3) - - for region in regionprops(label_image): - - if region.area <= tolerance: - mask_label = (label_image == region.label) - mask[mask_label] = 1 - - prediction_image[mask == 1] = prediction_image_medfilt[mask == 1] - logger.info("Segmentation Cleaned up in {:0.4f} seconds.".format(time.time() - start)) - - # Saving image - header_info.set_data_dtype(np.int16) - mapped_aseg_img = nib.MGHImage(prediction_image, affine_info, header_info) - mapped_aseg_img.to_filename(save_as) - - logger.info("Saving Segmentation to {}".format(save_as)) - logger.info("Total processing time: {:0.4f} seconds.".format(time.time() - start_total)) - - - - def run(self, options): - """ - Define the code to be run by this plugin app. - """ - print(Gstr_title) - print('Version: %s' % self.get_version()) - - # Output the space of CLI - d_options = vars(options) - for k,v in d_options.items(): - print("%20s: %-40s" % (k, v)) - print("") - - # Command Line options and error checking done here - # options = options_parse() - - # Set up the logger - logger = logging.getLogger("eval") - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler(stream=sys.stdout)) - - if options.simple_run: - - # Check if output subject directory exists and create it otherwise - sub_dir, out = op.split(options.oname) - - if not op.exists(sub_dir): - makedirs(sub_dir) - - fast_surfer_cnn(options.iname, options.oname, logger, options) - - else: - - # Prepare subject list to be processed - if options.subjectDir != "": - search_path = op.join( options.inputdir, - options.subjectDir, - options.subject) - subject_directories = glob.glob(search_path) - - else: - search_path = op.join(options.inputdir, options.subject) - subject_directories = glob.glob(search_path) - - # Report number of subjects to be processed and loop over them - data_set_size = len(subject_directories) - logger.info("Total Dataset Size is {}".format(data_set_size)) - - for current_subject in subject_directories: - - subject = current_subject.split("/")[-1] - - # Define volume to process, log-file and name under which final prediction will be saved - - - invol = op.join(current_subject, options.iname) - logfile = op.join(options.outputdir, subject, options.logfile) - save_file_name = op.join(options.outputdir, subject, options.oname) - - - logger.info("Running Fast Surfer on {}".format(subject)) - - # Check if output subject directory exists and create it otherwise - sub_dir, out = op.split(save_file_name) - - if not op.exists(sub_dir): - makedirs(sub_dir) - - # Prepare the log-file (logging to File in subject directory) - fh = logging.FileHandler(logfile, mode='w') - logger.addHandler(fh) - - # Run network - self.fast_surfer_cnn(invol, save_file_name, logger, options) - - logger.removeHandler(fh) - fh.close() - - def show_man_page(self): - """ - Print the app's man page. - """ - print(Gstr_synopsis) - - -# ENTRYPOINT -if __name__ == "__main__": - chris_app = Fastsurfer_inference() - chris_app.launch() diff --git a/fastsurfer_inference/generate_hdf5.py b/fastsurfer_inference/generate_hdf5.py deleted file mode 100644 index 7e6ee78..0000000 --- a/fastsurfer_inference/generate_hdf5.py +++ /dev/null @@ -1,192 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import time -import glob -import h5py -import numpy as np -import nibabel as nib - -from os.path import join -from data_loader.load_neuroimaging_data import map_aparc_aseg2label, create_weight_mask, transform_sagittal, \ - transform_axial, get_thick_slices, filter_blank_slices_thick - - -# Class to create hdf5-file -class PopulationDataset: - """ - Class to load all images in a directory into a hdf5-file. - """ - - def __init__(self, params): - - self.height = params["height"] - self.width = params["width"] - self.dataset_name = params["dataset_name"] - self.data_path = params["data_path"] - self.slice_thickness = params["thickness"] - self.orig_name = params["image_name"] - self.aparc_name = params["gt_name"] - self.aparc_nocc = params["gt_nocc"] - - if params["csv_file"] is not None: - with open(params["csv_file"], "r") as s_dirs: - self.subject_dirs = [line.strip() for line in s_dirs.readlines()] - - else: - self.search_pattern = join(self.data_path, params["pattern"]) - self.subject_dirs = glob.glob(self.search_pattern) - - self.data_set_size = len(self.subject_dirs) - - def create_hdf5_dataset(self, plane='axial', is_small=False): - """ - Function to store all images in a given directory (or pattern) in a hdf5-file. - :param str plane: which plane is processed (coronal, axial or saggital) - :param bool is_small: small hdf5-file for pretraining? - :return: - """ - start_d = time.time() - - # Prepare arrays to hold the data - orig_dataset = np.ndarray(shape=(256, 256, 0, 2 * self.slice_thickness + 1), dtype=np.uint8) - aseg_dataset = np.ndarray(shape=(256, 256, 0), dtype=np.uint8) - weight_dataset = np.ndarray(shape=(256, 256, 0), dtype=np.float) - subjects = [] - - # Loop over all subjects and load orig, aseg and create the weights - for idx, current_subject in enumerate(self.subject_dirs): - - try: - start = time.time() - - print("Volume Nr: {} Processing MRI Data from {}/{}".format(idx, current_subject, self.orig_name)) - - # Load orig and aseg - orig = nib.load(join(current_subject, self.orig_name)) - orig = np.asarray(orig.get_data(), dtype=np.uint8) - - aseg = nib.load(join(current_subject, self.aparc_name)) - - print('Processing ground truth segmentation {}'.format(self.aparc_name)) - aseg = np.asarray(aseg.get_data(), dtype=np.int16) - - if self.aparc_nocc is not None: - aseg_nocc = nib.load(join(current_subject, self.aparc_nocc)) - aseg_nocc = np.asarray(aseg_nocc.get_data(), dtype=np.int16) - - else: - aseg_nocc = None - - # Map aseg to label space and create weight masks - if plane == 'sagittal': - _, mapped_aseg = map_aparc_aseg2label(aseg, aseg_nocc) - weights = create_weight_mask(mapped_aseg) - orig = transform_sagittal(orig) - mapped_aseg = transform_sagittal(mapped_aseg) - weights = transform_sagittal(weights) - - else: - mapped_aseg, _ = map_aparc_aseg2label(aseg, aseg_nocc) - weights = create_weight_mask(mapped_aseg) - - # Transform Data as needed (swap axis for axial view) - if plane == 'axial': - orig = transform_axial(orig) - mapped_aseg = transform_axial(mapped_aseg) - weights = transform_axial(weights) - - # Create Thick Slices, filter out blanks - orig_thick = get_thick_slices(orig, self.slice_thickness) - orig, mapped_aseg, weights = filter_blank_slices_thick(orig_thick, mapped_aseg, weights) - - # Append finally processed images to arrays - orig_dataset = np.append(orig_dataset, orig, axis=2) - aseg_dataset = np.append(aseg_dataset, mapped_aseg, axis=2) - weight_dataset = np.append(weight_dataset, weights, axis=2) - - sub_name = current_subject.split("/")[-1] - subjects.append(sub_name.encode("ascii", "ignore")) - - end = time.time() - start - - print("Volume: {} Finished Data Reading and Appending in {:.3f} seconds.".format(idx, end)) - - if is_small and idx == 2: - break - - except Exception as e: - print("Volume: {} Failed Reading Data. Error: {}".format(idx, e)) - continue - - # Transpose to N, H, W, C and expand_dims for image - - orig_dataset = np.transpose(orig_dataset, (2, 0, 1, 3)) - aseg_dataset = np.transpose(aseg_dataset, (2, 0, 1)) - weight_dataset = np.transpose(weight_dataset, (2, 0, 1)) - - # Write the hdf5 file - with h5py.File(self.dataset_name, "w") as hf: - hf.create_dataset('orig_dataset', data=orig_dataset, compression='gzip') - hf.create_dataset('aseg_dataset', data=aseg_dataset, compression='gzip') - hf.create_dataset('weight_dataset', data=weight_dataset, compression='gzip') - - dt = h5py.special_dtype(vlen=str) - hf.create_dataset("subject", data=subjects, dtype=dt, compression="gzip") - - end_d = time.time() - start_d - print("Successfully written {} in {:.3f} seconds.".format(self.dataset_name, end_d)) - - -if __name__ == "__main__": - - import argparse - - # Training settings - parser = argparse.ArgumentParser(description='HDF5-Creation') - - parser.add_argument('--hdf5_name', type=str, default="testsuite_2.hdf5", - help='path and name of hdf5-dataset (default: testsuite_2.hdf5)') - parser.add_argument('--plane', type=str, default="axial", choices=["axial", "coronal", "sagittal"], - help="Which plane to put into file (axial (default), coronal or sagittal)") - parser.add_argument('--height', type=int, default=256, help='Height of Image (Default 256)') - parser.add_argument('--width', type=int, default=256, help='Width of Image (Default 256)') - parser.add_argument('--data_dir', type=str, default="/testsuite", help="Directory with images to load") - parser.add_argument('--thickness', type=int, default=3, help="Number of pre- and succeeding slices (default: 3)") - parser.add_argument('--csv_file', type=str, default=None, help="Csv-file listing subjects to include in file") - parser.add_argument('--pattern', type=str, help="Pattern to match files in directory.") - parser.add_argument('--image_name', type=str, default="mri/orig.mgz", - help="Default name of original images. FreeSurfer orig.mgz is default (mri/orig.mgz)") - parser.add_argument('--gt_name', type=str, default="mri/aparc.DKTatlas+aseg.mgz", - help="Default name for ground truth segmentations. Default: mri/aparc.DKTatlas+aseg.mgz." - " If Corpus Callosum segmentation is already removed, do not set gt_nocc." - " (e.g. for our internal training set mri/aparc.DKTatlas+aseg.filled.mgz exists already" - " and should be used here instead of mri/aparc.DKTatlas+aseg.mgz). ") - parser.add_argument('--gt_nocc', type=str, default=None, - help="Segmentation without corpus callosum (used to mask this segmentation in ground truth)." - " If the used segmentation was already processed, do not set this argument." - " For a normal FreeSurfer input, use mri/aseg.auto_noCCseg.mgz.") - - args = parser.parse_args() - - network_params = {"dataset_name": args.hdf5_name, "height": args.height, "width": args.width, - "data_path": args.data_dir, "thickness": args.thickness, "csv_file": args.csv_file, - "pattern": args.pattern, "image_name": args.image_name, - "gt_name": args.gt_name, "gt_nocc": args.gt_nocc} - - stratified = PopulationDataset(network_params) - stratified.create_hdf5_dataset(plane=args.plane) diff --git a/fastsurfer_inference/tests/__init__.py b/fastsurfer_inference/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fastsurfer_inference/tests/test_fastsurfer_inference.py b/fastsurfer_inference/tests/test_fastsurfer_inference.py deleted file mode 100644 index ee7204d..0000000 --- a/fastsurfer_inference/tests/test_fastsurfer_inference.py +++ /dev/null @@ -1,33 +0,0 @@ - -from unittest import TestCase -from unittest import mock -from fastsurfer_inference.fastsurfer_inference import Fastsurfer_inference - - -class Fastsurfer_inferenceTests(TestCase): - """ - Test Fastsurfer_inference. - """ - def setUp(self): - self.app = Fastsurfer_inference() - - def test_run(self): - """ - Test the run code. - """ - args = [] - if self.app.TYPE == 'ds': - args.append('inputdir') # you may want to change this inputdir mock - args.append('outputdir') # you may want to change this outputdir mock - - # you may want to add more of your custom defined optional arguments to test - # your app with - # eg. - # args.append('--custom-int') - # args.append(10) - - options = self.app.parse_args(args) - self.app.run(options) - - # write your own assertions - self.assertEqual(options.outputdir, 'outputdir') diff --git a/models/__init__.py b/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/models/losses.py b/models/losses.py deleted file mode 100644 index 0ec9e54..0000000 --- a/models/losses.py +++ /dev/null @@ -1,107 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import torch -import torch.nn as nn -from torch.nn.modules.loss import _Loss -import torch.nn.functional as F - - -class DiceLoss(_Loss): - """ - Dice Loss - """ - - def forward(self, output, target, weights=None, ignore_index=None): - """ - :param output: N x C x H x W Variable - :param target: N x C x W LongTensor with starting class at 0 - :param weights: C FloatTensor with class wise weights - :param int ignore_index: ignore label with index x in the loss calculation - :return: - """ - eps = 0.001 - - encoded_target = output.detach() * 0 - - if ignore_index is not None: - mask = target == ignore_index - target = target.clone() - target[mask] = 0 - encoded_target.scatter_(1, target.unsqueeze(1), 1) - mask = mask.unsqueeze(1).expand_as(encoded_target) - encoded_target[mask] = 0 - - else: - encoded_target.scatter_(1, target.unsqueeze(1), 1) - - if weights is None: - weights = 1 - - intersection = output * encoded_target - numerator = 2 * intersection.sum(0).sum(1).sum(1) - denominator = output + encoded_target - - if ignore_index is not None: - denominator[mask] = 0 - - denominator = denominator.sum(0).sum(1).sum(1) + eps - loss_per_channel = weights * (1 - (numerator / denominator)) # Channel-wise weights - - return loss_per_channel.sum() / output.size(1) - - -class CrossEntropy2D(nn.Module): - """ - 2D Cross-entropy loss implemented as negative log likelihood - """ - - def __init__(self, weight=None, reduction='none'): - super(CrossEntropy2D, self).__init__() - self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) - - def forward(self, inputs, targets): - return self.nll_loss(inputs, targets) - - -class CombinedLoss(nn.Module): - """ - For CrossEntropy the input has to be a long tensor - Args: - -- inputx N x C x H x W - -- target - N x H x W - int type - -- weight - N x H x W - float - """ - - def __init__(self, weight_dice=1, weight_ce=1): - super(CombinedLoss, self).__init__() - self.cross_entropy_loss = CrossEntropy2D() - self.dice_loss = DiceLoss() - self.weight_dice = weight_dice - self.weight_ce = weight_ce - - def forward(self, inputx, target, weight): - target = target.type(torch.LongTensor) # Typecast to long tensor - if inputx.is_cuda: - target = target.cuda() - - input_soft = F.softmax(inputx, dim=1) # Along Class Dimension - dice_val = torch.mean(self.dice_loss(input_soft, target)) - ce_val = torch.mean(torch.mul(self.cross_entropy_loss.forward(inputx, target), weight)) - total_loss = torch.add(torch.mul(dice_val, self.weight_dice), torch.mul(ce_val, self.weight_ce)) - - return total_loss, dice_val, ce_val diff --git a/models/networks.py b/models/networks.py deleted file mode 100644 index e2b4700..0000000 --- a/models/networks.py +++ /dev/null @@ -1,83 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import torch.nn as nn -import models.sub_module as sm - - -class FastSurferCNN(nn.Module): - """ - Network Definition of Fully Competitive Network network - * Spatial view aggregation (input 7 slices of which only middle one gets segmented) - * Same Number of filters per layer (normally 64) - * Dense Connections in blocks - * Unpooling instead of transpose convolutions - * Concatenationes are replaced with Maxout (competitive dense blocks) - * Global skip connections are fused by Maxout (global competition) - * Loss Function (weighted Cross-Entropy and dice loss) - """ - def __init__(self, params): - super(FastSurferCNN, self).__init__() - - # Parameters for the Descending Arm - self.encode1 = sm.CompetitiveEncoderBlockInput(params) - params['num_channels'] = params['num_filters'] - self.encode2 = sm.CompetitiveEncoderBlock(params) - self.encode3 = sm.CompetitiveEncoderBlock(params) - self.encode4 = sm.CompetitiveEncoderBlock(params) - self.bottleneck = sm.CompetitiveDenseBlock(params) - - # Parameters for the Ascending Arm - params['num_channels'] = params['num_filters'] - self.decode4 = sm.CompetitiveDecoderBlock(params) - self.decode3 = sm.CompetitiveDecoderBlock(params) - self.decode2 = sm.CompetitiveDecoderBlock(params) - self.decode1 = sm.CompetitiveDecoderBlock(params) - - params['num_channels'] = params['num_filters'] - self.classifier = sm.ClassifierBlock(params) - - # Code for Network Initialization - - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - """ - Computational graph - :param tensor x: input image - :return tensor: prediction logits - """ - encoder_output1, skip_encoder_1, indices_1 = self.encode1.forward(x) - encoder_output2, skip_encoder_2, indices_2 = self.encode2.forward(encoder_output1) - encoder_output3, skip_encoder_3, indices_3 = self.encode3.forward(encoder_output2) - encoder_output4, skip_encoder_4, indices_4 = self.encode4.forward(encoder_output3) - - bottleneck = self.bottleneck(encoder_output4) - - decoder_output4 = self.decode4.forward(bottleneck, skip_encoder_4, indices_4) - decoder_output3 = self.decode3.forward(decoder_output4, skip_encoder_3, indices_3) - decoder_output2 = self.decode2.forward(decoder_output3, skip_encoder_2, indices_2) - decoder_output1 = self.decode1.forward(decoder_output2, skip_encoder_1, indices_1) - - logits = self.classifier.forward(decoder_output1) - - return logits diff --git a/models/solver.py b/models/solver.py deleted file mode 100644 index b9b4630..0000000 --- a/models/solver.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import os -import torch -import time -import matplotlib.pyplot as plt -import numpy as np -import itertools -import glob - -from torch.autograd import Variable -from torch.optim import lr_scheduler -from torchvision import utils -from skimage import color -from models.losses import CombinedLoss - - -## -# Helper functions -## -def create_exp_directory(exp_dir_name): - """ - Function to create a directory if it does not exist yet. - :param str exp_dir_name: name of directory to create. - :return: - """ - if not os.path.exists(exp_dir_name): - try: - os.makedirs(exp_dir_name) - print("Successfully Created Directory @ {}".format(exp_dir_name)) - except: - print("Directory Creation Failed - Check Path") - else: - print("Directory {} Exists ".format(exp_dir_name)) - - -def dice_confusion_matrix(batch_output, labels_batch, num_classes): - """ - Function to compute the dice confusion matrix. - :param batch_output: - :param labels_batch: - :param num_classes: - :return: - """ - dice_cm = torch.zeros(num_classes, num_classes) - - for i in range(num_classes): - gt = (labels_batch == i).float() - - for j in range(num_classes): - pred = (batch_output == j).float() - inter = torch.sum(torch.mul(gt, pred)) + 0.0001 - union = torch.sum(gt) + torch.sum(pred) + 0.0001 - dice_cm[i, j] = 2 * torch.div(inter, union) - - avg_dice = torch.mean(torch.diagflat(dice_cm)) - - return avg_dice, dice_cm - - -def iou_score(pred_cls, true_cls, nclass=79): - """ - compute the intersection-over-union score - both inputs should be categorical (as opposed to one-hot) - """ - intersect_ = [] - union_ = [] - - for i in range(1, nclass): - intersect = ((pred_cls == i).float() + (true_cls == i).float()).eq(2).sum().item() - union = ((pred_cls == i).float() + (true_cls == i).float()).ge(1).sum().item() - intersect_.append(intersect) - union_.append(union) - - return np.array(intersect_), np.array(union_) - - -def precision_recall(pred_cls, true_cls, nclass=79): - """ - Function to calculate recall (TP/(TP + FN) and precision (TP/(TP+FP) per class - :param pytorch.Tensor pred_cls: network prediction (categorical) - :param pytorch.Tensor true_cls: ground truth (categorical) - :param int nclass: number of classes - :return: - """ - tpos_fneg = [] - tpos_fpos = [] - tpos = [] - - for i in range(1, nclass): - all_pred = (pred_cls == i).float() - all_gt = (true_cls == i).float() - - tpos.append((all_pred + all_gt).eq(2).sum().item()) - tpos_fpos.append(all_pred.sum().item()) - tpos_fneg.append(all_gt.sum().item()) - - return np.array(tpos), np.array(tpos_fneg), np.array(tpos_fpos) - - -## -# Plotting functions -## -def plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name): - """ - Function to plot predictions from validation set. - :param images_batch: - :param labels_batch: - :param batch_output: - :param plt_title: - :param file_save_name: - :return: - """ - - f = plt.figure(figsize=(20, 20)) - n, c, h, w = images_batch.shape - mid_slice = c // 2 - images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1) - grid = utils.make_grid(images_batch.cpu(), nrow=4) - - plt.subplot(131) - plt.imshow(grid.numpy().transpose((1, 2, 0))) - plt.title('Slices') - - grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0] - color_grid = color.label2rgb(grid.numpy(), bg_label=0) - plt.subplot(132) - plt.imshow(color_grid) - plt.title('Ground Truth') - - grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0] - color_grid = color.label2rgb(grid.numpy(), bg_label=0) - plt.subplot(133) - plt.imshow(color_grid) - plt.title('Prediction') - - plt.suptitle(plt_title) - plt.tight_layout() - - f.savefig(file_save_name, bbox_inches='tight') - plt.close(f) - plt.gcf().clear() - - -def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues, file_save_name="temp.pdf"): - """ - This function prints and plots the confusion matrix. - Normalization can be applied by setting `normalize=True`. - :param cm: - :param classes: - :param title: - :param cmap: - :param file_save_name: - :return: - """ - f = plt.figure(figsize=(35, 35)) - - plt.imshow(cm, interpolation='nearest', cmap=cmap) - plt.title(title) - plt.colorbar() - tick_marks = np.arange(len(classes)) - - plt.xticks(tick_marks, classes, rotation=45) - plt.yticks(tick_marks, classes) - - fmt = '.2f' - thresh = cm.max() / 2. - - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - plt.text(j, i, format(cm[i, j], fmt), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black") - - plt.tight_layout() - plt.ylabel('True label') - plt.xlabel('Predicted label') - - f.savefig(file_save_name, bbox_inches='tight') - plt.close(f) - plt.gcf().clear() - - -## -# Training routine -## -class Solver(object): - """ - Class for training neural networks - """ - - # gamma is the factor for lowering the lr and step_size is when it gets lowered - default_lr_scheduler_args = {"gamma": 0.05, - "step_size": 5} - - def __init__(self, num_classes, optimizer=torch.optim.Adam, optimizer_args={}, loss_func=CombinedLoss(), - lr_scheduler_args={}): - - # Merge and update the default arguments - optimizer - self.optimizer_args = optimizer_args - - lr_scheduler_args_merged = Solver.default_lr_scheduler_args.copy() - lr_scheduler_args_merged.update(lr_scheduler_args) - - # Merge and update the default arguments - lr scheduler - self.lr_scheduler_args = lr_scheduler_args_merged - - self.optimizer = optimizer - self.loss_func = loss_func - self.num_classes = num_classes - self.classes = list(range(self.num_classes)) - - def train(self, model, train_loader, validation_loader, class_names, num_epochs, - log_params, expdir, scheduler_type, torch_v11, resume=True): - """ - Train Model with provided parameters for optimization - Inputs: - -- model - model to be trained - -- train_loader - training DataLoader Object - -- validation_loader - validation DataLoader Object - -- num_epochs = total number of epochs - -- log_params - parameters for logging the progress - -- expdir --directory to save check points - - """ - create_exp_directory(expdir) # Experimental directory - create_exp_directory(log_params["logdir"]) # Logging Directory - - # Instantiate the optimizer class - optimizer = self.optimizer(model.parameters(), **self.optimizer_args) - - # Instantiate the scheduler class - if scheduler_type == "StepLR": - scheduler = lr_scheduler.StepLR(optimizer, step_size=self.lr_scheduler_args["step_size"], - gamma=self.lr_scheduler_args["gamma"]) - else: - scheduler = None - - # Set up logger format - a = "{}\t" * (self.num_classes - 2) + "{}" - - epoch = -1 # To allow for restoration - print('-------> Starting to train') - - # Code for restoring model - if resume: - - try: - prior_model_paths = sorted(glob.glob(os.path.join(expdir, 'Epoch_*')), key=os.path.getmtime) - - if prior_model_paths: - current_model = prior_model_paths.pop() - - state = torch.load(current_model) - - # Restore model dictionary - model.load_state_dict(state["model_state_dict"]) - optimizer.load_state_dict(state["optimizer_state_dict"]) - scheduler.load_state_dict(state["scheduler_state_dict"]) - epoch = state["epoch"] - - print("Successfully restored the model state. Resuming training from Epoch {}".format(epoch + 1)) - - except Exception as e: - print("No model to restore. Resuming training from Epoch 0. {}".format(e)) - - log_params["logger"].info("{} parameters in total".format(sum(x.numel() for x in model.parameters()))) - - while epoch < num_epochs: - - epoch = epoch + 1 - epoch_start = time.time() - - # Update learning rate based on epoch number (only for pytorch version <1.2) - if torch_v11 and scheduler is not None: - scheduler.step() - - loss_batch = np.zeros(1) - - for batch_idx, sample_batch in enumerate(train_loader): - - # Assign data - images_batch, labels_batch, weights_batch = sample_batch['image'], sample_batch['label'], \ - sample_batch['weight'] - - # Map to variables - images_batch = Variable(images_batch) - labels_batch = Variable(labels_batch) - weights_batch = Variable(weights_batch) - - if torch.cuda.is_available(): - images_batch, labels_batch, weights_batch = images_batch.cuda(), labels_batch.cuda(), \ - weights_batch.type(torch.FloatTensor).cuda() - - model.train() # Set to training mode! - optimizer.zero_grad() - - predictions = model(images_batch) - - loss_total, loss_dice, loss_ce = self.loss_func(predictions, labels_batch, weights_batch) - - loss_total.backward() - - optimizer.step() - - loss_batch += loss_total.item() - - if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == len(train_loader) - 1: - log_params["logger"].info("Train Epoch: {} [{}/{}] ({:.0f}%)] " - "with loss: {}".format(epoch, batch_idx, - len(train_loader), - 100. * batch_idx / len(train_loader), - loss_batch / (batch_idx + 1))) - - del images_batch, labels_batch, weights_batch, predictions, loss_total, loss_dice, loss_ce - - # Update learning rate at the end based on epoch number (only for pytorch version > 1.1) - if not torch_v11 and scheduler is not None: - scheduler.step() - - epoch_finish = time.time() - epoch_start - - log_params["logger"].info("Train Epoch {} finished in {:.04f} seconds.".format(epoch, epoch_finish)) - # End of Training, time to accumulate results - - # Testing Loop on Training Data - # Set evaluation mode on the model - model.eval() - - val_loss_total = 0 - val_loss_dice = 0 - val_loss_ce = 0 - - ints_ = np.zeros(self.num_classes - 1) - unis_ = np.zeros(self.num_classes - 1) - per_cls_counts_gt = np.zeros(self.num_classes - 1) - per_cls_counts_pred = np.zeros(self.num_classes - 1) - accs = np.zeros(self.num_classes - 1) # -1 to exclude background (still included in val loss) - - with torch.no_grad(): - - if validation_loader is not None: - - val_start = time.time() - cnf_matrix_validation = torch.zeros(self.num_classes, self.num_classes) - - for batch_idx, sample_batch in enumerate(validation_loader): - - images_batch, labels_batch, weights_batch = sample_batch['image'], sample_batch['label'], \ - sample_batch['weight'] - - # Map to variables (no longer necessary after pytorch 0.40) - images_batch = Variable(images_batch) - labels_batch = Variable(labels_batch) - weights_batch = Variable(weights_batch) - - if torch.cuda.is_available(): - images_batch, labels_batch, weights_batch = images_batch.cuda(), labels_batch.cuda(), \ - weights_batch.type(torch.FloatTensor).cuda() - - # Get logits, sum up batch loss and get final predictions (argmax) - predictions = model(images_batch) - loss_total, loss_dice, loss_ce = self.loss_func(predictions, labels_batch, weights_batch) - val_loss_total += loss_total.item() - val_loss_dice += loss_dice.item() - val_loss_ce += loss_ce.item() - - _, batch_output = torch.max(predictions, dim=1) - - # Calculate iou_scores, accuracy and dice confusion matrix + sum over previous batches - int_, uni_ = iou_score(batch_output, labels_batch, self.num_classes) - ints_ += int_ - unis_ += uni_ - - tpos, pcc_gt, pcc_pred = precision_recall(batch_output, labels_batch, self.num_classes) - accs += tpos - per_cls_counts_gt += pcc_gt - per_cls_counts_pred += pcc_pred - - _, cm_batch = dice_confusion_matrix(batch_output, labels_batch, self.num_classes) - cnf_matrix_validation += cm_batch.cpu() - - # Plot sample predictions - if batch_idx == 0: - plt_title = 'Validation Results Epoch ' + str(epoch) - - file_save_name = os.path.join(log_params["logdir"], - 'Epoch_' + str(epoch) + '_Validations_Predictions.pdf') - - plot_predictions(images_batch, labels_batch, batch_output, plt_title, file_save_name) - - del images_batch, labels_batch, weights_batch, predictions, batch_output, \ - int_, uni_, tpos, pcc_gt, pcc_pred, loss_total, loss_dice, loss_ce # cm_batch, - - # Get final measures and log them - ious = ints_ / unis_ - val_loss_total /= (batch_idx + 1) - val_loss_dice /= (batch_idx + 1) - val_loss_ce /= (batch_idx + 1) - cnf_matrix_validation = cnf_matrix_validation / (batch_idx + 1) - val_end = time.time() - val_start - - print("Completed Validation Dataset in {:0.4f} s".format(val_end)) - - save_name = os.path.join(log_params["logdir"], 'Epoch_' + str(epoch) + '_Validation_Dice_CM.pdf') - plot_confusion_matrix(cnf_matrix_validation.cpu().numpy(), self.classes, file_save_name=save_name) - - # Log metrics - log_params["logger"].info("[Epoch {} stats]: MIoU: {:.4f}; " - "Mean Recall: {:.4f}; " - "Mean Precision: {:.4f}; " - "Avg loss total: {:.4f}; " - "Avg loss dice: {:.4f}; " - "Avg loss ce: {:.4f}".format(epoch, np.mean(ious), - np.mean(accs / per_cls_counts_gt), - np.mean(accs / per_cls_counts_pred), - val_loss_total, val_loss_dice, val_loss_ce)) - - log_params["logger"].info(a.format(*class_names)) - log_params["logger"].info(a.format(*ious)) - - # Saving Models - - if epoch % log_params["log_iter"] == 0: - save_name = os.path.join(expdir, 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl') - checkpoint = {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), - "epoch": epoch} - if scheduler is not None: - checkpoint["scheduler_state_dict"] = scheduler.state_dict() - - torch.save(checkpoint, save_name) - - model.train() diff --git a/models/sub_module.py b/models/sub_module.py deleted file mode 100644 index 1fbcdb5..0000000 --- a/models/sub_module.py +++ /dev/null @@ -1,327 +0,0 @@ - -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS -import torch -import torch.nn as nn - - -# Building Blocks -class CompetitiveDenseBlock(nn.Module): - """ - Function to define a competitive dense block comprising of 3 convolutional layers, with BN/ReLU - - Inputs: - -- Params - params = {'num_channels': 1, - 'num_filters': 64, - 'kernel_h': 5, - 'kernel_w': 5, - 'stride_conv': 1, - 'pool': 2, - 'stride_pool': 2, - 'num_classes': 44 - 'kernel_c':1 - 'input':True - } - """ - - def __init__(self, params, outblock=False): - """ - Constructor to initialize the Competitive Dense Block - :param dict params: dictionary with parameters specifiying block architecture - :param bool outblock: Flag indicating if last block (before classifier block) is set up. - Default: False - :return None: - """ - super(CompetitiveDenseBlock, self).__init__() - - # Padding to get output tensor of same dimensions - padding_h = int((params['kernel_h'] - 1) / 2) - padding_w = int((params['kernel_w'] - 1) / 2) - - # Sub-layer output sizes for BN; and - conv0_in_size = int(params['num_filters']) # num_channels - conv1_in_size = int(params['num_filters']) - conv2_in_size = int(params['num_filters']) - - # Define the learnable layers - - self.conv0 = nn.Conv2d(in_channels=conv0_in_size, out_channels=params['num_filters'], - kernel_size=(params['kernel_h'], params['kernel_w']), - stride=params['stride_conv'], padding=(padding_h, padding_w)) - - self.conv1 = nn.Conv2d(in_channels=conv1_in_size, out_channels=params['num_filters'], - kernel_size=(params['kernel_h'], params['kernel_w']), - stride=params['stride_conv'], padding=(padding_h, padding_w)) - - # 1 \times 1 convolution for the last block - self.conv2 = nn.Conv2d(in_channels=conv2_in_size, out_channels=params['num_filters'], - kernel_size=(1, 1), - stride=params['stride_conv'], padding=(0, 0)) - - self.bn1 = nn.BatchNorm2d(num_features=conv1_in_size) - self.bn2 = nn.BatchNorm2d(num_features=conv2_in_size) - self.bn3 = nn.BatchNorm2d(num_features=conv2_in_size) - - self.prelu = nn.PReLU() # Learnable ReLU Parameter - self.outblock = outblock - - def forward(self, x): - """ - CompetitiveDenseBlock's computational Graph - {in (Conv - BN from prev. block) -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} x 2 -> {Conv -> BN} -> out - end with batch-normed output to allow maxout across skip-connections - - :param tensor x: input tensor (image or feature map) - :return tensor out: output tensor (processed feature map) - """ - # Activation from pooled input - x0 = self.prelu(x) - - # Convolution block 1 - x0 = self.conv0(x0) - x1_bn = self.bn1(x0) - x0_bn = torch.unsqueeze(x, 4) - x1_bn = torch.unsqueeze(x1_bn, 4) - x1 = torch.cat((x1_bn, x0_bn), dim=4) # Concatenate along the 5th dimension NB x C x H x W x F - x1_max, _ = torch.max(x1, 4) - x1 = self.prelu(x1_max) - - # Convolution block 2 - x1 = self.conv1(x1) - x2_bn = self.bn2(x1) - x2_bn = torch.unsqueeze(x2_bn, 4) - x1_max = torch.unsqueeze(x1_max, 4) - x2 = torch.cat((x2_bn, x1_max), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) - x2 = self.prelu(x2_max) - - # Convolution block 3 (end with batch-normed output to allow maxout across skip-connections) - out = self.conv2(x2) - - if not self.outblock: - out = self.bn3(out) - - return out - - -class CompetitiveDenseBlockInput(nn.Module): - """ - Function to define a competitive dense block comprising of 3 convolutional layers, with BN/ReLU for input - - Inputs: - -- Params - params = {'num_channels': 1, - 'num_filters': 64, - 'kernel_h': 5, - 'kernel_w': 5, - 'stride_conv': 1, - 'pool': 2, - 'stride_pool': 2, - 'num_classes': 44 - 'kernel_c':1 - 'input':True - } - """ - - def __init__(self, params): - """ - Constructor to initialize the Competitive Dense Block - :param dict params: dictionary with parameters specifiying block architecture - """ - super(CompetitiveDenseBlockInput, self).__init__() - - # Padding to get output tensor of same dimensions - padding_h = int((params['kernel_h'] - 1) / 2) - padding_w = int((params['kernel_w'] - 1) / 2) - - # Sub-layer output sizes for BN; and - conv0_in_size = int(params['num_channels']) - conv1_in_size = int(params['num_filters']) - conv2_in_size = int(params['num_filters']) - - # Define the learnable layers - - self.conv0 = nn.Conv2d(in_channels=conv0_in_size, out_channels=params['num_filters'], - kernel_size=(params['kernel_h'], params['kernel_w']), - stride=params['stride_conv'], padding=(padding_h, padding_w)) - - self.conv1 = nn.Conv2d(in_channels=conv1_in_size, out_channels=params['num_filters'], - kernel_size=(params['kernel_h'], params['kernel_w']), - stride=params['stride_conv'], padding=(padding_h, padding_w)) - - # 1 \times 1 convolution for the last block - self.conv2 = nn.Conv2d(in_channels=conv2_in_size, out_channels=params['num_filters'], - kernel_size=(1, 1), - stride=params['stride_conv'], padding=(0, 0)) - - self.bn0 = nn.BatchNorm2d(num_features=conv0_in_size) - self.bn1 = nn.BatchNorm2d(num_features=conv1_in_size) - self.bn2 = nn.BatchNorm2d(num_features=conv2_in_size) - self.bn3 = nn.BatchNorm2d(num_features=conv2_in_size) - - self.prelu = nn.PReLU() # Learnable ReLU Parameter - - def forward(self, x): - """ - CompetitiveDenseBlockInput's computational Graph - in -> BN -> {Conv -> BN -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} -> {Conv -> BN} -> out - - :param tensor x: input tensor (image or feature map) - :return tensor out: output tensor (processed feature map) - """ - # Input batch normalization - x0_bn = self.bn0(x) - - # Convolution block1 - x0 = self.conv0(x0_bn) - x1_bn = self.bn1(x0) - x1 = self.prelu(x1_bn) - - # Convolution block2 - x1 = self.conv1(x1) - x2_bn = self.bn2(x1) - # First Maxout - x1_bn = torch.unsqueeze(x1_bn, 4) - x2_bn = torch.unsqueeze(x2_bn, 4) # Add Singleton Dimension along 5th - x2 = torch.cat((x2_bn, x1_bn), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) - x2 = self.prelu(x2_max) - - # Convolution block 3 - out = self.conv2(x2) - out = self.bn3(out) - - return out - - -class CompetitiveEncoderBlock(CompetitiveDenseBlock): - """ - Encoder Block = CompetitiveDenseBlock + Max Pooling - """ - - def __init__(self, params): - """ - Encoder Block initialization - :param dict params: parameters like number of channels, stride etc. - """ - super(CompetitiveEncoderBlock, self).__init__(params) - self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'], - return_indices=True) # For Unpooling later on with the indices - - def forward(self, x): - """ - CComputational graph for Encoder Block: - * CompetitiveDenseBlock - * Max Pooling (+ retain indices) - - :param tensor x: feature map from previous block - :return: original feature map, maxpooled feature map, maxpool indices - """ - out_block = super(CompetitiveEncoderBlock, self).forward(x) # To be concatenated as Skip Connection - out_encoder, indices = self.maxpool(out_block) # Max Pool as Input to Next Layer - return out_encoder, out_block, indices - - -class CompetitiveEncoderBlockInput(CompetitiveDenseBlockInput): - """ - Encoder Block = CompetitiveDenseBlockInput + Max Pooling - """ - - def __init__(self, params): - """ - Encoder Block initialization - :param dict params: parameters like number of channels, stride etc. - """ - super(CompetitiveEncoderBlockInput, self).__init__(params) # The init of CompetitiveDenseBlock takes in params - self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'], - return_indices=True) # For Unpooling later on with the indices - - def forward(self, x): - """ - Computational graph for Encoder Block: - * CompetitiveDenseBlockInput - * Max Pooling (+ retain indices) - - :param tensor x: feature map from previous block - :return: original feature map, maxpooled feature map, maxpool indices - """ - out_block = super(CompetitiveEncoderBlockInput, self).forward(x) # To be concatenated as Skip Connection - out_encoder, indices = self.maxpool(out_block) # Max Pool as Input to Next Layer - return out_encoder, out_block, indices - - -class CompetitiveDecoderBlock(CompetitiveDenseBlock): - """ - Decoder Block = (Unpooling + Skip Connection) --> Dense Block - """ - - def __init__(self, params, outblock=False): - """ - Decoder Block initialization - :param dict params: parameters like number of channels, stride etc. - :param bool outblock: Flag, indicating if last block of network before classifier - is created. Default: False - """ - super(CompetitiveDecoderBlock, self).__init__(params, outblock=outblock) - self.unpool = nn.MaxUnpool2d(kernel_size=params['pool'], stride=params['stride_pool']) - - def forward(self, x, out_block, indices): - """ - Computational graph Decoder block: - * Unpooling of feature maps from lower block - * Maxout combination of unpooled map + skip connection - * Forwarding toward CompetitiveDenseBlock - - :param tensor x: input feature map from lower block (gets unpooled and maxed with out_block) - :param tensor out_block: skip connection feature map from the corresponding Encoder - :param tensor indices: indices for unpooling from the corresponding Encoder (maxpool op) - :return: processed feature maps - """ - unpool = self.unpool(x, indices) - unpool = torch.unsqueeze(unpool, 4) - - out_block = torch.unsqueeze(out_block, 4) - concat = torch.cat((unpool, out_block), dim=4) # Competitive Concatenation - concat_max, _ = torch.max(concat, 4) - out_block = super(CompetitiveDecoderBlock, self).forward(concat_max) - - return out_block - - -class ClassifierBlock(nn.Module): - """ - Classification Block - """ - def __init__(self, params): - """ - Classifier Block initialization - :param dict params: parameters like number of channels, stride etc. - """ - super(ClassifierBlock, self).__init__() - self.conv = nn.Conv2d(params['num_channels'], params['num_classes'], params['kernel_c'], - params['stride_conv']) # To generate logits - - def forward(self, x): - """ - Computational graph of classifier - :param tensor x: output of last CompetitiveDenseDecoder Block- - :return: logits - """ - logits = self.conv(x) - - return logits diff --git a/requirements.txt b/requirements.txt index bb3b11f..50653fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,5 @@ -chrisapp~=2.5.3 -nose~=1.3.7 +chris_plugin==0.2.1 pudb -h5py==2.9.0 -nibabel==2.5.1 -numpy==1.17.4 -scikit-image -torch -torchvision +loguru +pftag==1.2.22 +pflog==1.2.26 diff --git a/setup.py b/setup.py index 822c774..4c9c65a 100755 --- a/setup.py +++ b/setup.py @@ -1,19 +1,31 @@ -from os import path from setuptools import setup +import re -with open(path.join(path.dirname(path.abspath(__file__)), 'README.rst')) as f: - readme = f.read() +_version_re = re.compile(r"(?<=^__version__ = (\"|'))(.+)(?=\"|')") + +def get_version(rel_path: str) -> str: + """ + Searches for the ``__version__ = `` line in a source code file. + + https://packaging.python.org/en/latest/guides/single-sourcing-package-version/ + """ + with open(rel_path, 'r') as f: + matches = map(_version_re.search, f) + filtered = filter(lambda m: m is not None, matches) + version = next(filtered, None) + if version is None: + raise RuntimeError(f'Could not find __version__ in {rel_path}') + return version.group(0) setup( name = 'fastsurfer_inference', - version = '1.3.0', + version = get_version('fastsurfer_inference.py'), description = 'An app to efficiently perform cortical parcellation and segmentation on raw brain MRI images', - long_description = readme, author = 'Martin Reuter (FastSurfer), Sandip Samal (FNNDSC) (sandip.samal@childrens.harvard.edu)', author_email = 'dev@babyMRI.org', url = 'http://wiki', - packages = ['fastsurfer_inference','data_loader','models'], - install_requires = ['chrisapp'], + py_modules = ['fastsurfer_inference'], + install_requires = ['chris_plugin'], test_suite = 'nose.collector', tests_require = ['nose'], license = 'MIT', @@ -21,7 +33,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ - 'fastsurfer_inference = fastsurfer_inference.__main__:main' + 'fastsurfer_inference = fastsurfer_inference:main' ] } )