Skip to content

Commit

Permalink
Allow blank feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Sep 29, 2023
1 parent d834728 commit d2abf14
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions grace/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,27 @@ def run_grace(config_file: Union[str, os.PathLike]) -> None:
img_patch_augs = get_transforms(config, "patch")
img_graph_augs = get_transforms(config, "graph")

# Prepare the feature extractor:
if config.extractor_fn is not None:
# Feature extractor:
extractor_model = torch.load(config.extractor_fn)
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=img_patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)
else:
feature_extractor = lambda x, g: (x, g)
def return_unchanged(image, graph):
return image, graph

# Condition the augmentations to train mode only:
def transform(
image: torch.Tensor, graph: dict, *, in_train_mode: bool = True
) -> Callable:
# Prepare the feature extractor:
if config.extractor_fn is not None:
# Feature extractor:
extractor_model = torch.load(config.extractor_fn)
feature_extractor = FeatureExtractor(
model=extractor_model,
augmentations=img_patch_augs,
normalize=config.normalize,
bbox_size=config.patch_size,
keep_patch_fraction=config.keep_patch_fraction,
)
else:
feature_extractor = return_unchanged

# Ensure augmentations are only run on train data:
if in_train_mode is True:
image, graph = img_graph_augs(image, graph)
Expand All @@ -93,7 +96,6 @@ def prepare_dataset(
verbose: bool = True,
) -> tuple[list]:
# Read the data & terate through images & extract node features:
print(transform_method)
input_data = ImageGraphDataset(
image_dir=image_dir,
grace_dir=grace_dir,
Expand Down

0 comments on commit d2abf14

Please sign in to comment.