Skip to content

Commit

Permalink
JP-3741: Faster temporary file I/O for outlier detection in on-disk m…
Browse files Browse the repository at this point in the history
…ode (#8782)
  • Loading branch information
emolter authored Sep 27, 2024
1 parent 976c239 commit 15fa0be
Show file tree
Hide file tree
Showing 17 changed files with 866 additions and 382 deletions.
1 change: 1 addition & 0 deletions changes/8782.outlier_detection.0.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Decrease the amount of file I/O required to compute the median when in_memory is set to False.
1 change: 1 addition & 0 deletions changes/8782.outlier_detection.1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug that caused intermediate files to conflict for different slits when a MultiSlitModel was processed.
1 change: 1 addition & 0 deletions changes/8782.resample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Permit creating drizzled models one at a time in many-to-many mode.
12 changes: 8 additions & 4 deletions docs/jwst/outlier_detection/outlier_detection_imaging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ Specifically, this routine performs the following operations:
should be used when resampling to create the output mosaic. Any pixel with a
DQ value not included in this value (or list of values) will be ignored when
resampling.
* Resampled images will be written out to disk with the suffix ``_<asn_id>_outlier_i2d.fits``
* When the ``save_intermediate_results`` parameter is set to True,
resampled images will be written out to disk with the suffix ``_<asn_id>_outlier_i2d.fits``
if the input model container has an <asn_id>, otherwise the suffix will be ``_outlier_i2d.fits``
by default.
* **If resampling is turned off** through the use of the ``resample_data`` parameter,
Expand Down Expand Up @@ -162,9 +163,12 @@ during processing includes:
:py:class:`~jwst.resample.ResampleStep` as well, to set whether or not to keep the
resampled images in memory or not.

#. Computing the median image works section-by-section by only keeping 1Mb of each input
in memory at a time. As a result, only the final output product array for the final
median image along with a stack of 1Mb image sections are kept in memory.
#. Computing the median image works by writing the resampled data frames to appendable files
on disk that are split into sections spatially but contain the entire ngroups (i.e., time)
axis. The section size is set to use roughly the same amount of memory as a single resampled
model, and since the resampled models are discarded from memory after this write operation this
choice avoids increasing the memory usage beyond a single resampled model.
Those sections are then read in one at a time to compute the median image.

These changes result in a minimum amount of memory usage during processing at the obvious
expense of reading and writing the products from disk.
Expand Down
63 changes: 44 additions & 19 deletions jwst/outlier_detection/_fileio.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import os

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


def remove_file(fn):
if isinstance(fn, str) and os.path.isfile(fn):
os.remove(fn)
log.info(f"Removing file {fn}")


def save_median(median_model, make_output_path, asn_id=None):
def save_median(median_model, make_output_path):
'''
Save median if requested by user
Expand All @@ -20,13 +12,46 @@ def save_median(median_model, make_output_path, asn_id=None):
median_model : ~jwst.datamodels.ImageModel
The median ImageModel or CubeModel to save
'''
default_suffix = "_outlier_i2d.fits"
if asn_id is None:
suffix_to_remove = default_suffix
else:
suffix_to_remove = f"_{asn_id}{default_suffix}"
median_model_output_path = make_output_path(
basepath=median_model.meta.filename.replace(suffix_to_remove, '.fits'),
suffix='median')
median_model.save(median_model_output_path)
log.info(f"Saved model in {median_model_output_path}")
_save_intermediate_output(median_model, "median", make_output_path)


def save_drizzled(drizzled_model, make_output_path):
expected_tail = "outlier_?2d.fits"
suffix = drizzled_model.meta.filename[-len(expected_tail):-5]
_save_intermediate_output(drizzled_model, suffix, make_output_path)


def save_blot(input_model, blot, make_output_path):
blot_model = _make_blot_model(input_model, blot)
_save_intermediate_output(blot_model, "blot", make_output_path)


def _make_blot_model(input_model, blot):
blot_model = type(input_model)()
blot_model.data = blot
blot_model.update(input_model)
return blot_model


def _save_intermediate_output(model, suffix, make_output_path):
"""
Ensure all intermediate outputs from OutlierDetectionStep have consistent file naming conventions
Notes
-----
self.make_output_path() is updated globally for the step in the main pipeline
to include the asn_id in the output path, so no need to handle it here.
"""

# outlier_?2d is not a known suffix, and make_output_path cannot handle an
# underscore in an unknown suffix, so do a manual string replacement
input_path = model.meta.filename.replace("_outlier_", "_")

# Add a slit name to the output path for MultiSlitModel data if not present
if hasattr(model, "name") and model.name is not None:
if "_"+model.name.lower() not in input_path:
suffix = f"{model.name.lower()}_{suffix}"

output_path = make_output_path(input_path, suffix=suffix)
model.save(output_path)
log.info(f"Saved {suffix} model in {output_path}")
4 changes: 1 addition & 3 deletions jwst/outlier_detection/coron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Submodule for performing outlier detection on coronagraphy data.
"""


import logging

import numpy as np
Expand All @@ -27,7 +26,6 @@ def detect_outliers(
good_bits,
maskpt,
snr,
asn_id,
make_output_path,
):
"""
Expand Down Expand Up @@ -56,7 +54,7 @@ def detect_outliers(
median_model.update(input_model)
median_model.meta.wcs = input_model.meta.wcs

save_median(median_model, make_output_path, asn_id)
save_median(median_model, make_output_path)
del median_model

# Perform outlier detection using statistical comparisons between
Expand Down
71 changes: 16 additions & 55 deletions jwst/outlier_detection/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
Submodule for performing outlier detection on imaging data.
"""

import copy
import logging
import os

from stdatamodels.jwst import datamodels

from jwst.datamodels import ModelLibrary
from jwst.resample import resample
from jwst.resample.resample_utils import build_driz_weight
from jwst.stpipe.utilities import record_step_status

from .utils import create_median, flag_model_crs, flag_resampled_model_crs
from ._fileio import remove_file, save_median
from .utils import (flag_model_crs,
flag_resampled_model_crs,
median_without_resampling,
median_with_resampling)

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand All @@ -40,7 +37,6 @@ def detect_outliers(
fillval,
allowed_memory,
in_memory,
asn_id,
make_output_path,
):
"""
Expand All @@ -58,20 +54,10 @@ def detect_outliers(
log.warning("Outlier detection will be skipped")
record_step_status(input_models, "outlier_detection", False)
return input_models

if resample_data:
# Start by creating resampled/mosaic images for
# each group of exposures
with input_models:
example_model = input_models.borrow(0)
output_path = make_output_path(basepath=example_model.meta.filename,
suffix='')
input_models.shelve(example_model, modify=False)
del example_model
output_path = os.path.dirname(output_path)
resamp = resample.ResampleData(
input_models,
output=output_path,
single=True,
blendheaders=False,
wht_type=weight_type,
Expand All @@ -80,46 +66,21 @@ def detect_outliers(
fillval=fillval,
good_bits=good_bits,
in_memory=in_memory,
asn_id=asn_id,
allowed_memory=allowed_memory,
)
median_wcs = resamp.output_wcs
drizzled_models = resamp.do_drizzle(input_models)
median_data, median_wcs = median_with_resampling(input_models,
resamp,
maskpt,
save_intermediate_results=save_intermediate_results,
make_output_path=make_output_path,)
else:
# for non-dithered data, the resampled image is just the original image
drizzled_models = input_models
with input_models:
for i, model in enumerate(input_models):
model.wht = build_driz_weight(
model,
weight_type=weight_type,
good_bits=good_bits)
# copy for when saving median and input is a filename?
if i == 0:
median_wcs = copy.deepcopy(model.meta.wcs)
input_models.shelve(model, modify=True)

# Perform median combination on set of drizzled mosaics
median_data = create_median(drizzled_models, maskpt)
median_data, median_wcs = median_without_resampling(input_models,
maskpt,
weight_type,
good_bits,
save_intermediate_results=save_intermediate_results,
make_output_path=make_output_path,)

if save_intermediate_results:
# make a median model
with drizzled_models:
example_model = drizzled_models.borrow(0)
drizzled_models.shelve(example_model, modify=False)
median_model = datamodels.ImageModel(median_data)
median_model.update(example_model)
median_model.meta.wcs = median_wcs
del example_model

save_median(median_model, make_output_path, asn_id)
del median_model
else:
# since we're not saving intermediate results if the drizzled models
# were written to disk, remove them
if not in_memory:
for fn in drizzled_models.asn["products"][0]["members"]:
remove_file(fn["expname"])

# Perform outlier detection using statistical comparisons between
# each original input image and its blotted version of the median image
Expand Down
13 changes: 6 additions & 7 deletions jwst/outlier_detection/outlier_detection_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def process(self, input_data):
self.log.info(f"Outlier Detection mode: {mode}")

# determine the asn_id (if not set by the pipeline)
asn_id = self._get_asn_id(input_data)
self.log.info(f"Outlier Detection asn_id: {asn_id}")
self._get_asn_id(input_data)

snr1, snr2 = [float(v) for v in self.snr.split()]
scale1, scale2 = [float(v) for v in self.scale.split()]
Expand All @@ -94,7 +93,6 @@ def process(self, input_data):
self.maskpt,
self.rolling_window_width,
snr1,
asn_id,
self.make_output_path,
)
elif mode == 'coron':
Expand All @@ -104,7 +102,6 @@ def process(self, input_data):
self.good_bits,
self.maskpt,
snr1,
asn_id,
self.make_output_path,
)
elif mode == 'imaging':
Expand All @@ -125,7 +122,6 @@ def process(self, input_data):
self.fillval,
self.allowed_memory,
self.in_memory,
asn_id,
self.make_output_path,
)
elif mode == 'spec':
Expand All @@ -145,7 +141,6 @@ def process(self, input_data):
self.kernel,
self.fillval,
self.in_memory,
asn_id,
self.make_output_path,
)
elif mode == 'ifu':
Expand Down Expand Up @@ -203,6 +198,9 @@ def _guess_mode(self, input_models):
return None

def _get_asn_id(self, input_models):
"""Find association ID for any allowed input model type,
and update make_output_path such that the association ID
is included in intermediate and output file names."""
# handle if input_models isn't open
if isinstance(input_models, (str, dict)):
input_models = datamodels.open(input_models, asn_n_members=1)
Expand All @@ -227,7 +225,8 @@ def _get_asn_id(self, input_models):
_make_output_path,
asn_id=asn_id
)
return asn_id
self.log.info(f"Outlier Detection asn_id: {asn_id}")
return

def _set_status(self, input_models, status):
# this might be called with the input which might be a filename or path
Expand Down
Loading

0 comments on commit 15fa0be

Please sign in to comment.