From 13d150ec4d93cb0201672e593590a6d20eed6c05 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sat, 1 May 2021 20:08:49 +0800 Subject: [PATCH] Cleanup data pipeline (#103) * Remove DataPipeline * Fix prediction with image path * Minor fixes * Add unittest for contains_any_tensor --- test/test_data_pipeline.py | 13 +++- test/test_engine.py | 3 +- yolort/data/__init__.py | 4 +- yolort/data/_helper.py | 22 +++++- yolort/data/data_module.py | 68 +++++++----------- yolort/data/data_pipeline.py | 92 ------------------------ yolort/data/detection_pipeline.py | 69 ------------------ yolort/models/yolo_module.py | 112 ++++++++++++++++-------------- 8 files changed, 119 insertions(+), 264 deletions(-) delete mode 100644 yolort/data/data_pipeline.py delete mode 100644 yolort/data/detection_pipeline.py diff --git a/test/test_data_pipeline.py b/test/test_data_pipeline.py index 0c88f26e..cf478fbd 100644 --- a/test/test_data_pipeline.py +++ b/test/test_data_pipeline.py @@ -1,16 +1,25 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from pathlib import Path import unittest +import numpy as np +import torch from torch import Tensor -from yolort.data import DetectionDataModule -import yolort.data._helper as data_helper +from yolort.data import DetectionDataModule, contains_any_tensor, _helper as data_helper from typing import Dict class DataPipelineTester(unittest.TestCase): + def test_contains_any_tensor(self): + dummy_numpy = np.random.randn(3, 6) + self.assertFalse(contains_any_tensor(dummy_numpy)) + dummy_tensor = torch.randn(3, 6) + self.assertTrue(contains_any_tensor(dummy_tensor)) + dummy_tensors = [torch.randn(3, 6), torch.randn(9, 5)] + self.assertTrue(contains_any_tensor(dummy_tensors)) + def test_get_dataset(self): # Acquire the images and labels from the coco128 dataset train_dataset = data_helper.get_dataset(data_root='data-bin', mode='train') diff --git a/test/test_engine.py b/test/test_engine.py index 2f6fc36b..87e4bdc9 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -8,8 +8,7 @@ import pytorch_lightning as pl -from yolort.data import COCOEvaluator, DetectionDataModule -import yolort.data._helper as data_helper +from yolort.data import COCOEvaluator, DetectionDataModule, _helper as data_helper from yolort.models import yolov5s from yolort.models.yolo import yolov5_darknet_pan_s_r31 diff --git a/yolort/data/__init__.py b/yolort/data/__init__.py index 56103127..c24a08ce 100644 --- a/yolort/data/__init__.py +++ b/yolort/data/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from .coco_eval import COCOEvaluator -from .data_pipeline import DataPipeline from .data_module import DetectionDataModule, VOCDetectionDataModule, COCODetectionDataModule +from .coco_eval import COCOEvaluator +from ._helper import contains_any_tensor diff --git a/yolort/data/_helper.py b/yolort/data/_helper.py index dbf3d6e0..f76a0032 100644 --- a/yolort/data/_helper.py +++ b/yolort/data/_helper.py @@ -1,16 +1,17 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -import random from pathlib import Path, PosixPath from zipfile import ZipFile import torch -from torchvision import ops +from torch import Tensor + +from typing import Type, Any from .coco import COCODetection from .transforms import collate_fn, default_train_transforms, default_val_transforms import logging -logger = logging.getLogger(__name__) + def get_coco_api_from_dataset(dataset): for _ in range(10): @@ -24,6 +25,19 @@ def get_coco_api_from_dataset(dataset): raise NotImplementedError("Currently only supports COCO datasets") +def contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: + """ + Determine whether or not a list contains any Type + """ + if isinstance(value, dtype): + return True + if isinstance(value, (list, tuple)): + return any(contains_any_tensor(v, dtype=dtype) for v in value) + elif isinstance(value, dict): + return any(contains_any_tensor(v, dtype=dtype) for v in value.values()) + return False + + def prepare_coco128( data_path: PosixPath, dirname: str = 'coco128', @@ -35,6 +49,8 @@ def prepare_coco128( data_path (PosixPath): root path of coco128 dataset. dirname (str): the directory name of coco128 dataset. Default: 'coco128'. """ + logger = logging.getLogger(__name__) + if not data_path.is_dir(): logger.info(f'Create a new directory: {data_path}') data_path.mkdir(parents=True, exist_ok=True) diff --git a/yolort/data/data_module.py b/yolort/data/data_module.py index 5c2e9d81..51f32e13 100644 --- a/yolort/data/data_module.py +++ b/yolort/data/data_module.py @@ -6,13 +6,11 @@ from pytorch_lightning import LightningDataModule +from typing import Callable, List, Any, Optional + from .transforms import collate_fn, default_train_transforms, default_val_transforms from .voc import VOCDetection from .coco import COCODetection -from .data_pipeline import DataPipeline -from .detection_pipeline import ObjectDetectionDataPipeline - -from typing import Callable, List, Any, Optional class DetectionDataModule(LightningDataModule): @@ -79,19 +77,33 @@ def val_dataloader(self, batch_size: int = 16) -> None: return loader - @property - def data_pipeline(self) -> DataPipeline: - if self._data_pipeline is None: - self._data_pipeline = self.default_pipeline() - return self._data_pipeline - @data_pipeline.setter - def data_pipeline(self, data_pipeline) -> None: - self._data_pipeline = data_pipeline +class COCODetectionDataModule(DetectionDataModule): + def __init__( + self, + data_path: str, + year: str = "2017", + train_transform: Optional[Callable] = default_train_transforms, + val_transform: Optional[Callable] = default_val_transforms, + batch_size: int = 1, + num_workers: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + train_dataset = self.build_datasets( + data_path, image_set='train', year=year, transforms=train_transform) + val_dataset = self.build_datasets( + data_path, image_set='val', year=year, transforms=val_transform) + + super().__init__(train_dataset=train_dataset, val_dataset=val_dataset, + batch_size=batch_size, num_workers=num_workers, *args, **kwargs) + + self.num_classes = 80 @staticmethod - def default_pipeline() -> DataPipeline: - return ObjectDetectionDataPipeline() + def build_datasets(data_path, image_set, year, transforms): + ann_file = Path(data_path) / 'annotations' / f"instances_{image_set}{year}.json" + return COCODetection(data_path, ann_file, transforms()) class VOCDetectionDataModule(DetectionDataModule): @@ -134,31 +146,3 @@ def build_datasets(data_path, image_set, years, transforms): return datasets[0], num_classes else: return torch.utils.data.ConcatDataset(datasets), num_classes - - -class COCODetectionDataModule(DetectionDataModule): - def __init__( - self, - data_path: str, - year: str = "2017", - train_transform: Optional[Callable] = default_train_transforms, - val_transform: Optional[Callable] = default_val_transforms, - batch_size: int = 1, - num_workers: int = 0, - *args: Any, - **kwargs: Any, - ) -> None: - train_dataset = self.build_datasets( - data_path, image_set='train', year=year, transforms=train_transform) - val_dataset = self.build_datasets( - data_path, image_set='val', year=year, transforms=val_transform) - - super().__init__(train_dataset=train_dataset, val_dataset=val_dataset, - batch_size=batch_size, num_workers=num_workers, *args, **kwargs) - - self.num_classes = 80 - - @staticmethod - def build_datasets(data_path, image_set, year, transforms): - ann_file = Path(data_path).joinpath('annotations').joinpath(f"instances_{image_set}{year}.json") - return COCODetection(data_path, ann_file, transforms()) diff --git a/yolort/data/data_pipeline.py b/yolort/data/data_pipeline.py deleted file mode 100644 index 00e76c80..00000000 --- a/yolort/data/data_pipeline.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from torch import Tensor -from torch.utils.data.dataloader import default_collate - -from typing import Any - - -class DataPipeline: - """ - This class purpose is to facilitate the conversion of raw data to processed or batched data and back. - Several hooks are provided for maximum flexibility. - - Example:: - - .. code-block:: python - - class MyTextDataPipeline(DataPipeline): - def __init__(self, tokenizer, padder): - self.tokenizer = tokenizer - self.padder = padder - - def before_collate(self, samples): - # encode each input sequence - return [self.tokenizer.encode(sample) for sample in samplers] - - def after_collate(self, batch): - # pad tensor elements to the maximum length in the batch - return self.padder(batch) - - def after_uncollate(self, samples): - # decode each input sequence - return [self.tokenizer.decode(sample) for sample in samples] - - """ - - def before_collate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samples - - def collate(self, samples: Any) -> Any: - """Override to convert a set of samples to a batch""" - if not isinstance(samples, Tensor): - return default_collate(samples) - return samples - - def after_collate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def collate_fn(self, samples: Any) -> Any: - """ - Utility function to convert raw data to batched data - - ``collate_fn`` as used in ``torch.utils.data.DataLoader``. - To avoid the before/after collate transformations, please use ``collate``. - """ - samples = self.before_collate(samples) - batch = self.collate(samples) - batch = self.after_collate(batch) - return batch - - def before_uncollate(self, batch: Any) -> Any: - """Override to apply transformations to the batch""" - return batch - - def uncollate(self, batch: Any) -> Any: - """Override to convert a batch to a set of samples""" - samples = batch - return samples - - def after_uncollate(self, samples: Any) -> Any: - """Override to apply transformations to samples""" - return samples - - def uncollate_fn(self, batch: Any) -> Any: - """Utility function to convert batched data back to raw data""" - batch = self.before_uncollate(batch) - samples = self.uncollate(batch) - samples = self.after_uncollate(samples) - return samples diff --git a/yolort/data/detection_pipeline.py b/yolort/data/detection_pipeline.py deleted file mode 100644 index 0b04528d..00000000 --- a/yolort/data/detection_pipeline.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from collections.abc import Sequence - -from torch import Tensor -from torchvision.io import read_image - -from .transforms import collate_fn -from .data_pipeline import DataPipeline - -from typing import Callable, Any, Optional, Type - - -class ObjectDetectionDataPipeline(DataPipeline): - """ - Modified from: - - """ - def __init__(self, loader: Optional[Callable] = None): - if loader is None: - loader = lambda x: read_image(x) / 255. - self._loader = loader - - def before_collate(self, samples: Any) -> Any: - if _contains_any_tensor(samples, Tensor): - return samples - - if isinstance(samples, str): - samples = [samples] - - if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): - outputs = [] - for sample in samples: - output = self._loader(sample) - outputs.append(output) - return outputs - - raise NotImplementedError("The samples should either be a tensor or path, a list of paths or tensors.") - - def collate(self, samples: Any) -> Any: - if not isinstance(samples, Tensor): - elem = samples[0] - - if isinstance(elem, Sequence): - return collate_fn(samples) - - return list(samples) - - return samples.unsqueeze(dim=0) - - def after_collate(self, batch: Any) -> Any: - return (batch["x"], batch["target"]) if isinstance(batch, dict) else (batch, None) - - -def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: - """ - TODO: we should refactor FlashDatasetFolder to better integrate - with DataPipeline. That way, we wouldn't need this check. - This is because we are running transforms in both places. - - Ref: - - """ - if isinstance(value, dtype): - return True - if isinstance(value, (list, tuple)): - return any(_contains_any_tensor(v, dtype=dtype) for v in value) - elif isinstance(value, dict): - return any(_contains_any_tensor(v, dtype=dtype) for v in value.values()) - return False diff --git a/yolort/models/yolo_module.py b/yolort/models/yolo_module.py index 5b465498..588b3f82 100644 --- a/yolort/models/yolo_module.py +++ b/yolort/models/yolo_module.py @@ -5,15 +5,15 @@ import torch from torch import Tensor +from torchvision.io import read_image from pytorch_lightning import LightningModule +from typing import Any, Callable, List, Dict, Tuple, Optional, Union from . import yolo from .transform import YOLOTransform from ._utils import _evaluate_iou -from ..data import DetectionDataModule, DataPipeline, COCOEvaluator - -from typing import Any, List, Dict, Tuple, Optional, Union +from ..data import COCOEvaluator, contains_any_tensor __all__ = ['YOLOModule'] @@ -51,8 +51,6 @@ def __init__( self.transform = YOLOTransform(min_size, max_size) - self._data_pipeline = None - # metrics self.evaluator = None if annotation_path is not None: @@ -166,7 +164,7 @@ def test_step(self, batch, batch_idx): The test step. """ images, targets = batch - images = list(image.to(self.device) for image in images) + images = list(image.to(next(self.parameters()).device) for image in images) preds = self._forward_impl(images) results = self.evaluator(preds, targets) # log step metric @@ -175,69 +173,79 @@ def test_step(self, batch, batch_idx): def test_epoch_end(self, outputs): return self.log('coco_eval', self.evaluator.compute()) + def configure_optimizers(self): + return torch.optim.SGD( + self.model.parameters(), + lr=self.lr, + momentum=0.9, + weight_decay=0.005, + ) + @torch.no_grad() def predict( self, x: Any, - batch_idx: Optional[int] = None, - skip_collate_fn: bool = False, - dataloader_idx: Optional[int] = None, - data_pipeline: Optional[DataPipeline] = None, - ) -> Any: + image_loader: Optional[Callable] = None, + ) -> List[Dict[str, Tensor]]: """ Predict function for raw data or processed data - Args: - x: Input to predict. Can be raw data or processed data. + image_loader: Utility function to convert raw data to Tensor. - batch_idx: Batch index - - dataloader_idx: Dataloader index + Returns: + The post-processed model predictions. + """ + image_loader = image_loader or self.default_loader + images = self.collate_images(x, image_loader) + outputs = self.forward(images) + return outputs - skip_collate_fn: Whether to skip the collate step. - this is required when passing data already processed - for the model, for example, data from a dataloader + def default_loader(self, img_path: str) -> Tensor: + """ + Default loader of read a image path. - data_pipeline: Use this to override the current data pipeline + Args: + img_path (str): a image path Returns: - The post-processed model predictions - + Tensor, processed tensor for prediction. """ - data_pipeline = data_pipeline or self.data_pipeline - batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - images, _ = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None) - images = [img.to(self.device) for img in images] - predictions = self.forward(images) - output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x - return output + return read_image(img_path) / 255. - def configure_optimizers(self): - return torch.optim.SGD( - self.model.parameters(), - lr=self.lr, - momentum=0.9, - weight_decay=0.005, - ) + def collate_images(self, samples: Any, image_loader: Callable) -> List[Tensor]: + """ + Prepare source samples for inference. - @torch.jit.unused - @property - def data_pipeline(self) -> DataPipeline: - # we need to save the pipeline in case this class - # is loaded from checkpoint and used to predict - if not self._data_pipeline: - self._data_pipeline = self.default_pipeline() - return self._data_pipeline - - @data_pipeline.setter - def data_pipeline(self, data_pipeline: DataPipeline) -> None: - self._data_pipeline = data_pipeline + Args: + samples (Any): samples source, support the following various types: + - str or List[str]: a image path or list of image paths. + - Tensor or List[Tensor]: a tensor or list of tensors. - @staticmethod - def default_pipeline() -> DataPipeline: - """Pipeline to use when there is no datamodule or it has not defined its pipeline""" - return DetectionDataModule.default_pipeline() + Returns: + List[Tensor], The processed image samples. + """ + p = next(self.parameters()) # for device and type + if isinstance(samples, Tensor): + return [samples.to(p.device).type_as(p)] + + if contains_any_tensor(samples): + return [sample.to(p.device).type_as(p) for sample in samples] + + if isinstance(samples, str): + samples = [samples] + + if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): + outputs = [] + for sample in samples: + output = image_loader(sample).to(p.device).type_as(p) + outputs.append(output) + return outputs + + raise NotImplementedError( + f"The type of the sample is {type(samples)}, we currently don't support it now, the " + "samples should be either a tensor, list of tensors, a image path or list of image paths." + ) @staticmethod def add_model_specific_args(parent_parser):