Skip to content

Commit

Permalink
Merge pull request #43 from esheldon/joblib-config
Browse files Browse the repository at this point in the history
ENH allow variable number of jobs for joblib
  • Loading branch information
esheldon authored Jan 13, 2020
2 parents 4699eee + bf094d7 commit 3d7851a
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions meds/maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,13 @@ def _write_psf_cutouts_joblib(self):
psf_data, file_ids, rows, cols))

# run them all in parallel
with joblib.Parallel(
n_jobs=-1,
backend='multiprocessing',
with joblib.parallel_backend(
self._joblib_backend,
inner_max_num_threads=self._joblib_threads):
outputs = joblib.Parallel(
n_jobs=self._joblib_max_workers,
max_nbytes=None,
verbose=50) as parallel:
outputs = parallel(jobs)
verbose=50)(jobs)

# write to disk
# at this point all of the PSFs we need are in memory on a
Expand Down Expand Up @@ -421,7 +422,7 @@ def _write_psf_cutouts(self):

print('writing psf cutouts')

if self.get('use_joblib', False):
if self._use_joblib:
self._write_psf_cutouts_joblib()
else:
self._write_psf_cutouts_serial()
Expand Down Expand Up @@ -891,10 +892,13 @@ def _do_sky2image(self, wcs, ra, dec, color=None):
"""
# the cut at 250 eliminates cases where multiprocessing is
# slower or the same due to overheads
if self.get('use_joblib', False) and len(ra) > 250:
if self._use_joblib and len(ra) > 250:
import joblib
n_jobs = joblib.externals.loky.cpu_count()

if self._joblib_max_workers > 0:
n_jobs = min(self._joblib_max_workers, n_jobs)

n_per_job = len(ra) // n_jobs
if n_jobs * n_per_job < len(ra):
n_per_job += 1
Expand Down Expand Up @@ -924,12 +928,13 @@ def _do_sky2image(self, wcs, ra, dec, color=None):
)
)

with joblib.Parallel(
with joblib.parallel_backend(
self._joblib_backend,
inner_max_num_threads=self._joblib_threads):
outputs = joblib.Parallel(
n_jobs=n_jobs,
backend='multiprocessing',
max_nbytes=None,
verbose=50) as parallel:
outputs = parallel(jobs)
verbose=50)(jobs)

col = []
row = []
Expand Down Expand Up @@ -1334,6 +1339,20 @@ def _load_config(self, config):
if 'psf_type' in self:
self['psf'] = {'type': self['psf_type']}

if 'joblib' in self:
self._use_joblib = True
else:
self._use_joblib = self.get('use_joblib', False)

self._joblib_backend = self.get(
'joblib', {}).get('backend', 'multiprocessing')
self._joblib_max_workers = self.get(
'joblib', {}).get('max_workers', -1)
if self._joblib_backend == 'loky':
self._joblib_threads = 1
else:
self._joblib_threads = None


def _psf_rec_func(output_path, psf_data, file_ids, rows, cols):
import joblib
Expand Down

0 comments on commit 3d7851a

Please sign in to comment.