Skip to content

Commit

Permalink
Merge branch 'public-dev' into public-main
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh7joshi committed Oct 31, 2022
2 parents ef78bc1 + 43c20a2 commit 8db71d4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
2 changes: 1 addition & 1 deletion agml/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
kwargs.update(json.load(f))

# Infer the classes for image classification/object detection.
classes = kwargs['classes']
classes = kwargs.pop('classes')
if classes is None:
if task == 'semantic_segmentation':
raise ValueError(
Expand Down
12 changes: 8 additions & 4 deletions agml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ class ClassificationModel(AgMLModelBase):
parameter `net`, and you'll need to implement methods like `training_step`,
`configure_optimizers`, etc. See PyTorch Lightning for more information.
"""
serializable = frozenset(("model", ))
state_override = serializable
serializable = frozenset(("model", "regression"))
state_override = frozenset(("model",))

def __init__(self, num_classes):
def __init__(self, num_classes, regression = False):
# Construct the network and load in pretrained weights.
super(ClassificationModel, self).__init__()
self._num_classes = num_classes
self._regression = regression
self.net = self._construct_sub_net(num_classes)

@auto_move_data
Expand Down Expand Up @@ -169,7 +170,10 @@ def predict(self, images):
A `np.ndarray` with integer labels for each image.
"""
images = self.preprocess_input(images)
return self._to_out(torch.squeeze(torch.argmax(self.forward(images), 1)))
out = self.forward(images)
if not self._regression: # standard classification
out = torch.argmax(out, 1)
return self._to_out(torch.squeeze(out))

def evaluate(self, loader, **kwargs):
"""Runs an accuracy evaluation on the given loader.
Expand Down
19 changes: 10 additions & 9 deletions agml/models/metrics/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,21 @@ def mean_average_precision(
# Calculate average precision for each class.
pred_boxes, true_boxes = torch.tensor(predicted_boxes), torch.tensor(truth_boxes)
for c in range(num_classes):
# If there are no predictions, then the per-class AP is 0.
if len(pred_boxes) == 0:
average_precisions.append(torch.tensor(0.0))
continue

# Get the predictions and targets corresponding to this class.
detections = pred_boxes[torch.where(pred_boxes[:, 1] == c)[0]].tolist()
ground_truths = true_boxes[torch.where(true_boxes[:, 1] == c)[0]].tolist()
torch_gt = torch.tensor(ground_truths)

# If there are no ground truths, then the per-class AP is 0.
if len(ground_truths) == 0:
average_precisions.append(0.0)
average_precisions.append(torch.tensor(0.0))
continue

# Get all of the unique data samples and create a dictionary
# storing all of the corresponding bounding boxes for each sample.
training_ids = torch.unique(true_boxes[:, 0])
truth_samples_by_id = {
idx.numpy().item(): true_boxes[torch.where(true_boxes[:, 0] == idx)] # noqa
for idx in training_ids}

# Determine the number of boxes for each of the training samples.
numpy_gt = torch.tensor(ground_truths)
amount_bboxes = {int(k.numpy().item()): torch.zeros(v) for k, v in zip(
Expand All @@ -101,7 +100,8 @@ def mean_average_precision(

# Only take out the ground_truths that have the same
# training idx as the detection.
ground_truth_img = truth_samples_by_id[update_num]
ground_truth_img = torch_gt[
torch.where(torch_gt[:, 0] == update_num)[0]]

# Get the bounding box with the highest IoU.
ious = torch.tensor([bbox_iou(
Expand Down Expand Up @@ -198,6 +198,7 @@ def update(self, pred_data, gt_data):
pred_scores = np.squeeze(pred_scores)
pred_labels, gt_labels, pred_scores = \
self._scalar_to_array(pred_labels, gt_labels, pred_scores)
gt_boxes = np.squeeze(gt_boxes)
if pred_boxes.ndim == 1:
pred_boxes = np.expand_dims(pred_boxes, axis = 0)
if gt_boxes.ndim == 1:
Expand Down
22 changes: 22 additions & 0 deletions agml/models/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@
import albumentations as A
from albumentations.pytorch import ToTensorV2

from agml.models.tools import imagenet_style_process as _isp


def imagenet_preprocess(image, size = None):
"""Preprocesses a single input image to ImageNet standards.
The preprocessing steps are applied logically; if the images
are passed with preprocessing already having been applied, for
instance, the images are already resized or they are already been
normalized, the operation is not applied again, for efficiency.
Preprocessing includes the following steps:
1. Resizing the image to size (224, 224).
2. Performing normalization with ImageNet parameters.
3. Converting the image into a PyTorch tensor format.
as well as other intermediate steps such as adding a channel
dimension for two-channel inputs, for example.
"""
return _isp(image, size = size)


class EfficientDetPreprocessor(object):
"""A preprocessor which prepares a data sample for `EfficientDet`.
Expand Down
19 changes: 15 additions & 4 deletions agml/viz/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,27 @@ def set_colormap(colormap):
1. "default": Traditional matplotlib RGB colors.
2. "agriculture": Various shades of green (for agriculture).
If you want to set a custom colormap, then pass a list of RGB
values which will be used as the colormap.
Parameters
----------
colormap : str
The colormap to set.
"""
global _COLORMAP_CHOICE, _COLORMAPS
colormap = colormap.lower()
if colormap not in _COLORMAPS.keys():
raise ValueError(f"Invalid colormap {colormap} received.")
_COLORMAP_CHOICE = colormap
if isinstance(colormap, list):
if not all(len(i) == 3 for i in colormap):
raise ValueError(
"If you want a custom colormap, then pass a list of RGB values.")
elif isinstance(colormap, str):
colormap = colormap.lower()
if colormap not in _COLORMAPS.keys():
raise ValueError(f"Invalid colormap {colormap} received.")
else:
raise TypeError(f"Invalid colormap of type {type(colormap)}.")
_COLORMAPS['custom'] = colormap
_COLORMAP_CHOICE = 'custom'


def auto_resolve_image(f):
Expand Down

0 comments on commit 8db71d4

Please sign in to comment.