diff --git a/paint/preprocessing/focal_spot_extractor.py b/paint/preprocessing/focal_spot_extractor.py new file mode 100644 index 00000000..0e7787e0 --- /dev/null +++ b/paint/preprocessing/focal_spot_extractor.py @@ -0,0 +1,291 @@ +import json +from pathlib import Path +from typing import Union + +import cv2 +import torch + +import paint.util.paint_mappings as mappings + + +class FocalSpot: + """ + Data class for storing focal spot information. + + Attributes + ---------- + flux : torch.Tensor + Intensity image of the detected focal spot. + aim_point_image : tuple[float, float] + Detected center of intensity in image coordinates (height, width). + aim_point : torch.Tensor + Detected center of intensity in global coordinates (E, N, U). + """ + + def __init__(self) -> None: + """ + Private initializer to ensure objects are only created using ``from_flux()`` or ``load()``. + + Initializes attributes with default placeholder values. + """ + self.flux: torch.Tensor = torch.empty(0) # Default to an empty tensor + self.aim_point_image: tuple[float, float] = (0.0, 0.0) # Default to (0.0, 0.0) + self.aim_point: torch.Tensor = torch.empty(3) # Default to an empty tensor of required shape + + @classmethod + def from_flux( + cls, + flux: torch.Tensor, + aim_point_image: tuple[float, float], + aim_point: torch.Tensor, + ) -> "FocalSpot": + """ + Create a ``FocalSpot`` object from provided data. + + Parameters + ---------- + flux : torch.Tensor + Intensity image of the detected focal spot. + aim_point_image : tuple[float, float] + Detected center of intensity in image coordinates (height, width). + aim_point : torch.Tensor + Detected center of intensity in global coordinates (E, N, U). + + Returns + ------- + FocalSpot + Initialized ``FocalSpot`` object. + """ + instance = cls() + instance.flux = flux + instance.aim_point_image = aim_point_image + instance.aim_point = aim_point + return instance + + @classmethod + def load(cls, file_path: Union[str, Path]) -> "FocalSpot": + """ + Load a ``FocalSpot`` object from disk. + + Parameters + ---------- + file_path : Union[str, Path] + Base path for the focal spot files (without extensions). + + Returns + ------- + FocalSpot + Loaded ``FocalSpot`` object. + """ + base_path = Path(file_path) + flux_path = base_path.with_name(f"{base_path.name}_flux.png") + metadata_path = base_path.with_name(f"{base_path.name}_metadata.json") + + if not flux_path.exists() or not metadata_path.exists(): + raise FileNotFoundError( + f"Missing required files: {flux_path}, {metadata_path}" + ) + + # Load flux as a grayscale image and convert to tensor + flux_image = cv2.imread(str(flux_path), cv2.IMREAD_GRAYSCALE) + if flux_image is None: + raise ValueError("Flux PNG file is missing or corrupted.") + flux = torch.tensor(flux_image, dtype=torch.float32) / 255.0 + + # Load metadata + with open(metadata_path, "r") as json_file: + metadata = json.load(json_file) + + aim_point_image = tuple(metadata["aim_point_image"]) + aim_point = torch.tensor(metadata["aim_point"]) + + # Create the FocalSpot instance + return cls.from_flux(flux, aim_point_image, aim_point) + + def save(self, save_path: Path) -> None: + """ + Save the focal spot object to disk. + + Parameters + ---------- + save_path : Path + Base path (without extension) to save the focal spot data. + """ + # Ensure all fields are initialized + if self.flux is None or self.aim_point_image is None or self.aim_point is None: + raise ValueError("FocalSpot object is not fully initialized.") + + # Normalize flux to the range [0, 255] for saving as PNG + normalized_flux = (self.flux / self.flux.max() * 255).clamp(0, 255).byte() + flux_image = normalized_flux.numpy() + + # Save flux as a single-channel grayscale PNG + cv2.imwrite(f"{save_path}_flux.png", flux_image) + + # Save metadata to JSON + metadata = { + "aim_point_image": self.aim_point_image, + "aim_point": self.aim_point.tolist(), + } + with open(f"{save_path}_metadata.json", "w") as json_file: + json.dump(metadata, json_file, indent=4) + + +def get_marker_coordinates(target: Union[str, int]) -> tuple[torch.Tensor, ...]: + """ + Get the specific marker coordinates based on the target. + + Parameters + ---------- + target : Union[str, int] + Target ID used to determine which marker coordinates should be used. + + Returns + ------- + tuple[torch.Tensor, ...] + Coordinates for the markers from the considered target. + """ + # Select markers based on the target identifier. + # These markers are predefined in `paint.util.paint_mappings`. + if target in {1, 7, mappings.STJ_LOWER}: + markers = mappings.STJ_LOWER_ENU + elif target in {4, 5, 6, mappings.STJ_UPPER}: + markers = mappings.STJ_UPPER_ENU + elif target in {3, mappings.MFT}: + markers = mappings.MFT_ENU + else: + # Raise an error if the target is not recognized. + raise ValueError(f"Unsupported target value: {target}") + + # Convert marker coordinates to PyTorch tensors for further computation. + return tuple(torch.tensor(marker) for marker in markers) + + +def convert_xy_to_enu( + aim_point_image: tuple[float, float], target: Union[str, int] +) -> torch.Tensor: + """ + Convert (x, y) coordinates to an aimpoint in the target image based on marker positions. + + Parameters + ---------- + aim_point_image : tuple[float, float] + Relative aim point coordinates (x, y). + target : Union[str, int] + The target object containing marker information. + + Returns + ------- + torch.Tensor + Aimpoint coordinates in the target image. + """ + # Retrieve the marker positions for the given target. + ( + marker_left_top, + marker_left_bottom, + marker_right_top, + marker_right_bottom, + ) = get_marker_coordinates(target) + + # Compute the horizontal directional vector (dx). + # This vector spans the distance between the left and right markers. + dx = 0.5 * (marker_right_top - marker_left_top) + 0.5 * ( + marker_right_bottom - marker_left_bottom + ) + + # Compute the vertical directional vector (dy). + # This vector spans the distance between the top and bottom markers. + dy = 0.5 * (marker_left_bottom - marker_left_top) + 0.5 * ( + marker_right_bottom - marker_right_top + ) + + # Calculate the aimpoint in global coordinates. + # The aimpoint is a combination of the top-left marker position and a weighted + # sum of dx and dy based on the relative aim_point_image coordinates. + return marker_left_top + aim_point_image[0] * dx + aim_point_image[1] * dy + + +def compute_center_of_intensity( + flux: torch.Tensor, threshold: float = 0.0 +) -> tuple[float, float]: + """ + Calculate the center of intensity (x, y) coordinates in a 2D flux image. + + Parameters + ---------- + flux : torch.Tensor + A 2D tensor representing the flux/intensity image. + threshold : float, optional + Minimum intensity value for pixels to be included in the calculation (Default: `0.0`). + + Returns + ------- + float + x coordinate of the center of intensity. + float + y coordinate of the center of intensity. + """ + height, width = flux.shape + + # Threshold the flux values. Any values below the threshold are set to zero. + flux_thresholded = torch.where(flux >= threshold, flux, torch.zeros_like(flux)) + total_intensity = flux_thresholded.sum() + + # If all values are below the threshold, default to the center of the image. + if total_intensity == 0: + return 0.5, 0.5 + + # Generate normalized x and y coordinates adjusted for pixel centers. + # The "+ 0.5" adjustment ensures coordinates are centered within each pixel. + x_coords = (torch.arange(width, dtype=torch.float32) + 0.5) / width + y_coords = (torch.arange(height, dtype=torch.float32) + 0.5) / height + + # Compute the center of intensity using weighted sums of the coordinates. + x_center = (flux_thresholded.sum(dim=0) * x_coords).sum() / total_intensity + y_center = (flux_thresholded.sum(dim=1) * y_coords).sum() / total_intensity + + return x_center.item(), y_center.item() + + +def detect_focal_spot( + image: torch.Tensor, + target: Union[str, int], + utis_model: torch.nn.Module, +) -> FocalSpot: + """ + Detect a focal spot within an image using UTIS focal spot segmentation. + + Parameters + ---------- + image : torch.Tensor + Cropped image containing the focal spot to be detected, with the shape (height, width). + target : Union[str, int] + Target center and dimensions used for aimpoint calculation. + utis_model : torch.nn.Module + UTIS model checkpoint used to compute the flux. + + Returns + ------- + FocalSpot + Detected focal spot with flux, image coordinates, and global aimpoint. + """ + if image.dim() != 2: + raise ValueError(f"Expected a 2D tensor, got {image.dim()} dimensions.") + + # Pass the input image through the UTIS model to generate a flux image. + # The input image is unsqueezed to add batch and channel dimensions. + # The output flux is extracted from the first channel of the result. + flux = utis_model(image.unsqueeze(0).unsqueeze(0))[0, 0].detach() + + # Compute the center of intensity in image coordinates. + aim_point_image = compute_center_of_intensity(flux) + + # Convert the center of intensity to global ENU coordinates using marker data. + aimpoint_global = convert_xy_to_enu(aim_point_image, target) + + # Return a FocalSpot object containing the flux, image coordinates, and global aimpoint. + return FocalSpot.from_flux( + flux=flux, + aim_point_image=aim_point_image, + aim_point=aimpoint_global, + ) diff --git a/paint/util/paint_mappings.py b/paint/util/paint_mappings.py index d5025043..e7b8b8d6 100644 --- a/paint/util/paint_mappings.py +++ b/paint/util/paint_mappings.py @@ -84,6 +84,12 @@ 6.387886052532388, 133.719, ] +STJ_UPPER_ENU = [ + [4.32, -3.23, 46.70], + [4.32, -3.23, 39.48], + [-4.30, -3.23, 46.70], + [-4.30, -3.23, 39.48], +] STJ_LOWER_COORDINATES = [ [50.913391839040266, 6.38788603808917, 119.268], [50.91339186595943, 6.387886052532388, 126.47], @@ -99,6 +105,12 @@ 6.387886052532388, 126.506, ] +STJ_LOWER_ENU = [ + [4.33, -3.23, 39.49], + [4.33, -3.23, 32.27], + [-4.29, -3.23, 39.49], + [-4.29, -3.23, 32.27], +] MFT_COORDINATES = [ [50.91339634341573, 6.387612841591359, 135.789], [50.91339628900999, 6.387612983329584, 142.175], @@ -114,6 +126,12 @@ 6.387612983329584, 142.175, ] +MFT_ENU = [ + [-14.88, -2.83, 55.17], + [-14.88, -2.83, 48.78], + [-20.29, -2.83, 55.17], + [-20.29, -2.83, 48.78], +] RECEIVER_COORDINATES = [ [50.913406544144294, 6.387853925842858, 139.86], [50.91342645401072, 6.387854205350705, 144.592], @@ -242,6 +260,10 @@ FOCAL_SPOT_KEY = "focal_spot" UTIS_KEY = "UTIS" UTIS_URL = "https://github.com/DLR-SF/UTIS-HeliostatBeamCharacterization" +UTIS_MODEL_CHECKPOINT = ( + "https://github.com/DLR-SF/UTIS-HeliostatBeamCharacterization/" + "raw/main/trained_models/utis_model_scripted.pt" +) HELIOS_KEY = "HeliOS" TARGET_OFFSET_E = "TargetOffsetE" TARGET_OFFSET_N = "TargetOffsetN" diff --git a/preprocessing-scripts/copy_processed_images_to_correct_location.py b/preprocessing-scripts/copy_processed_images_to_correct_location.py old mode 100644 new mode 100755 diff --git a/preprocessing-scripts/crop_image_and_detect_aimpoint.py b/preprocessing-scripts/crop_image_and_detect_aimpoint.py new file mode 100755 index 00000000..8e2080fa --- /dev/null +++ b/preprocessing-scripts/crop_image_and_detect_aimpoint.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +import argparse +import json +import tempfile +from pathlib import Path + +import requests +import torch + +import paint.util.paint_mappings as mappings +from paint import PAINT_ROOT +from paint.data.stac_client import StacClient +from paint.preprocessing.focal_spot_extractor import detect_focal_spot +from paint.preprocessing.target_cropper import crop_image_with_template_matching + + +def load_model_from_url(url: str) -> torch.jit.ScriptModule: + """ + Download the model checkpoint from the given URL and load it directly without storing it permanently. + + Parameters + ---------- + url : str + URL of the model checkpoint. + + Returns + ------- + torch.jit.ScriptModule + Loaded PyTorch model. + + Raises + ------ + Exception + If the model cannot be loaded or the download fails. + """ + print(f"Downloading checkpoint from {url}...") + response = requests.get(url, stream=True) + if response.status_code == 200: + print("Checkpoint downloaded successfully. Loading the model...") + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + for chunk in response.iter_content(chunk_size=8192): + temp_file.write(chunk) + temp_file.flush() + try: + model = torch.jit.load(temp_file.name, map_location="cpu") + print("Model loaded successfully.") + return model + except Exception as e: + print(f"Failed to load the model: {e}") + raise + else: + print(f"Failed to download the checkpoint. Status code: {response.status_code}") + response.raise_for_status() + + +def main(args: argparse.Namespace) -> None: + """ + Process calibration images and extract focal spots using UTIS. + + This function downloads and processes heliostat calibration data before loading the UTIS model from a given URL. + It then applies UTIS to detect focal spots for specified heliostats and measurement IDs. + """ + # Base directory for all outputs + output_dir = Path(args.output_dir) + + # Subdirectories for specific types of outputs + downloaded_data_dir = output_dir / "downloaded_data" + extracted_focal_spots_dir = output_dir / "extracted_focal_spots" + + # Create STAC client and download heliostat data. + client = StacClient(output_dir=downloaded_data_dir) + + # Load the UTIS model directly from the URL. + loaded_model = load_model_from_url(mappings.UTIS_MODEL_CHECKPOINT) + + # Load single measurement. + client.get_single_calibration_item_by_id( + heliostat_id=args.heliostat, + item_id=args.measurement_id, + filtered_calibration_keys=args.filtered_calibration, + ) + + # Define paths for current measurement. + image_path = ( + downloaded_data_dir + / args.heliostat + / mappings.SAVE_CALIBRATION + / f"{args.measurement_id}_raw.png" + ) + calibration_properties_path = ( + downloaded_data_dir + / args.heliostat + / mappings.SAVE_CALIBRATION + / f"{args.measurement_id}-calibration-properties.json" + ) + + # Extract target from calibration properties. + with open(calibration_properties_path, "r") as file: + calibration_data = json.load(file) + target = calibration_data[mappings.TARGET_NAME_KEY] + + # Crop the image using template matching. + cropped_image = crop_image_with_template_matching(image_path, target, n_grid=256) + + # Convert to grayscale by using only green color channel and convert to torch tensor. + cropped_image_tensor = ( + torch.tensor(cropped_image[:, :, 1], dtype=torch.float32) / 255.0 + ) + + # Detect focal spot. + focal_spot = detect_focal_spot(cropped_image_tensor, target, loaded_model) + + # Log the detected aim point. + print( + "Heliostat %s: For measurement %s, focal spot detected at %s.", + args.heliostat, + args.measurement_id, + focal_spot.aim_point.tolist(), + ) + + # Save focal spot. + focal_spot_path = ( + extracted_focal_spots_dir / f"{args.heliostat}_{args.measurement_id}_focal_spot" + ) + focal_spot_path.parent.mkdir(parents=True, exist_ok=True) + focal_spot.save(focal_spot_path) + + +if __name__ == "__main__": + # Parse command-line arguments. + parser = argparse.ArgumentParser( + description="Process calibration data for heliostats." + ) + parser.add_argument( + "--output_dir", + type=str, + help="Base path to save outputs. Downloaded data and extracted focal spots will be saved in separate subdirectories.", + default=f"{PAINT_ROOT}/outputs", + ) + parser.add_argument( + "--heliostat", + type=str, + help="Heliostat to be downloaded.", + nargs="+", + default="AA23", + ) + parser.add_argument( + "--collections", + type=str, + help="List of collections to be downloaded.", + nargs="+", + default=["calibration"], + ) + parser.add_argument( + "--filtered_calibration", + type=str, + help="List of calibration items to download.", + nargs="+", + default=["raw_image", "calibration_properties"], + ) + parser.add_argument( + "--measurement_id", + type=int, + help="ID of the measurement to apply the algorithm.", + nargs="+", + default=123819, + ) + arguments = parser.parse_args() + main(arguments) diff --git a/tests/preprocessing/test_data/utis_model_scripted.pt b/tests/preprocessing/test_data/utis_model_scripted.pt new file mode 100644 index 00000000..4b3db015 Binary files /dev/null and b/tests/preprocessing/test_data/utis_model_scripted.pt differ diff --git a/tests/preprocessing/test_focal_spot_extractor.py b/tests/preprocessing/test_focal_spot_extractor.py new file mode 100644 index 00000000..3b0dab66 --- /dev/null +++ b/tests/preprocessing/test_focal_spot_extractor.py @@ -0,0 +1,179 @@ +from pathlib import Path + +import cv2 +import pytest +import torch + +from paint.preprocessing.focal_spot_extractor import ( + FocalSpot, + compute_center_of_intensity, + detect_focal_spot, + get_marker_coordinates, +) +from paint.util.paint_mappings import MFT, STJ_LOWER, STJ_UPPER + + +@pytest.fixture(scope="module") +def sample_focal_spot() -> FocalSpot: + """ + Create a sample ``FocalSpot`` object for testing purposes. + + Returns + ------- + FocalSpot + Sample ``FocalSpot`` object. + """ + flux = torch.ones(256, 256) + aim_point_image = (0.5, 0.5) + aim_point = torch.tensor([1.0, 1.0, 1.0]) + return FocalSpot.from_flux(flux, aim_point_image, aim_point) + + +@pytest.fixture(scope="module") +def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + """ + Create a temporary directory for testing file I/O. + + Returns + ------- + Path + Temporary directory for testing purposes. + """ + return tmp_path_factory.mktemp("temp_dir") + + +@pytest.fixture(scope="module") +def utis_model() -> torch.jit.ScriptModule: + """ + Load the UTIS model once for all test cases. + + Returns + ------- + torch.jit.ScriptModule + The loaded UTIS model. + ------- + """ + return torch.jit.load( + "tests/preprocessing/test_data/utis_model_scripted.pt", map_location="cpu" + ) + + +def test_focal_spot_save_and_load(sample_focal_spot: FocalSpot, temp_dir: Path) -> None: + """ + Test saving and loading the ``FocalSpot`` object. + + Parameters + ---------- + sample_focal_spot : FocalSpot + Sample ``FocalSpot`` object. + temp_dir : Path + Temporary directory for testing purposes. + """ + save_path = temp_dir / "test_focal_spot" + + # Save the focal spot. + sample_focal_spot.save(save_path) + + # Check if files exist. + assert (save_path.with_name(f"{save_path.name}_flux.png")).exists() + assert (save_path.with_name(f"{save_path.name}_metadata.json")).exists() + + # Load the focal spot. + loaded_focal_spot = FocalSpot.load(save_path) + + # Validate loaded data. + torch.testing.assert_close(loaded_focal_spot.flux, sample_focal_spot.flux) + assert loaded_focal_spot.aim_point_image == sample_focal_spot.aim_point_image + torch.testing.assert_close(loaded_focal_spot.aim_point, sample_focal_spot.aim_point) + + +@pytest.mark.parametrize( + "flux, threshold, expected_center", + [ + (torch.zeros(10, 10), 0.0, (0.5, 0.5)), # No intensity defaults to center + ], +) +def test_compute_center_of_intensity( + flux: torch.Tensor, threshold: float, expected_center: tuple[float, float] +) -> None: + """ + Test ``compute_center_of_intensity`` for various flux distributions. + + Parameters + ---------- + flux : torch.Tensor + A 2D tensor representing the flux/intensity image. + threshold : float, optional + Minimum intensity value for pixels to be included in the calculation (Default: `0.0`). + expected_center : tuple[float, float] + Expected center of intensity. + """ + assert compute_center_of_intensity(flux, threshold) == pytest.approx( + expected_center, rel=1e-2 + ) + + +def test_get_marker_coordinates_invalid_target() -> None: + """Test get_marker_coordinates raises ValueError for invalid targets.""" + with pytest.raises(ValueError, match="Unsupported target value: INVALID_TARGET"): + get_marker_coordinates("INVALID_TARGET") + + +@pytest.mark.parametrize( + "image_path, target, expected_aimpoint", + [ + ( + Path("tests/preprocessing/test_data/stj_lower.png"), + STJ_LOWER, + torch.tensor([0.9369, -3.2300, 36.4410]), + ), + ( + Path("tests/preprocessing/test_data/stj_upper.png"), + STJ_UPPER, + torch.tensor([0.0996, -3.2300, 43.0595]), + ), + ( + Path("tests/preprocessing/test_data/mft_crop.png"), + MFT, + torch.tensor([-17.5122, -2.8299, 52.0571]), + ), + ], +) +def test_detect_focal_spot( + image_path: Path, + target: str, + expected_aimpoint: torch.Tensor, + utis_model: torch.jit.ScriptModule, +) -> None: + """ + Test ``detect_focal_spot`` end-to-end using a loaded UTIS model and test images. + + Parameters + ---------- + image_path : Path + Path to test image to load. + target : str + Target to consider. + expected_aimpoint : torch.Tensor + Expected aimpoint. + utis_model : torch.jit.ScriptModule + Loaded UTIS model. + """ + # Load the image. + image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) + assert image is not None, f"Image not found at {image_path}" + image_tensor = torch.tensor(image, dtype=torch.float32) / 255.0 + + # Perform focal spot detection. + focal_spot = detect_focal_spot(image_tensor, target, utis_model=utis_model) + + # Validate the focal spot. + assert isinstance(focal_spot, FocalSpot) + assert focal_spot.flux.shape == image_tensor.shape + assert isinstance(focal_spot.aim_point_image, tuple) + assert isinstance(focal_spot.aim_point, torch.Tensor) + + # Check if the detected aimpoint matches the expected aimpoint. + torch.testing.assert_close( + focal_spot.aim_point, expected_aimpoint, atol=1e-3, rtol=1e-3 + )