Skip to content

Commit

Permalink
Hypertension classification from ECG (#536)
Browse files Browse the repository at this point in the history
write t1 maps, remove np.bool
  • Loading branch information
lucidtronix authored Aug 17, 2023
1 parent 01bb7ed commit f592069
Show file tree
Hide file tree
Showing 24 changed files with 369 additions and 112 deletions.
2 changes: 1 addition & 1 deletion docker/vm_boot_images/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ARG BASE_IMAGE
FROM ${BASE_IMAGE}

LABEL maintainer="James Pirruccello <jamesp@broadinstitute.org>"
LABEL maintainer="Sam Freesun Friedman <sam@broadinstitute.org>"

# Setup time zone (or else docker build hangs)
ENV TZ=America/New_York
Expand Down
1 change: 1 addition & 0 deletions docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ opencv-python
blosc
boto3
ml4ht==0.0.10
umap-learn[plot]
2 changes: 1 addition & 1 deletion docker/vm_boot_images/config/ubuntu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Other necessities
apt-get update
echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
apt-get install -y wget unzip curl python3-pydot python3-pydot-ng graphviz ttf-mscorefonts-installer git pip
apt-get install -y wget unzip curl python3-pydot python3-pydot-ng graphviz ttf-mscorefonts-installer git pip ffmpeg
21 changes: 15 additions & 6 deletions ingest/cmd/build_curl_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@


FORM_TEXT = """
<form name="fetch" action="https://biota.osc.ox.ac.uk/dataset.cgi" method="post">
<input type="hidden" name="id" value="671599"/>
<input type="hidden" name="s" value="305736"/>
<input type="hidden" name="t" value="1684504514"/>
<input type="hidden" name="i" value="67.244.49.54"/>
<input type="hidden" name="v" value="da5aa919c0119423d8335cf169f51bb2a834f2967558e3a45f0f49d0157d6428"/>
<input class="btn_glow" type="submit" value="Fetch"/>
</form>
"""


Expand All @@ -12,12 +21,12 @@

test = """
<form name="fetch" action="https://biota.osc.ox.ac.uk/dataset.cgi" method="post">
<input type="hidden" name="id" value="AAA">
<input type="hidden" name="s" value="BBB">
<input type="hidden" name="t" value="CCC">
<input type="hidden" name="i" value="DDD">
<input type="hidden" name="v" value="EEE">
<input class="sub_go" type="submit" value="Fetch">
<input type="hidden" name="id" value="671600"/>
<input type="hidden" name="s" value="305736"/>
<input type="hidden" name="t" value="1684501586"/>
<input type="hidden" name="i" value="67.244.49.54"/>
<input type="hidden" name="v" value="891f3ec7f3388d4c7a0c094ef1abde73f44c356f2732dade6a7921d9770dd095"/>
<input class="btn_glow" type="submit" value="Fetch"/>
</form>
"""

Expand Down
8 changes: 4 additions & 4 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def parse_args():
parser.add_argument('--z', default=48, type=int, help='z tensor resolution')
parser.add_argument('--t', default=48, type=int, help='Number of time slices')
parser.add_argument('--mlp_concat', default=False, action='store_true', help='Concatenate input with every multiplayer perceptron layer.') # TODO: should be the same style as u_connect
parser.add_argument('--dense_layers', nargs='*', default=[32], type=int, help='List of number of hidden units in neural nets dense layers.')
parser.add_argument('--dense_layers', nargs='*', default=[256], type=int, help='List of number of hidden units in neural nets dense layers.')
parser.add_argument('--dense_regularize_rate', default=0.0, type=float, help='Rate parameter for dense_regularize.')
parser.add_argument('--dense_regularize', default=None, choices=list(DENSE_REGULARIZATION_CLASSES), help='Type of regularization layer for dense layers.')
parser.add_argument('--dense_normalize', default=None, choices=list(NORMALIZATION_CLASSES), help='Type of normalization layer for dense layers.')
Expand Down Expand Up @@ -241,9 +241,9 @@ def parse_args():
'If not specified, default 0.1 is used. If default ratios are used with train_csv, some tensors may be ignored because ratios do not sum to 1.',
)
parser.add_argument('--test_steps', default=32, type=int, help='Number of batches to use for testing.')
parser.add_argument('--training_steps', default=72, type=int, help='Number of training batches to examine in an epoch.')
parser.add_argument('--validation_steps', default=18, type=int, help='Number of validation batches to examine in an epoch validation.')
parser.add_argument('--learning_rate', default=0.0002, type=float, help='Learning rate during training.')
parser.add_argument('--training_steps', default=96, type=int, help='Number of training batches to examine in an epoch.')
parser.add_argument('--validation_steps', default=32, type=int, help='Number of validation batches to examine in an epoch validation.')
parser.add_argument('--learning_rate', default=0.00005, type=float, help='Learning rate during training.')
parser.add_argument('--mixup_alpha', default=0, type=float, help='If positive apply mixup and sample from a Beta with this value as shape parameter alpha.')
parser.add_argument(
'--label_weights', nargs='*', type=float,
Expand Down
6 changes: 3 additions & 3 deletions ml4h/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def __str__(self):
'interventricular_septum': 5, 'interatrial_septum': 6, 'crista_terminalis': 7,
}
MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_pap': 4, 'LV_cavity': 5,
'RV_cavity': 6, 'thoracic_cavity': 7, 'liver': 8, 'stomach': 9, 'spleen': 10, 'kidney': 12, 'body': 11,
'left_atrium': 13, 'right_atrium': 14, 'aorta': 15, 'pulmonary_artery': 16,
'background': 0, 'body': 1, 'thoracic_cavity': 2, 'liver': 3, 'stomach': 4, 'spleen': 5, 'kidney': 6,
'interventricular_septum': 7, 'LV_free_wall': 8, 'anterolateral_pap': 9, 'posteromedial_pap': 10, 'LV_cavity': 11,
'RV_free_wall': 12, 'RV_cavity': 13,
}
MRI_SAX_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4,
Expand Down
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def dashes(n): return '-' * n
def _unpack_truth_into_events(truth, intervals):
event_time = np.argmin(np.diff(truth[:, :intervals]), axis=-1)
event_time[truth[:, intervals-1] == 1] = intervals-1 # If the sample is never censored set event time to max time
event_indicator = np.sum(truth[:, intervals:], axis=-1).astype(np.bool)
event_indicator = np.sum(truth[:, intervals:], axis=-1).astype(bool)
return event_indicator, event_time


