-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_benchmark_dataset_full_workflow.py
executable file
·127 lines (115 loc) · 4.45 KB
/
example_benchmark_dataset_full_workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import pathlib
import paint.util.paint_mappings as mappings
from paint import PAINT_ROOT
from paint.data import StacClient
from paint.data.dataset import PaintCalibrationDataset
from paint.data.dataset_splits import DatasetSplitter
from paint.util import set_logger_config
set_logger_config()
if __name__ == "__main__":
"""
This script demonstrates the full workflow for loading a ``torch.Dataset`` based on a given benchmark, specifically
it includes:
- Downloading the necessary metadata to generate the dataset splits.
- Generating the dataset splits.
- Creating a ``torch.Dataset`` based on these splits, if necessary, downloading the appropriate data.
"""
# Read in arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"--metadata_input",
type=pathlib.Path,
help="File containing the metadata required to generate the dataset splits.",
default=f"{PAINT_ROOT}/metadata/calibration_metadata_all_heliostats.csv",
)
parser.add_argument(
"--output_dir",
type=pathlib.Path,
help="Root directory to save outputs.",
default=f"{PAINT_ROOT}/benchmarks",
)
parser.add_argument(
"--split_type",
type=str,
help="The split type to apply.",
choices=[mappings.AZIMUTH_SPLIT, mappings.SOLSTICE_SPLIT],
default=mappings.AZIMUTH_SPLIT,
)
parser.add_argument(
"--train_size",
type=int,
help="The number of training samples required per heliostat - the total training size depends on the number of"
"heliostats.",
default=10,
)
parser.add_argument(
"--val_size",
type=int,
help="The number of validation samples per heliostat - the total validation size depends on the number of"
"heliostats.",
default=30,
)
parser.add_argument(
"--remove_unused_data",
type=bool,
help="Whether to remove metadata that is not required to load benchmark splits, but may be useful for plots or "
"data inspection.",
default=True,
)
parser.add_argument(
"--item_type",
type=str,
help="The type of item to be loaded -- i.e. raw image, cropped image, flux image, or flux centered image",
choices=[
mappings.CALIBRATION_RAW_IMAGE_KEY,
mappings.CALIBRATION_FLUX_IMAGE_KEY,
mappings.CALIBRATION_FLUX_CENTERED_IMAGE_KEY,
mappings.CALIBRATION_PROPERTIES_KEY,
mappings.CALIBRATION_CROPPED_IMAGE_KEY,
],
default="calibration_properties",
)
args = parser.parse_args()
metadata_file = args.metadata_input
# Check if the metadata file has already been downloaded, if not download it.
if not metadata_file.exists():
# Create STAC client to download the metadata.
output_dir_for_stac = metadata_file.parent
client = StacClient(output_dir=output_dir_for_stac)
client.get_heliostat_metadata(heliostats=None)
# Set the correct folder to save the benchmark splits.
splits_output_dir = args.output_dir / "splits"
splitter = DatasetSplitter(
input_file=args.metadata_input,
output_dir=splits_output_dir,
remove_unused_data=args.remove_unused_data,
)
# Generate the splits, they will be saved automatically to the defined location.
_ = splitter.get_dataset_splits(
split_type=args.split_type,
training_size=args.train_size,
validation_size=args.val_size,
)
# Determine name to automatically load the splits into the dataset.
dataset_benchmark_file = (
splits_output_dir
/ f"benchmark_split-{args.split_type}_train-{args.train_size}_validation-{args.val_size}.csv"
)
# Set the correct folder for the dataset.
dataset_output_dir = (
args.output_dir
/ "datasets"
/ f"benchmark_split-{args.split_type}_train-{args.train_size}_validation-{args.val_size}"
/ args.item_type
)
# Determine whether to download the data or not:
# The first time this script is executed locally, the data must be downloaded, afterward no longer.
dataset_download = not dataset_output_dir.exists()
# Initialize dataset from benchmark splits.
train, test, val = PaintCalibrationDataset.from_benchmark(
benchmark_file=dataset_benchmark_file,
root_dir=dataset_output_dir,
item_type=args.item_type,
download=dataset_download,
)