Skip to content

Commit

Permalink
fix: run precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Alikerin committed Sep 16, 2024
1 parent c548b7f commit ab44f27
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,4 @@ __pycache__/
*prithvi/
nul/
notebooks/
!notebooks/*.ipynb
!notebooks/*.ipynb
8 changes: 4 additions & 4 deletions instageo/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def get_device() -> str:
logging.info("Neither GPU nor TPU is available. Using CPU...")
return device

def custom_collate_fn(batch):

def custom_collate_fn(batch: tuple[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Test DataLoader Collate Function.
This function is a convenient wrapper around the PyTorch DataLoader class,
Expand All @@ -100,6 +101,7 @@ def custom_collate_fn(batch):
labels = torch.cat([a[1] for a in batch], 0)
return data, labels


def create_dataloader(
dataset: Dataset,
batch_size: int,
Expand Down Expand Up @@ -537,9 +539,7 @@ def main(cfg: DictConfig) -> None:
constant_multiplier=cfg.dataloader.constant_multiplier,
)
test_loader = create_dataloader(
test_dataset,
batch_size=batch_size,
collate_fn=custom_collate_fn
test_dataset, batch_size=batch_size, collate_fn=custom_collate_fn
)
model = PrithviSegmentationModule.load_from_checkpoint(
checkpoint_path,
Expand Down

0 comments on commit ab44f27

Please sign in to comment.