Skip to content

Commit

Permalink
Initialize ray from adam_core
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian committed Nov 22, 2023
1 parent f4d389b commit 7a797f4
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 93 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ venv/
ENV/
env.bak/
venv.bak/
.python-version

# Spyder project settings
.spyderproject
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ setup_requires =
wheel
setuptools_scm >= 6.0
install_requires =
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@main
adam-core @ git+https://github.com/B612-Asteroid-Institute/adam_core@75a9cb077870c557a628101ca05e8ed57659ee97#egg=adam_core
astropy >= 5.3.1
astroquery
difi
Expand All @@ -40,6 +40,7 @@ install_requires =
numpy
numba
pandas
psutil
pyarrow >= 14.0.0
pydantic < 2.0.0
pyyaml >= 5.1
Expand Down
17 changes: 3 additions & 14 deletions thor/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import quivr as qv
import ray
from adam_core.propagator import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray

from .range_and_transform import TransformedDetections

Expand Down Expand Up @@ -507,7 +508,6 @@ def cluster_velocity(
if len(clusters) == 0:
return Clusters.empty(), ClusterMembers.empty()
else:

cluster_ids = []
cluster_num_obs = []
cluster_members_cluster_ids = []
Expand Down Expand Up @@ -633,9 +633,6 @@ def cluster_and_link(
Algorithm to use. Can be "dbscan" or "hotspot_2d".
num_jobs : int, optional
Number of jobs to launch.
parallel_backend : str, optional
Which parallelization backend to use {'ray', 'mp', 'cf'}.
Defaults to using Python's concurrent futures module ('cf').
Returns
-------
Expand Down Expand Up @@ -691,14 +688,8 @@ def cluster_and_link(
mjd0 = mjd[first][0]
dt = mjd - mjd0

if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
# Put all arrays (which can be large) in ray's
# local object store ahead of time
obs_ids_ref = ray.put(obs_ids)
Expand Down Expand Up @@ -742,7 +733,6 @@ def cluster_and_link(
)

else:

for vxi_chunk, vyi_chunk in zip(
_iterate_chunks(vxx, chunk_size), _iterate_chunks(vyy, chunk_size)
):
Expand Down Expand Up @@ -776,7 +766,6 @@ def cluster_and_link(
)

else:

clusters = Clusters.empty()
cluster_members = ClusterMembers.empty()

Expand Down
3 changes: 2 additions & 1 deletion thor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

class Config(BaseModel):
max_processes: Optional[int] = None
propagator: Literal["PYOORB"] = "PYOORB"
ray_memory_bytes: int = 0
propagator: str = "PYOORB"
cell_radius: float = 10
vx_min: float = -0.1
vx_max: float = 0.1
Expand Down
19 changes: 4 additions & 15 deletions thor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import quivr as qv
import ray
from adam_core.propagator import PYOORB
from adam_core.ray_cluster import initialize_use_ray

from .checkpointing import create_checkpoint_data, load_initial_checkpoint_values
from .clusters import cluster_and_link
Expand All @@ -25,20 +26,6 @@
logger = logging.getLogger("thor")


def initialize_use_ray(config: Config) -> bool:
use_ray = False
if config.max_processes is None or config.max_processes > 1:
# Initialize ray
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {config.max_processes} cpus..."
)
ray.init(num_cpus=config.max_processes)

use_ray = True
return use_ray


def initialize_test_orbit(
test_orbit: TestOrbits,
working_dir: Optional[str] = None,
Expand Down Expand Up @@ -132,7 +119,9 @@ def link_test_orbit(
else:
raise ValueError(f"Unknown propagator: {config.propagator}")

use_ray = initialize_use_ray(config)
use_ray = initialize_use_ray(
num_cpus=config.max_processes, object_store_bytes=config.ray_memory_bytes
)

refs_to_free = []
if (
Expand Down
10 changes: 3 additions & 7 deletions thor/observations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import quivr as qv
import ray
from adam_core.coordinates import SphericalCoordinates
from adam_core.ray_cluster import initialize_use_ray

from thor.config import Config
from thor.observations.observations import Observations
Expand Down Expand Up @@ -160,13 +161,8 @@ def apply(
ephemeris = test_orbit.generate_ephemeris_from_observations(observations)

filtered_observations_list = []
if max_processes is None or max_processes > 1:
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if observations_ref is None:
observations_ref = ray.put(observations)
Expand Down
15 changes: 10 additions & 5 deletions thor/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@


class RangedPointSourceDetections(qv.Table):

id = qv.StringColumn()
exposure_id = qv.StringColumn()
coordinates = SphericalCoordinates.as_column()
state_id = qv.Int64Column()


class TestOrbitEphemeris(qv.Table):

id = qv.Int64Column()
ephemeris = Ephemeris.as_column()
observer = Observers.as_column()
Expand Down Expand Up @@ -97,7 +95,6 @@ def range_observations_worker(


class TestOrbits(qv.Table):

orbit_id = qv.StringColumn(default=lambda: uuid.uuid4().hex)
object_id = qv.StringColumn(nullable=True)
bundle_id = qv.Int64Column(nullable=True)
Expand Down Expand Up @@ -199,7 +196,11 @@ def propagate(
The test orbit propagated to the given times.
"""
return propagator.propagate_orbits(
self.to_orbits(), times, max_processes=max_processes, chunk_size=1
self.to_orbits(),
times,
max_processes=max_processes,
chunk_size=1,
parallel_backend="ray",
)

def generate_ephemeris(
Expand All @@ -226,7 +227,11 @@ def generate_ephemeris(
The ephemeris of the test orbit at the given observers.
"""
return propagator.generate_ephemeris(
self.to_orbits(), observers, max_processes=max_processes, chunk_size=1
self.to_orbits(),
observers,
max_processes=max_processes,
chunk_size=1,
parallel_backend="ray",
)

def generate_ephemeris_from_observations(
Expand Down
19 changes: 4 additions & 15 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from adam_core.orbits import Orbits
from adam_core.propagator import PYOORB
from adam_core.propagator.utils import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray
from sklearn.neighbors import BallTree

from ..observations.observations import Observations
Expand Down Expand Up @@ -165,7 +166,6 @@ def attribution_worker(
radius_rad = np.radians(radius)
residuals = []
for _, ephemeris_i, observations_i in linkage.iterate():

# Extract the observation IDs and times
obs_ids = observations_i.id.to_numpy(zero_copy_only=False)
obs_times = observations_i.coordinates.time.mjd().to_numpy(zero_copy_only=False)
Expand Down Expand Up @@ -279,12 +279,8 @@ def attribute_observations(
observation_indices = np.arange(0, len(observations))

attributions_list = []
if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(f"Ray is not initialized. Initializing with {max_processes}...")
ray.init(address="auto", max_processes=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if orbits_ref is None:
orbits_ref = ray.put(orbits)
Expand Down Expand Up @@ -421,14 +417,8 @@ def merge_and_extend_orbits(
odp_orbits_list = []
odp_orbit_members_list = []
if len(orbits_iter) > 0 and len(observations_iter) > 0:

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", max_processes=max_processes)

refs_to_free = []
if observations_ref is None:
observations_ref = ray.put(observations)
Expand All @@ -437,7 +427,6 @@ def merge_and_extend_orbits(

converged = False
while not converged:

if use_ray:
# Orbits will change with differential correction so we need to add them
# to the object store at the start of each iteration (we cannot simply
Expand Down
17 changes: 3 additions & 14 deletions thor/orbits/iod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from adam_core.coordinates.residuals import Residuals
from adam_core.propagator import PYOORB, Propagator
from adam_core.propagator.utils import _iterate_chunks
from adam_core.ray_cluster import initialize_use_ray

from ..clusters import ClusterMembers
from ..observations.observations import Observations
Expand Down Expand Up @@ -126,13 +127,11 @@ def iod_worker(
propagator: Type[Propagator] = PYOORB,
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:

prop = propagator(**propagator_kwargs)

iod_orbits_list = []
iod_orbit_members_list = []
for linkage_id in linkage_ids:

time_start = time.time()
logger.debug(f"Finding initial orbit for linkage {linkage_id}...")

Expand Down Expand Up @@ -379,7 +378,6 @@ def iod(
# belonging to one object yield a good initial orbit but the presence of outlier
# observations is skewing the sum total of the residuals and chi2
elif num_outliers > 0:

logger.debug("Attempting to identify possible outliers.")
for o in range(num_outliers):
# Select i highest observations that contribute to
Expand Down Expand Up @@ -424,11 +422,9 @@ def iod(
j += 1

if not converged or not processable:

return FittedOrbits.empty(), FittedOrbitMembers.empty()

else:

orbit = FittedOrbits.from_kwargs(
orbit_id=orbit_sol.orbit_id,
object_id=orbit_sol.object_id,
Expand Down Expand Up @@ -574,18 +570,11 @@ def initial_orbit_determination(
iod_orbits_list = []
iod_orbit_members_list = []
if len(observations) > 0 and len(linkage_members) > 0:

# Extract linkage IDs
linkage_ids = linkage_members.column(linkage_id_col).unique()

if max_processes is None or max_processes > 1:

if not ray.is_initialized():
logger.info(
f"Ray is not initialized. Initializing with {max_processes}..."
)
ray.init(address="auto", num_cpus=max_processes)

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
if linkage_members_ref is None:
linkage_members_ref = ray.put(linkage_members)
Expand Down
Loading

0 comments on commit 7a797f4

Please sign in to comment.