Skip to content

Commit

Permalink
adding normalized map support
Browse files Browse the repository at this point in the history
  • Loading branch information
ad12 committed Oct 5, 2018
1 parent 541f89f commit 037a11e
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 39 deletions.
4 changes: 3 additions & 1 deletion defaults.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
DEFAULT_BATCH_SIZE = 32
DEFAULT_BATCH_SIZE = 32

FIX_VISUALIZATION_BOUNDS = True
2 changes: 1 addition & 1 deletion msk/knee.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tissues.femoral_cartilage import FemoralCartilage
from utils.quant_vals import QuantitativeValue as QV
from utils.quant_vals import QuantitativeValues as QV
from utils import io_utils

KNEE_KEY = 'knee'
Expand Down
2 changes: 1 addition & 1 deletion pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from models.get_model import get_model
from tissues.femoral_cartilage import FemoralCartilage

from utils.quant_vals import QuantitativeValue as QV, get_qv
from utils.quant_vals import QuantitativeValues as QV, get_qv
import file_constants as fc

from msk import knee
Expand Down
10 changes: 4 additions & 6 deletions scan_sequences/cones.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__EXPECTED_NUM_ECHO_TIMES__ = 4

__R_SQUARED_THRESHOLD__ = 0.9
__INITIAL_P0_VALS__ = (1.0, 1/30.0)
__INITIAL_T2_STAR_VAL__ = 30.0 # ms

__T2_STAR_LOWER_BOUND__ = 0
__T2_STAR_UPPER_BOUND__ = np.inf
Expand Down Expand Up @@ -56,7 +56,6 @@ def interregister(self, target_path, mask_path=None):
raw_filepaths = dict()

echo_time_inds = natsorted(list(subvolumes.keys()))
print(echo_time_inds)

for i in range(len(echo_time_inds)):
raw_filepath = os.path.join(temp_raw_dirpath, '%03d.nii.gz' % i)
Expand Down Expand Up @@ -107,7 +106,6 @@ def interregister(self, target_path, mask_path=None):
transformation_files = reg_output.transform
warped_files = [(base_echo_time, reg_output.warped_file)]

print(raw_filepaths)
files = []
for echo_time_ind in raw_filepaths.keys():
filepath = raw_filepaths[echo_time_ind]
Expand Down Expand Up @@ -146,8 +144,8 @@ def save_data(self, save_dirpath):

if self.t2star_map is not None:
data = {'data': self.t2star_map}
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValue.T2_STAR.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValue.T2_STAR.name.lower()),
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValues.T2_STAR.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValues.T2_STAR.name.lower()),
self.t2star_map, self.pixel_spacing)

# Save interregistered files
Expand Down Expand Up @@ -183,7 +181,7 @@ def generate_t2_star_map(self):
svs.append(svr)

svs = np.concatenate(svs)
vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_P0_VALS__)
vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_T2_STAR_VAL__)

map_unfiltered = vals.reshape(original_shape)
r_squared = r_squared.reshape(original_shape)
Expand Down
9 changes: 4 additions & 5 deletions scan_sequences/cube_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

__EXPECTED_NUM_SPIN_LOCK_TIMES__ = 4
__R_SQUARED_THRESHOLD__ = 0.9
__INITIAL_P0_VALS__ = (1.0, 1/30.0)
__INITIAL_T1_RHO_VAL__ = 70.0

__T1_RHO_LOWER_BOUND__ = 0
__T1_RHO_UPPER_BOUND__ = np.inf
Expand Down Expand Up @@ -136,9 +136,8 @@ def generate_t1_rho_map(self):
svs.append(svr)

svs = np.concatenate(svs)
print(svs.shape)

vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_P0_VALS__)
vals, r_squared = qv.fit_monoexp_tc(spin_lock_times, svs, __INITIAL_T1_RHO_VAL__)

map_unfiltered = vals.reshape(original_shape)
r_squared = r_squared.reshape(original_shape)
Expand Down Expand Up @@ -213,8 +212,8 @@ def save_data(self, save_dirpath):

if self.t1rho_map is not None:
data = {'data': self.t1rho_map}
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValue.T1_RHO.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValue.T1_RHO.name.lower()), self.t1rho_map,
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % qv.QuantitativeValues.T1_RHO.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % qv.QuantitativeValues.T1_RHO.name.lower()), self.t1rho_map,
self.pixel_spacing)

# Save interregistered files
Expand Down
6 changes: 3 additions & 3 deletions scan_sequences/dess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from scan_sequences.scans import TargetSequence
from utils import dicom_utils, im_utils, io_utils
from utils.quant_vals import QuantitativeValue
from utils.quant_vals import QuantitativeValues

class Dess(TargetSequence):
NAME = 'dess'
Expand Down Expand Up @@ -162,8 +162,8 @@ def save_data(self, save_dirpath):

save_dirpath = self.__save_dir__(save_dirpath)
data = {'data': self.t2map}
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % QuantitativeValue.T2.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % QuantitativeValue.T2.name.lower()), self.t2map, self.pixel_spacing)
io_utils.save_h5(os.path.join(save_dirpath, '%s.h5' % QuantitativeValues.T2.name.lower()), data)
io_utils.save_nifti(os.path.join(save_dirpath, '%s.nii.gz' % QuantitativeValues.T2.name.lower()), self.t2map, self.pixel_spacing)

# write echos
for i in range(len(self.subvolumes)):
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-script
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
WEIGHTS_DIRECTORY=""
WEIGHTS_DIRECTORY="/Users/arjundesai/Documents/stanford/research/msk_pipeline_raw/weights"
if [ -z "$WEIGHTS_DIRECTORY" ]; then
echo "Please define WEIGHTS_DIRECTORY in script. Use the absolute path"
exit 125
Expand Down
4 changes: 2 additions & 2 deletions test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scan_sequences.cube_quant import CubeQuant

