Skip to content

Commit

Permalink
implement basic ops debug (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosheraboh authored Aug 2, 2022
1 parent 711bdf2 commit bdd907a
Show file tree
Hide file tree
Showing 19 changed files with 379 additions and 660 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
===============================
MNIST classfier implementation that demonstrate end to end training, inference and evaluation using FuseMedML
MNIST classifier implementation that demonstrate end to end training, inference and evaluation using FuseMedML
"""

import copy
Expand Down Expand Up @@ -116,7 +116,7 @@ def create_model() -> torch.nn.Module:
model=torch_model,
model_inputs=["data.image"],
post_forward_processing_function=perform_softmax,
model_outputs=["logits.classification", "output.classification"],
model_outputs=["model.logits.classification", "model.output.classification"],
)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, model_dir: str, opt_lr: float, opt_weight_decay: float, **kwa
model=torch_model,
model_inputs=["data.image"],
post_forward_processing_function=perform_softmax,
model_outputs=["logits.classification", "output.classification"],
model_outputs=["model.logits.classification", "model.output.classification"],
)

# losses
Expand All @@ -119,7 +119,7 @@ def forward(self, batch_dict: NDict) -> NDict:
## Step
def training_step(self, batch_dict: NDict, batch_idx: int) -> dict:
# run forward function and store the outputs in batch_dict["model"]
batch_dict["model"] = self.forward(batch_dict)
batch_dict = self.forward(batch_dict)
# given the batch_dict and FuseMedML style losses - compute the losses, return the total loss and save losses values in batch_dict["losses"]
total_loss = fuse_pl.step_losses(self._losses, batch_dict)
# given the batch_dict and FuseMedML style losses - collect the required values to compute the metrics on epoch_end
Expand All @@ -130,7 +130,7 @@ def training_step(self, batch_dict: NDict, batch_idx: int) -> dict:

def validation_step(self, batch_dict: NDict, batch_idx: int) -> dict:
# run forward function and store the outputs in batch_dict["model"]
batch_dict["model"] = self.forward(batch_dict)
batch_dict = self.forward(batch_dict)
# given the batch_dict and FuseMedML style losses - compute the losses, return the total loss (ignored) and save losses values in batch_dict["losses"]
_ = fuse_pl.step_losses(self._losses, batch_dict)
# given the batch_dict and FuseMedML style losses - collect the required values to compute the metrics on epoch_end
Expand All @@ -145,7 +145,7 @@ def predict_step(self, batch_dict: NDict, batch_idx: int) -> dict:
"Error: predict_step expectes list of prediction keys to extract from batch_dict. Please specify it using set_predictions_keys() method "
)
# run forward function and store the outputs in batch_dict["model"]
batch_dict["model"] = self.forward(batch_dict)
batch_dict = self.forward(batch_dict)
# extract the requried keys - defined in self.set_predictions_keys()
return fuse_pl.step_extract_predictions(self._prediction_keys, batch_dict)

Expand Down
7 changes: 7 additions & 0 deletions fuse/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ The following operators are useful when implementing a common pipeline:
* OpToTensor - convert many different types to PyTorch tensor
* OpOneHotToNumber - convert one-hot encoding vectors into numbers

[**Debug operators**](ops/ops_debug.py)

* OpPrintKeys - print the keys available at this point in the pipeline. Use OpDebugBase constructor arguments to limit the samples to debug.
* OpPrintShapes - print the shapes of all tensors, numpy arrays and sequences. Use OpDebugBase constructor arguments to limit the samples to debug.
* OpPrintTypes - print the types of all keys. Use OpDebugBase constructor arguments to limit the samples to debug.


**Imaging operators**
See [fuseimg package](../../fuseimg/data/README.md)

125 changes: 125 additions & 0 deletions fuse/data/ops/ops_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from abc import abstractmethod
from typing import Hashable, List, Sequence, Optional
from fuse.data.utils.sample import get_sample_id
from fuse.utils import NDict
from fuse.data import OpBase
import numpy
import torch


class OpDebugBase(OpBase):
"""
Base class for debug operations.
Provides the ability to limit samples to debug (will debug the first k samples).
Inherits and implements self.call_debug instead of self.__call__.
"""

def __init__(
self, name: Optional[str] = None, sample_ids: Optional[List[Hashable]] = None, num_samples: bool = False
):
"""
:param name: string identifier - might be useful when the debug op display or save information into a file
:param sample_ids: apply for the specified sample ids. To apply for all set to None.
:param num_samples: apply for the first num_samples (per process). if None, will apply for all.
"""
super().__init__()
self._name = name
self._sample_ids = sample_ids
self._num_samples = num_samples
self._num_samples_done = 0

def reset(self, name: Optional[str] = None):
"""Reset operation state"""
self._num_samples_done = 0
self._name = name

def should_debug_sample(self, sample_dict: NDict) -> bool:
if self._num_samples and self._num_samples_done >= self._num_samples:
return False

if self._sample_ids is not None:
sid = get_sample_id(sample_dict)
if sid not in self._sample_ids:
return False

self._num_samples_done += True
return True

def __call__(self, sample_dict: NDict, **kwargs) -> NDict:
if self.should_debug_sample(sample_dict):
self.call_debug(sample_dict, **kwargs)
return sample_dict

@abstractmethod
def call_debug(self, sample_dict: NDict, **kwargs) -> None:
"""The actual debug op implementation"""
raise NotImplementedError


class OpPrintKeys(OpDebugBase):
"""
Print list of available keys at a given point in the data pipeline
It's recommended, but not a must, to run it in a single process.
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintKeys(num_samples=1), dict()),
```
"""

def call_debug(self, sample_dict: NDict) -> None:
print(f"Sample {get_sample_id(sample_dict)} keys:")
for key in sample_dict.keypaths():
print(f"{key}")


class OpPrintShapes(OpDebugBase):
"""
Print the shapes/length of every torch tensor / numpy array / sequence
Add at the top your script to force single process:
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintShapes(num_samples=1), dict()),
```
"""

def call_debug(self, sample_dict: NDict) -> None:
print(f"Sample {get_sample_id(sample_dict)} shapes:")
for key in sample_dict.keypaths():
value = sample_dict[key]
if isinstance(value, torch.Tensor):
print(f"{key} is tensor with shape: {value.shape}")
elif isinstance(value, numpy.ndarray):
print(f"{key} is numpy array with shape: {value.shape}")
elif not isinstance(value, str) and isinstance(value, Sequence):
print(f"{key} is sequence with length: {len(value)}")


class OpPrintTypes(OpDebugBase):
"""
Print the the type of each key
Add at the top your script to force single process:
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintTypes(num_samples=1), dict()),
```
"""

def call_debug(self, sample_dict: NDict) -> None:
print(f"Sample {get_sample_id(sample_dict)} types:")
for key in sample_dict.keypaths():
value = sample_dict[key]
print(f"{key} - {type(value).__name__}")
192 changes: 0 additions & 192 deletions fuse/data/ops/ops_visprobe.py

This file was deleted.

Loading

0 comments on commit bdd907a

Please sign in to comment.