diff --git a/.gitignore b/.gitignore index 944dbe3..084509c 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,4 @@ __pycache__/ *prithvi/ nul/ notebooks/ -!notebooks/*.ipynb \ No newline at end of file +!notebooks/*.ipynb diff --git a/instageo/model/run.py b/instageo/model/run.py index 839f5be..8d2d217 100644 --- a/instageo/model/run.py +++ b/instageo/model/run.py @@ -84,7 +84,8 @@ def get_device() -> str: logging.info("Neither GPU nor TPU is available. Using CPU...") return device -def custom_collate_fn(batch): + +def custom_collate_fn(batch: tuple[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """Test DataLoader Collate Function. This function is a convenient wrapper around the PyTorch DataLoader class, @@ -100,6 +101,7 @@ def custom_collate_fn(batch): labels = torch.cat([a[1] for a in batch], 0) return data, labels + def create_dataloader( dataset: Dataset, batch_size: int, @@ -537,9 +539,7 @@ def main(cfg: DictConfig) -> None: constant_multiplier=cfg.dataloader.constant_multiplier, ) test_loader = create_dataloader( - test_dataset, - batch_size=batch_size, - collate_fn=custom_collate_fn + test_dataset, batch_size=batch_size, collate_fn=custom_collate_fn ) model = PrithviSegmentationModule.load_from_checkpoint( checkpoint_path,