Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a standalone detect checkpoint function #158

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 153 additions & 96 deletions thor/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,42 +120,86 @@ def create_checkpoint_data(stage: VALID_STAGES, **data) -> CheckpointData:
raise ValueError(f"Invalid stage: {stage}")


def detect_checkpoint_stage(test_orbit_directory: pathlib.Path) -> VALID_STAGES:
"""
Looks for existing files and indicates the next stage to run
"""
if not test_orbit_directory.is_dir():
raise ValueError(f"{test_orbit_directory} is not a directory")

if not test_orbit_directory.exists():
logger.info(f"Working directory does not exist, starting at beginning.")
return "filter_observations"

if not (test_orbit_directory / "filtered_observations.parquet").exists():
logger.info(
"No filtered observations found, starting stage filter_observations"
)
return "filter_observations"

if (test_orbit_directory / "recovered_orbits.parquet").exists() and (
test_orbit_directory / "recovered_orbit_members.parquet"
).exists():
logger.info("Found recovered orbits, pipeline is complete.")
return "complete"

if (test_orbit_directory / "od_orbits.parquet").exists() and (
test_orbit_directory / "od_orbit_members.parquet"
).exists():
logger.info("Found OD orbits, starting stage recover_orbits")
return "recover_orbits"

if (test_orbit_directory / "iod_orbits.parquet").exists() and (
test_orbit_directory / "iod_orbit_members.parquet"
).exists():
logger.info("Found IOD orbits, starting stage differential_correction")
return "differential_correction"

if (test_orbit_directory / "clusters.parquet").exists() and (
test_orbit_directory / "cluster_members.parquet"
).exists():
logger.info("Found clusters, starting stage initial_orbit_determination")
return "initial_orbit_determination"

if (test_orbit_directory / "transformed_detections.parquet").exists():
logger.info("Found transformed detections, starting stage cluster_and_link")
return "cluster_and_link"

if (test_orbit_directory / "filtered_observations.parquet").exists():
logger.info("Found filtered observations, starting stage range_and_transform")
return "range_and_transform"

raise ValueError(f"Could not detect stage from {test_orbit_directory}")


def load_initial_checkpoint_values(
test_orbit_directory: Optional[pathlib.Path] = None,
test_orbit_directory: Optional[Union[pathlib.Path, str]] = None,
) -> CheckpointData:
"""
Check for completed stages and return values from disk if they exist.

We want to avoid loading objects into memory that are not required.
"""
stage: VALID_STAGES = "filter_observations"
# Without a checkpoint directory, we always start at the beginning
if isinstance(test_orbit_directory, str):
test_orbit_directory = pathlib.Path(test_orbit_directory)

if test_orbit_directory is None:
return create_checkpoint_data(stage)
logger.info("Not using a workign directory, start at beginning.")
return create_checkpoint_data("filter_observations")

# filtered_observations is always needed when it exists
filtered_observations_path = pathlib.Path(
test_orbit_directory, "filtered_observations.parquet"
)
# If it doesn't exist, start at the beginning.
if not filtered_observations_path.exists():
return create_checkpoint_data(stage)
logger.info("Found filtered observations")
filtered_observations = Observations.from_parquet(filtered_observations_path)
stage: VALID_STAGES = detect_checkpoint_stage(test_orbit_directory)

if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)
if stage == "filter_observations":
return create_checkpoint_data(stage)

# If the pipeline was started but we have recovered_orbits already, we
# are done and should exit early.
recovered_orbits_path = pathlib.Path(
test_orbit_directory, "recovered_orbits.parquet"
)
recovered_orbit_members_path = pathlib.Path(
test_orbit_directory, "recovered_orbit_members.parquet"
)
if recovered_orbits_path.exists() and recovered_orbit_members_path.exists():
logger.info("Found recovered orbits in checkpoint")
# If we've already completed the pipeline, we can load the recovered orbits
if stage == "complete":
recovered_orbits_path = pathlib.Path(
test_orbit_directory, "recovered_orbits.parquet"
)
recovered_orbit_members_path = pathlib.Path(
test_orbit_directory, "recovered_orbit_members.parquet"
)
recovered_orbits = FittedOrbits.from_parquet(recovered_orbits_path)
recovered_orbit_members = FittedOrbitMembers.from_parquet(
recovered_orbit_members_path
Expand All @@ -167,88 +211,101 @@ def load_initial_checkpoint_values(
recovered_orbit_members = qv.defragment(recovered_orbit_members)

return create_checkpoint_data(
"complete",
stage,
recovered_orbits=recovered_orbits,
recovered_orbit_members=recovered_orbit_members,
)

# Now with filtered_observations available, we can check for the later
# stages in reverse order.
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
od_orbit_members_path = pathlib.Path(
test_orbit_directory, "od_orbit_members.parquet"
)
if od_orbits_path.exists() and od_orbit_members_path.exists():
logger.info("Found OD orbits in checkpoint")
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)

if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)

