Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Dataset #62

Merged
merged 45 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
45b92ef
extend mappings for dataset
kalebphipps Dec 5, 2024
a7f1ed8
include initial dataset, only basic functions
kalebphipps Dec 5, 2024
6247b25
include logger
kalebphipps Dec 5, 2024
55cdc8b
include class method to download from benchmark file
kalebphipps Dec 5, 2024
e09b7dc
fix logging
kalebphipps Dec 6, 2024
16ce746
fix docstrings
kalebphipps Dec 6, 2024
15263a7
add new function for dataset - not yet finished
kalebphipps Dec 6, 2024
5b4a0a7
include boolean to save in one folder if the data is required for a d…
kalebphipps Dec 6, 2024
b8cc4ce
include from heliostat class method
kalebphipps Dec 6, 2024
98d3b30
update example
kalebphipps Dec 6, 2024
bcd4714
include docstrings
kalebphipps Dec 9, 2024
6ac0b87
example script for the entire workflow including downloading data
kalebphipps Dec 9, 2024
013da2b
fix error with mappings
kalebphipps Dec 9, 2024
32d3389
include torchvision as dependency
kalebphipps Dec 9, 2024
d562be5
improve loading from benchmark method to only consider data available…
kalebphipps Dec 9, 2024
720aa3f
include tests for the dataset
kalebphipps Dec 9, 2024
2133da6
ruff formatting
kalebphipps Dec 9, 2024
140f0d3
Merge branch 'main' into features/torch_dataset
kalebphipps Dec 9, 2024
0f60a77
add test data for dataset
kalebphipps Dec 9, 2024
768901c
rename unused variables
kalebphipps Dec 9, 2024
adb1db8
fix type hints
kalebphipps Dec 9, 2024
9503ecd
include test for custom string representation
kalebphipps Dec 9, 2024
ceb0d04
fix initial orientation values
kalebphipps Dec 9, 2024
515d847
fix formatting
kalebphipps Dec 9, 2024
007edc0
fix formatting
kalebphipps Dec 9, 2024
b215e6d
fix capitalization of index column
kalebphipps Dec 9, 2024
f46bfb9
rename test files
kalebphipps Dec 10, 2024
692549d
remove unused files
kalebphipps Dec 10, 2024
586bfd2
adjust test coverage
kalebphipps Dec 10, 2024
25b86e2
adjust test for new names
kalebphipps Dec 10, 2024
409031e
adjust test for new names and include test of string representation
kalebphipps Dec 10, 2024
0c6c594
adjust names of calibration item identifiers
kalebphipps Dec 10, 2024
0c3aa35
exclude download function from test coverage
kalebphipps Dec 10, 2024
1a7a736
adjust STAC for new file names
kalebphipps Dec 10, 2024
0be6d8a
improve memory efficiency and adjust file names
kalebphipps Dec 10, 2024
b2c41d1
improve argument description and include choices
kalebphipps Dec 10, 2024
7aa2b31
include choices
kalebphipps Dec 10, 2024
b892da3
include choices
kalebphipps Dec 10, 2024
e0a663e
include choices
kalebphipps Dec 10, 2024
b1e6f66
include choices
kalebphipps Dec 10, 2024
4f53c63
make executable
kalebphipps Dec 10, 2024
1b1a7c2
make executable
kalebphipps Dec 10, 2024
fa0781a
fix typos
kalebphipps Dec 10, 2024
afaab7d
update structure and include examples
kalebphipps Dec 10, 2024
50cae29
fix typo
kalebphipps Dec 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
437 changes: 437 additions & 0 deletions paint/data/dataset.py

Large diffs are not rendered by default.

