From ab44f27c833813226d2a095a8c101a609c91c9bb Mon Sep 17 00:00:00 2001 From: Ibrahim Salihu Yusuf Date: Mon, 16 Sep 2024 07:42:35 +0200 Subject: [PATCH] fix: run precommit --- .gitignore | 2 +- instageo/model/run.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 944dbe3..084509c 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,4 @@ __pycache__/ *prithvi/ nul/ notebooks/ -!notebooks/*.ipynb \ No newline at end of file +!notebooks/*.ipynb diff --git a/instageo/model/run.py b/instageo/model/run.py index 839f5be..8d2d217 100644 --- a/instageo/model/run.py +++ b/instageo/model/run.py @@ -84,7 +84,8 @@ def get_device() -> str: logging.info("Neither GPU nor TPU is available. Using CPU...") return device -def custom_collate_fn(batch): + +def custom_collate_fn(batch: tuple[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """Test DataLoader Collate Function. This function is a convenient wrapper around the PyTorch DataLoader class, @@ -100,6 +101,7 @@ def custom_collate_fn(batch): labels = torch.cat([a[1] for a in batch], 0) return data, labels + def create_dataloader( dataset: Dataset, batch_size: int, @@ -537,9 +539,7 @@ def main(cfg: DictConfig) -> None: constant_multiplier=cfg.dataloader.constant_multiplier, ) test_loader = create_dataloader( - test_dataset, - batch_size=batch_size, - collate_fn=custom_collate_fn + test_dataset, batch_size=batch_size, collate_fn=custom_collate_fn ) model = PrithviSegmentationModule.load_from_checkpoint( checkpoint_path,