from utils import io_utils, dicom_utils
from utils.quant_vals import QuantitativeValue
from utils.quant_vals import QuantitativeValues

import file_constants as fc

Expand Down Expand Up @@ -117,7 +117,7 @@ def test_t2_map_load(self):

scan = pipeline.handle_dess(vargin)

scan.tissues[0].calc_quant_vals(scan.t2map, QuantitativeValue.T2)
scan.tissues[0].calc_quant_vals(scan.t2map, QuantitativeValues.T2)

print(scan.t2map.shape)

Expand Down
34 changes: 28 additions & 6 deletions tissues/femoral_cartilage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@

import nipy.labs.mask as nlm

from utils.quant_vals import QuantitativeValue
from utils.quant_vals import QuantitativeValues
import matplotlib.pyplot as plt

import defaults
import warnings

BOUNDS = {QuantitativeValues.T2: 100.0,
QuantitativeValues.T1_RHO: 150.0,
QuantitativeValues.T2_STAR: 100.0}

class FemoralCartilage(Tissue):
ID = 1
Expand Down Expand Up @@ -246,9 +252,10 @@ def calc_quant_vals(self, quant_map, map_type):
sagital_keys = ['anterior', 'central', 'posterior']
df = pd.DataFrame(data=np.transpose(tissue_values), index=sagital_keys, columns=pd.MultiIndex.from_tuples(zip(depth_keys, coronal_keys)))

maps = [{'title': 'T2 deep', 'data': deep, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2deep.png'},
{'title': 'T2 superficial', 'data': superficial, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2superficial.png'},
{'title': 'T2 total', 'data': total, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': 't2total.png'}]
qv_name = map_type.name
maps = [{'title': '%s deep' % qv_name, 'data': deep, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_deep.png' % qv_name},
{'title': '%s superficial' % qv_name, 'data': superficial, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_superficial.png' % qv_name},
{'title': '%s total' % qv_name, 'data': total, 'xlabel': 'Slice', 'ylabel': 'Angle (binned)', 'filename': '%s_total.png' % qv_name}]

self.__store_quant_vals__(maps, df, map_type)

Expand All @@ -261,7 +268,7 @@ def __save_quant_data__(self, dirpath):
q_names = []
dfs = []

for quant_val in QuantitativeValue:
for quant_val in QuantitativeValues:
if quant_val.name not in self.quant_vals.keys():
continue

Expand All @@ -276,12 +283,27 @@ def __save_quant_data__(self, dirpath):
ylabel = 'Angle (binned)'
title = q_map_data['title']
data_map = q_map_data['data']

plt.clf()
plt.imshow(data_map, cmap='jet')

upper_bound = BOUNDS[quant_val]
is_picture_written = False
if defaults.FIX_VISUALIZATION_BOUNDS:
if np.sum(data_map <= upper_bound) == 0:
plt.imshow(data_map, cmap='jet', vmin=0.0, vmax=BOUNDS[quant_val])
is_picture_written = True
else:
warnings.warn('%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale.'
% (quant_val.name, upper_bound))

if not is_picture_written:
plt.imshow(data_map, cmap='jet')

plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.colorbar()

plt.savefig(filepath)

if len(dfs) > 0:
Expand Down
2 changes: 1 addition & 1 deletion tissues/tissue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import os
from utils import io_utils
from utils.quant_vals import QuantitativeValue
from utils.quant_vals import QuantitativeValues
import cv2
import numpy as np

Expand Down
25 changes: 13 additions & 12 deletions utils/quant_vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
__R_SQUARED_THRESHOLD__ = 0.9


class QuantitativeValue(Enum):
class QuantitativeValues(Enum):
T1_RHO = 1
T2 = 2
T2_STAR = 3


def get_qv(id):
for qv in QuantitativeValue:
for qv in QuantitativeValues:
if qv.name.lower() == id or qv.value == id:
return qv


def fit_mono_exp(x, y, p0=None):
def __fit_mono_exp__(x, y, p0=None):
def func(t, a, b):
exp = np.exp(b * t)
return a * exp
Expand All @@ -40,8 +40,9 @@ def func(t, a, b):
return popt, r_squared


def fit_monoexp_tc(x, ys, p0):
vals = np.zeros([1, ys.shape[-1]])
def fit_monoexp_tc(x, ys, tc0):
p0 = (1.0, -1/tc0)
time_constants = np.zeros([1, ys.shape[-1]])
r_squared = np.zeros([1, ys.shape[-1]])

warned_negative = False
Expand All @@ -56,17 +57,17 @@ def fit_monoexp_tc(x, ys, p0):
continue

try:
params, r2 = fit_mono_exp(x, y, p0=p0)
t1_rho = abs(params[-1])
params, r2 = __fit_mono_exp__(x, y, p0=p0)
tc = 1 / abs(params[-1])
except RuntimeError:
t1_rho, r2 = (np.nan, 0.0)
tc, r2 = (np.nan, 0.0)

vals[..., i] = t1_rho
time_constants[..., i] = tc
r_squared[..., i] = r2

return vals, r_squared
return time_constants, r_squared


if __name__ == '__main__':
print(type(QuantitativeValue.T1_RHO.name))
print(QuantitativeValue.T1_RHO.value== 1)
print(type(QuantitativeValues.T1_RHO.name))
print(QuantitativeValues.T1_RHO.value == 1)

0 comments on commit 037a11e

Please sign in to comment.