Expand Down
22 changes: 22 additions & 0 deletions ml4h/models/basic_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List

import numpy as np

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, LSTM, Concatenate, Flatten
Expand Down Expand Up @@ -237,3 +239,23 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = Non
x = tf.keras.layers.Dropout(self.dropout_rate)(x)
x = self.final_layer(x)
return x


def random_gaussian_noise(x, scalar=0.1):
return x + np.random.randn(*x.shape[1:].as_list()) * scalar


class RandomGauss(tf.keras.layers.Layer):
def __init__(self, scalar=0.1, **kwargs):
super().__init__(**kwargs)
self.scalar = scalar

def call(self, x):
return random_gaussian_noise(x, self.scalar)

def get_config(self):
config = super().get_config()
config.update({
"scalar": self.scalar,
})
return config
8 changes: 4 additions & 4 deletions ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import tensorflow_probability as tfp

from ml4h.metrics import get_metric_dict
from ml4h.models.model_factory import _get_custom_objects
from ml4h.models.model_factory import get_custom_objects
from ml4h.plots import plot_metric_history
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.optimizers import get_optimizer, NON_KERAS_OPTIMIZERS
Expand Down Expand Up @@ -984,7 +984,7 @@ def legacy_multimodal_multitask_model(
"""
tensor_maps_out = parent_sort(tensor_maps_out)
u_connect: DefaultDict[TensorMap, Set[TensorMap]] = u_connect or defaultdict(set)
custom_dict = _get_custom_objects(tensor_maps_out)
custom_dict = get_custom_objects(tensor_maps_out)
opt = get_optimizer(
optimizer, learning_rate, steps_per_epoch=training_steps, learning_rate_schedule=learning_rate_schedule,
optimizer_kwargs=kwargs.get('optimizer_kwargs'),
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def make_paired_autoencoder_model(
multimodal_merge: str = 'average',
**kwargs
) -> Model:
custom_dict = _get_custom_objects(kwargs['tensor_maps_out'])
custom_dict = get_custom_objects(kwargs['tensor_maps_out'])
opt = get_optimizer(
kwargs['optimizer'], kwargs['learning_rate'], steps_per_epoch=kwargs['training_steps'],
learning_rate_schedule=kwargs['learning_rate_schedule'], optimizer_kwargs=kwargs.get('optimizer_kwargs'),
Expand Down Expand Up @@ -1491,7 +1491,7 @@ def get_model_inputs_outputs(
models_inputs_outputs = dict()

for model_file in model_files:
custom = _get_custom_objects(tensor_maps_out)
custom = get_custom_objects(tensor_maps_out)
logging.info(f'custom keys: {list(custom.keys())}')
m = load_model(model_file, custom_objects=custom, compile=False)
model_inputs_outputs = defaultdict(list)
Expand Down
4 changes: 3 additions & 1 deletion ml4h/models/merge_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def get_config(self):
def call(self, inputs):
# We use `add_loss` to create a regularization loss
# that depends on the inputs.
self.add_loss(self.weight * contrastive_difference(inputs[0], inputs[1], self.batch_size, self.temperature))
contrastive_loss = self.weight * contrastive_difference(inputs[0], inputs[1], self.batch_size, self.temperature)
self.add_loss(contrastive_loss)
self.add_metric(contrastive_loss, name="contrastive_loss")
return inputs


Expand Down
10 changes: 6 additions & 4 deletions ml4h/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from ml4h.models.transformer_blocks import TransformerDecoder, TransformerEncoder, PositionalEncoding, MultiHeadAttention
from ml4h.models.merge_blocks import GlobalAveragePoolBlock, EncodeIdentityBlock, L2LossLayer, CosineLossLayer, VariationalDiagNormal
from ml4h.models.merge_blocks import FlatConcatDenseBlock, FlatConcatBlock, AverageBlock, PairLossBlock, ReduceMean, ContrastiveLossLayer
from ml4h.models.basic_blocks import ModelAsBlock, LSTMEncoderBlock, LanguageDecoderBlock, DenseEncoder, DenseDecoder, LinearDecoder, PartitionedLinearDecoder, LanguagePredictionBlock
from ml4h.models.basic_blocks import ModelAsBlock, LSTMEncoderBlock, LanguageDecoderBlock, DenseEncoder, DenseDecoder
from ml4h.models.basic_blocks import LinearDecoder, PartitionedLinearDecoder, LanguagePredictionBlock, RandomGauss


BLOCK_CLASSES = {
Expand Down Expand Up @@ -100,7 +101,7 @@ def make_multimodal_multitask_model(
"""
tensor_maps_out = parent_sort(tensor_maps_out)
u_connect: DefaultDict[TensorMap, Set[TensorMap]] = u_connect or defaultdict(set)
custom_dict = _get_custom_objects(tensor_maps_out)
custom_dict = get_custom_objects(tensor_maps_out)
opt = get_optimizer(
optimizer, learning_rate, steps_per_epoch=training_steps, learning_rate_schedule=learning_rate_schedule,
optimizer_kwargs=kwargs.get('optimizer_kwargs'),
Expand Down Expand Up @@ -294,13 +295,14 @@ def _load_model_encoders_and_decoders(
return m, encoders, decoders, merger


def _get_custom_objects(tensor_maps_out: List[TensorMap]) -> Dict[str, Any]:
def get_custom_objects(tensor_maps_out: List[TensorMap]) -> Dict[str, Any]:
custom_objects = {
obj.__name__: obj
for obj in chain(
NON_KERAS_OPTIMIZERS.values(), ACTIVATION_FUNCTIONS.values(), NORMALIZATION_CLASSES.values(),
[
VariationalDiagNormal, L2LossLayer, CosineLossLayer, ContrastiveLossLayer, PositionalEncoding, MultiHeadAttention,
VariationalDiagNormal, L2LossLayer, CosineLossLayer, ContrastiveLossLayer, PositionalEncoding,
MultiHeadAttention, RandomGauss,
KerasLayer,
],
)
Expand Down
9 changes: 6 additions & 3 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback

from ml4h.TensorMap import TensorMap
from ml4h.plots import plot_metric_history
from ml4h.defines import IMAGE_EXT, MODEL_EXT
from ml4h.models.inspect import plot_and_time_model
from ml4h.models.model_factory import _get_custom_objects
from ml4h.models.model_factory import get_custom_objects


def train_model_from_generators(
Expand All @@ -27,7 +28,7 @@ def train_model_from_generators(
run_id: str,
inspect_model: bool,
inspect_show_labels: bool,
output_tensor_maps = [],
output_tensor_maps: List[TensorMap] = [],
return_history: bool = False,
plot: bool = True,
save_last_model: bool = False,
Expand All @@ -49,11 +50,13 @@ def train_model_from_generators(
:param run_id: User-chosen string identifying this run
:param inspect_model: If True, measure training and inference runtime of the model and generate architecture plot.
:param inspect_show_labels: If True, show labels on the architecture plot.
:param output_tensor_maps: List of output TensorMap
:param return_history: If true return history from training and don't plot the training history
:param plot: If true, plots the metrics for train and validation set at the end of each epoch
:param save_last_model: If true saves the model weights from last epoch otherwise saves model with best validation loss
:return: The optimized model.
"""
model_file = os.path.join(output_folder, run_id, run_id + MODEL_EXT)
if not os.path.exists(os.path.dirname(model_file)):
Expand All @@ -70,7 +73,7 @@ def train_model_from_generators(
)

logging.info('Model weights saved at: %s' % model_file)
custom_dict = _get_custom_objects(output_tensor_maps)
custom_dict = get_custom_objects(output_tensor_maps)
model = load_model(model_file, custom_objects=custom_dict, compile=False)
model.compile(optimizer='adam', loss='mse')
if plot:
Expand Down
16 changes: 8 additions & 8 deletions ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def evaluate_predictions(
y_predictions, y_truth, tm.channel_map, title, folder, dpi, width, height,
)
plot_prediction_calibration(
y_predictions, y_truth, tm.channel_map, title, folder, dpi, width, height,
y_predictions, y_truth, tm.channel_map, title, folder, 10, dpi, width, height,
)
performance_metrics.update(
subplot_roc_per_class(
Expand Down Expand Up @@ -218,7 +218,7 @@ def evaluate_predictions(
),
)
plot_prediction_calibration(
y_predictions, y_truth, tm.channel_map, title, folder, dpi, width, height,
y_predictions, y_truth, tm.channel_map, title, folder, 10, dpi, width, height,
)
rocs.append((y_predictions, y_truth, tm.channel_map))
elif tm.is_categorical() and tm.axes() == 3:
Expand All @@ -242,7 +242,7 @@ def evaluate_predictions(
),
)
plot_prediction_calibration(
y_predictions, y_truth, tm.channel_map, title, folder, dpi, width, height,
y_predictions, y_truth, tm.channel_map, title, folder, 10, dpi, width, height,
)
rocs.append((y_predictions, y_truth, tm.channel_map))
elif tm.is_categorical() and tm.axes() == 4:
Expand All @@ -269,7 +269,7 @@ def evaluate_predictions(
),
)
plot_prediction_calibration(
y_predictions, y_truth, tm.channel_map, title, folder, dpi, width, height,
y_predictions, y_truth, tm.channel_map, title, folder, 10, dpi, width, height,
)
rocs.append((y_predictions, y_truth, tm.channel_map))
elif tm.is_survival_curve():
Expand Down Expand Up @@ -557,9 +557,9 @@ def plot_prediction_calibrations(
:param width: Width in inches of the figure
:param height: Height in inches of the figure
"""
_ = plt.figure(figsize=(width, height), dpi=dpi)
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
ax2 = plt.subplot2grid((3, 1), (2, 0))
_, (ax1, ax2) = plt.subplots(3, figsize=(width, height * 2), dpi=dpi)
#ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
#ax2 = plt.subplot2grid((3, 1), (2, 0))

true_sums = np.sum(truth, axis=0)
ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated Brier score: 0.0")
Expand Down Expand Up @@ -632,7 +632,7 @@ def plot_prediction_calibration(
:param width: Width in inches of the figure
:param height: Height in inches of the figure
"""
_, (ax1, ax3, ax2) = plt.subplots(3, figsize=(width, height), dpi=dpi)
fig, (ax1, ax3, ax2) = plt.subplots(3, 1, figsize=(width*2, height*4), dpi=dpi)

true_sums = np.sum(truth, axis=0)
ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated Brier score: 0.0")
Expand Down
10 changes: 8 additions & 2 deletions ml4h/tensorize/tensor_writer_ukbb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
'cine_segmented_sax_b8', 'cine_segmented_sax_b9', 'cine_segmented_sax_b10', 'cine_segmented_sax_b11', 'cine_segmented_sax_b12',
'cine_segmented_sax_b13', 'cine_segmented_sax_inlinevf', 'cine_segmented_lax_inlinevf', 'cine_segmented_ao_dist',
'cine_segmented_lvot', 'flow_250_tp_aov_bh_epat@c_p', 'flow_250_tp_aov_bh_epat@c', 'flow_250_tp_aov_bh_epat@c_mag',
'shmolli_192i_b1_sax_b1s_sax_b1s_sax_b1s_t1map', 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map',
'shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map',
'shmolli_192i_b3_sax_b3s_sax_b3s_sax_b3s_t1map', 'shmolli_192i_b4_sax_b4s_sax_b4s_sax_b4s_t1map',
'shmolli_192i_b5_sax_b5s_sax_b5s_sax_b5s_t1map', 'shmolli_192i_b6_sax_b6s_sax_b6s_sax_b6s_t1map',
'shmolli_192i_b7_sax_b7s_sax_b7s_sax_b7s_t1map',

]
MRI_CARDIAC_SERIES_SEGMENTED = [series+'_segmented' for series in MRI_CARDIAC_SERIES]
MRI_BRAIN_SERIES = ['t1_p2_1mm_fov256_sag_ti_880', 't2_flair_sag_p2_1mm_fs_ellip_pf78']
Expand All @@ -64,7 +70,7 @@
MRI_LIVER_IDEAL_PROTOCOL = ['lms_ideal_optimised_low_flip_6dyn', 'lms_ideal_optimised_low_flip_6dyn_12bit']

DICOM_MRI_FIELDS = [
'20209', '20208', '20210', '20212', '20213', '20204', '20203', '20254', '20216', '20220', '20218',
'20209', '20208', '20210', '20212', '20213', '20214', '20204', '20203', '20254', '20216', '20220', '20218',
'20227', '20225', '20217', '20158',
]

Expand Down Expand Up @@ -185,7 +191,7 @@ def write_tensors_from_dicom_pngs(
try:
png = imageio.imread(os.path.join(png_path, dicom_file + png_postfix))
full_tensor = np.zeros((x, y), dtype=np.float32)
full_tensor[:png.shape[0], :png.shape[1]] = png
full_tensor[:png.shape[0], :png.shape[1]] = png[..., 0]
tensor_file = os.path.join(tensors, str(sample_id) + TENSOR_EXT)
if not os.path.exists(os.path.dirname(tensor_file)):
os.makedirs(os.path.dirname(tensor_file))
Expand Down
Loading

0 comments on commit f592069

Please sign in to comment.