44 changes: 35 additions & 9 deletions paint/data/stac_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _process_single_heliostat_item(
heliostat_catalog_id: str,
save_folder: str,
pbar: tqdm,
for_dataset: bool = False,
) -> None:
"""
Process a single heliostat item.
Expand All @@ -197,6 +198,8 @@ def _process_single_heliostat_item(
Name of the folder to save the collection items in.
pbar : tqdm
Progress bar.
for_dataset : bool
Whether to save all items in one folder for a dataset (Default: ``False``).
"""
item_time = item.properties["datetime"]
if start_date and end_date:
Expand All @@ -218,12 +221,14 @@ def _process_single_heliostat_item(
for key, asset in item.assets.items():
if key in filtered_calibration_keys:
self._download_heliostat_asset(
asset, heliostat_catalog_id, save_folder
asset, heliostat_catalog_id, save_folder, for_dataset
)
else:
# Process all assets in the item.
for asset in item.assets.values():
self._download_heliostat_asset(asset, heliostat_catalog_id, save_folder)
self._download_heliostat_asset(
asset, heliostat_catalog_id, save_folder, for_dataset
)
pbar.update(1)

def _process_heliostat_items(
Expand All @@ -235,6 +240,7 @@ def _process_heliostat_items(
collection_id: str,
heliostat_catalog_id: str,
save_folder: str,
for_dataset: bool = False,
) -> None:
"""
Process and download items, either with or without date filtering and calibration key filtering.
Expand All @@ -259,6 +265,8 @@ def _process_heliostat_items(
ID of the considered heliostat catalog.
save_folder : str
Name of the folder to save the collection items in.
for_dataset : bool
Whether to save all items in one folder for a dataset (Default: ``False``).
"""
with tqdm(
total=len(items),
Expand All @@ -277,6 +285,7 @@ def _process_heliostat_items(
heliostat_catalog_id,
save_folder,
pbar,
for_dataset,
)
for item in items
]
Expand All @@ -288,7 +297,11 @@ def _process_heliostat_items(
log.error(f"Error processing heliostat item: {e}")

def _download_heliostat_asset(
self, asset: pystac.asset.Asset, heliostat_catalog_id: str, save_folder: str
self,
asset: pystac.asset.Asset,
heliostat_catalog_id: str,
save_folder: str,
for_dataset: bool = False,
) -> None:
"""
Download the asset from the provided URL and save it to the selected location.
Expand All @@ -301,15 +314,20 @@ def _download_heliostat_asset(
ID of the considered heliostat catalog.
save_folder : str
Name of the folder to save the asset in.
for_dataset : bool
Whether to save all items in one folder for a dataset (Default: ``False``).
"""
url = asset.href
file_end = url.split("/")[-1]
file_name = (
self.output_dir
/ heliostat_catalog_id.split("-")[0]
/ save_folder
/ file_end
)
if for_dataset:
file_name = self.output_dir / file_end
else:
file_name = (
self.output_dir
/ heliostat_catalog_id.split("-")[0]
/ save_folder
/ file_end
)
file_name.parent.mkdir(parents=True, exist_ok=True)
self.download_file(url, file_name)

Expand All @@ -321,6 +339,7 @@ def _download_heliostat_data(
start_date: Union[datetime, None] = None,
end_date: Union[datetime, None] = None,
filtered_calibration_keys: Union[list[str], None] = None,
for_dataset: bool = False,
) -> None:
"""
Download items from a specified collection from a heliostat and save them to a designated path.
Expand All @@ -343,6 +362,8 @@ def _download_heliostat_data(
List of keys to filter the calibration data. These keys must be one of: ``raw_image``, ``cropped_image``,
``flux_image``, ``flux_centered_image``, ``calibration_properties``. If no list is provided, all calibration
data is downloaded (Default: ``None``).
for_dataset : bool
Whether to save all items in one folder for a dataset (Default: ``False``).
"""
child = self.get_child(parent=heliostat_catalog, child_id=collection_id)
if child is None:
Expand All @@ -368,6 +389,7 @@ def _download_heliostat_data(
collection_id,
heliostat_catalog.id,
save_folder,
for_dataset,
)

@staticmethod
Expand Down Expand Up @@ -404,6 +426,7 @@ def get_heliostat_data(
start_date: Union[datetime, None] = None,
end_date: Union[datetime, None] = None,
filtered_calibration_keys: Union[list[str], None] = None,
for_dataset: bool = False,
) -> None:
"""
Download data for one or more heliostats.
Expand All @@ -426,6 +449,8 @@ def get_heliostat_data(
List of keys to filter the calibration data. These keys must be one of: ``raw_image``, ``cropped_image``,
``flux_image``, ``flux_centered_image``, ``calibration_properties``. If no list is provided, all calibration
data is downloaded (Default: ``None``).
for_dataset : bool
Whether to save all items in one folder for a dataset (Default: ``False``).
"""
# Check if keys provided to the filtered_calibration_key dictionary are acceptable.
if heliostats is None:
Expand Down Expand Up @@ -515,6 +540,7 @@ def get_heliostat_data(
start_date,
end_date,
filtered_calibration_keys,
for_dataset,
)

# Download deflectometry data.
Expand Down
4 changes: 2 additions & 2 deletions paint/preprocessing/binary_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def convert_to_h5(
_unused_facet_translation_vectors[f] = torch.tensor(
facet_header_data[1:4], dtype=torch.float
)
_unused_canting_n[f] = torch.tensor(
_unused_canting_e[f] = torch.tensor(
facet_header_data[4:7],
dtype=torch.float,
)
_unused_canting_e[f] = torch.tensor(
_unused_canting_n[f] = torch.tensor(
facet_header_data[7:10],
dtype=torch.float,
)
Expand Down
4 changes: 3 additions & 1 deletion paint/preprocessing/focal_spot_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self) -> None:
"""
self.flux: torch.Tensor = torch.empty(0) # Default to an empty tensor
self.aim_point_image: tuple[float, float] = (0.0, 0.0) # Default to (0.0, 0.0)
self.aim_point: torch.Tensor = torch.empty(3) # Default to an empty tensor of required shape
self.aim_point: torch.Tensor = torch.empty(
3
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
) # Default to an empty tensor of required shape

@classmethod
def from_flux(
Expand Down
8 changes: 8 additions & 0 deletions paint/util/paint_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,11 @@
SPLIT_KEY = "Split"
DISTANCE_WINTER = "distance_winter"
DISTANCE_SUMMER = "distance_summer"


# Keys for Dataset file identifiers.
CALIBRATION_RAW_IMAGE_IDENTIFIER = "_raw.png"
CALIBRATION_CROPPED_IMAGE_IDENTIFIER = "_cropped.png"
CALIBRATION_FLUX_IMAGE_IDENTIFIER = "_flux.png"
CALIBRATION_FLUX_CENTERED_IMAGE_IDENTIFIER = "_flux_centered.png"
CALIBRATION_PROPERTIES_IDENTIFIER = "-calibration-properties.json"
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"numpy",
"pandas",
"torch",
"torchvision",
"matplotlib",
"colorlog",
"wetterdienst==0.95.1",
Expand Down
118 changes: 118 additions & 0 deletions scripts/example_benchmark_dataset_full_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import argparse
import pathlib

from paint import PAINT_ROOT
from paint.data import StacClient
from paint.data.dataset import PaintCalibrationDataset
from paint.data.dataset_splits import DatasetSplitter
from paint.util import set_logger_config

set_logger_config()


if __name__ == "__main__":
"""
This script demonstrates the full workflow for loading a ``torch.Dataset`` based on a given benchmark, specifically
it includes:
- Downloading the necessary metadata to generate the dataset splits.
- Generating the dataset splits.
- Creating a ``torch.Dataset`` based on these splits, if necessary, downloading the appropriate data.
"""
# Read in arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"--metadata_input",
type=pathlib.Path,
help="File containing the metadata required to generate the dataset splits.",
default=f"{PAINT_ROOT}/metadata/calibration_metadata_all_heliostats.csv",
)
parser.add_argument(
"--output_dir",
type=pathlib.Path,
help="Root directory to save outputs.",
default=f"{PAINT_ROOT}/benchmarks",
)
parser.add_argument(
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
"--split_type",
type=str,
help="The split type to apply.",
default="azimuth",
)
parser.add_argument(
"--train_size",
type=int,
help="The size of the training split.",
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
default=10,
)
parser.add_argument(
"--val_size",
type=int,
help="The size of the validation split.",
default=30,
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--remove_unused_data",
type=bool,
help="Whether to remove unused data or not.",
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
default=True,
)
parser.add_argument(
"--item_type",
type=str,
help="The type of item to be loaded -- i.e. raw image, cropped image, flux image, or flux centered image",
default="calibration_properties",
)
args = parser.parse_args()

metadata_file = args.metadata_input

# Check if the metadata file has already been downloaded, if not download it.
if not metadata_file.exists():
# Create STAC client to download the metadata.
output_dir_for_stac = metadata_file.parent
client = StacClient(output_dir=output_dir_for_stac)
client.get_heliostat_metadata(heliostats=None)

# Set the correct folder to save the benchmark splits.
splits_output_dir = args.output_dir / "splits"
splitter = DatasetSplitter(
input_file=args.metadata_input,
output_dir=splits_output_dir,
remove_unused_data=args.remove_unused_data,
)

# Generate the splits, they will be saved automatically to the defined location.
_ = splitter.get_dataset_splits(
split_type=args.split_type,
training_size=args.train_size,
validation_size=args.val_size,
)

# Determine name to automatically load the splits into the dataset.
dataset_benchmark_file = (
splits_output_dir
/ f"benchmark_split-{args.split_type}_train-{args.train_size}_validation-{args.val_size}.csv"
)

# Set the correct folder for the dataset.
dataset_output_dir = (
args.output_dir
/ "datasets"
/ f"benchmark_split-{args.split_type}_train-{args.train_size}_validation-{args.val_size}"
/ args.item_type
)

# Determine whether to download the data or not:
# The first time this script is executed locally, the data must be downloaded, afterward no longer.
if dataset_output_dir.exists():
kalebphipps marked this conversation as resolved.
Show resolved Hide resolved
dataset_download = False
else:
dataset_download = True

# Initialize dataset from benchmark splits.
train, test, val = PaintCalibrationDataset.from_benchmark(
benchmark_file=dataset_benchmark_file,
root_dir=dataset_output_dir,
item_type=args.item_type,
download=dataset_download,
)
61 changes: 61 additions & 0 deletions scripts/example_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse

from paint import PAINT_ROOT
from paint.data.dataset import PaintCalibrationDataset
from paint.util import set_logger_config

set_logger_config()


if __name__ == "__main__":
"""
This is an example script demonstrating how the ``PaintCalibrationDataset`` can be used. This script shows the three
ways of creating a ``PaintCalibrationDataset``:
1.) Directly, when the data is already downloaded and a ``root_dir`` containing the data is available.
2.) From a benchmark file, containing information on variaous calibration IDs and what split they belong to. In this
case, both a ``benchmark_file`` and a ``root_dir`` for saving the data are required.
3.) From a single heliostats or list of heliostats. In this case, all calibration items for the provided heliostats
will be used to create a dataset. Here, a list of heliostats is required.
"""
# Read in arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"--item_type",
type=str,
help="The type of item to be loaded -- i.e. raw image, cropped image, flux image, or flux centered image",
default="calibration_properties",
)
args = parser.parse_args()

# Direct Creation:
# The root dir needs to be set to the directory containing the data.
direct_root_dir = f"{PAINT_ROOT}/benchmark_datasets/Benchmark_testy_test/test"
dataset = PaintCalibrationDataset(
root_dir=direct_root_dir,
item_ids=None,
item_type=args.item_type,
)

# From Benchmark:
# The benchmark file needs to be set to the path of the benchmark file and the root dir to the location the data
# should be saved.
benchmark_file = f"{PAINT_ROOT}/benchmark_test/benchmark_testy_test.csv"
benchmark_root_dir = f"{PAINT_ROOT}/dataset_benchmark_test"
train, test, val = PaintCalibrationDataset.from_benchmark(
benchmark_file=benchmark_file,
root_dir=benchmark_root_dir,
item_type=args.item_type,
download=True,
)

# From Heliostat:
# The list of heliostats must be provided (if ``None`` data for all heliostats will be downloaded) and the root dir
# must be selected to the location, the data should be saved.
heliostats = ["AA39", "AA23"]
heliostat_root_dir = f"{PAINT_ROOT}/dataset_heliostat_test"
heliostat_dataset = PaintCalibrationDataset.from_heliostats(
heliostats=heliostats,
root_dir=heliostat_root_dir,
item_type=args.item_type,
download=True,
)
Loading
Loading