Skip to content

Commit

Permalink
Allow blank feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Oct 3, 2023
1 parent 44ef83a commit 99c1042
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,31 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None:
subfolder_path = run_dir / subfolder
subfolder_path.mkdir(parents=True, exist_ok=True)

# Prepare the feature extractor:
extractor_model = torch.load(config.extractor_fn)
patch_augs = get_transforms(config, "patch")
# Augmentations, if any:
img_patch_augs = get_transforms(config, "patch")
img_graph_augs = get_transforms(config, "graph")
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)

def return_unchanged(image, graph):
return image, graph

# Condition the augmentations to train mode only:
def transform(
image: torch.Tensor, graph: dict, *, in_train_mode: bool = True
) -> Callable:
# Prepare the feature extractor:
if config.extractor_fn is not None:
# Feature extractor:
extractor_model = torch.load(config.extractor_fn)
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=img_patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)
else:
feature_extractor = return_unchanged

# Ensure augmentations are only run on train data:
if in_train_mode:
image_aug, graph_aug = img_graph_augs(image, graph)
Expand Down

0 comments on commit 99c1042

Please sign in to comment.