return create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
od_orbits=od_orbits,
od_orbit_members=od_orbit_members,
)

iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
iod_orbit_members_path = pathlib.Path(
test_orbit_directory, "iod_orbit_members.parquet"
# Filtered observations are required for all other stages
filtered_observations_path = pathlib.Path(
test_orbit_directory, "filtered_observations.parquet"
)
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
logger.info("Found IOD orbits")
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)

if iod_orbits.fragmented():
iod_orbits = qv.defragment(iod_orbits)
if iod_orbit_members.fragmented():
iod_orbit_members = qv.defragment(iod_orbit_members)
filtered_observations = Observations.from_parquet(filtered_observations_path)
if filtered_observations.fragmented():
filtered_observations = qv.defragment(filtered_observations)

return create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
iod_orbits=iod_orbits,
iod_orbit_members=iod_orbit_members,
if stage == "recover_orbits":
od_orbits_path = pathlib.Path(test_orbit_directory, "od_orbits.parquet")
od_orbit_members_path = pathlib.Path(
test_orbit_directory, "od_orbit_members.parquet"
)

clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
cluster_members_path = pathlib.Path(test_orbit_directory, "cluster_members.parquet")
if clusters_path.exists() and cluster_members_path.exists():
logger.info("Found clusters")
clusters = Clusters.from_parquet(clusters_path)
cluster_members = ClusterMembers.from_parquet(cluster_members_path)

if clusters.fragmented():
clusters = qv.defragment(clusters)
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

return create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
clusters=clusters,
cluster_members=cluster_members,
if od_orbits_path.exists() and od_orbit_members_path.exists():
logger.info("Found OD orbits in checkpoint")
od_orbits = FittedOrbits.from_parquet(od_orbits_path)
od_orbit_members = FittedOrbitMembers.from_parquet(od_orbit_members_path)

if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)

return create_checkpoint_data(
"recover_orbits",
filtered_observations=filtered_observations,
od_orbits=od_orbits,
od_orbit_members=od_orbit_members,
)

if stage == "differential_correction":
iod_orbits_path = pathlib.Path(test_orbit_directory, "iod_orbits.parquet")
iod_orbit_members_path = pathlib.Path(
test_orbit_directory, "iod_orbit_members.parquet"
)

transformed_detections_path = pathlib.Path(
test_orbit_directory, "transformed_detections.parquet"
)
if transformed_detections_path.exists():
logger.info("Found transformed detections")
transformed_detections = TransformedDetections.from_parquet(
transformed_detections_path
if iod_orbits_path.exists() and iod_orbit_members_path.exists():
logger.info("Found IOD orbits")
iod_orbits = FittedOrbits.from_parquet(iod_orbits_path)
iod_orbit_members = FittedOrbitMembers.from_parquet(iod_orbit_members_path)

if iod_orbits.fragmented():
iod_orbits = qv.defragment(iod_orbits)
if iod_orbit_members.fragmented():
iod_orbit_members = qv.defragment(iod_orbit_members)

return create_checkpoint_data(
"differential_correction",
filtered_observations=filtered_observations,
iod_orbits=iod_orbits,
iod_orbit_members=iod_orbit_members,
)

if stage == "initial_orbit_determination":
clusters_path = pathlib.Path(test_orbit_directory, "clusters.parquet")
cluster_members_path = pathlib.Path(
test_orbit_directory, "cluster_members.parquet"
)

return create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
if clusters_path.exists() and cluster_members_path.exists():
logger.info("Found clusters")
clusters = Clusters.from_parquet(clusters_path)
cluster_members = ClusterMembers.from_parquet(cluster_members_path)

if clusters.fragmented():
clusters = qv.defragment(clusters)
if cluster_members.fragmented():
cluster_members = qv.defragment(cluster_members)

return create_checkpoint_data(
"initial_orbit_determination",
filtered_observations=filtered_observations,
clusters=clusters,
cluster_members=cluster_members,
)

if stage == "cluster_and_link":
transformed_detections_path = pathlib.Path(
test_orbit_directory, "transformed_detections.parquet"
)
if transformed_detections_path.exists():
logger.info("Found transformed detections")
transformed_detections = TransformedDetections.from_parquet(
transformed_detections_path
)

return create_checkpoint_data(
"cluster_and_link",
filtered_observations=filtered_observations,
transformed_detections=transformed_detections,
)

return create_checkpoint_data(
"range_and_transform", filtered_observations=filtered_observations
Expand Down
Loading