Skip to content

Commit

Permalink
refinement of fiber thermal shift & minor bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmejia committed Dec 10, 2024
1 parent 591fa98 commit fc8dcf9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
6 changes: 4 additions & 2 deletions python/lvmdrp/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,10 @@ def measure_fiber_shifts(self, ref_image, trace_cent, columns=[500, 1000, 1500,
s2 = bn.nanmedian(self._data[50:-50,2000-500:2000+500], axis=1)
guess_shift = align_blocks(s1, s2)

if guess_shift > 6:
log.warning(f"measuring fiber thermal shift too large {guess_shift = } pixels")
if numpy.abs(guess_shift) > 6:
log.warning(f"measuring guess fiber thermal shift too large {guess_shift = } pixels")
else:
log.info(f"measured guess fiber thermal shift {guess_shift = } pixels")

shifts = numpy.zeros(len(columns))
select_blocks = [9]
Expand Down
40 changes: 35 additions & 5 deletions python/lvmdrp/core/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy
import bottleneck as bn
from astropy.io import fits as pyfits
from astropy.stats import biweight_location
from numpy import polynomial
from scipy.linalg import norm
from scipy import signal, interpolate, ndimage, sparse
Expand Down Expand Up @@ -225,7 +226,7 @@ def _cross_match(

def _normalize_peaks(data, ref, min_peak_dist):
data_ = numpy.asarray(data).copy()
dat_peaks, dat_peak_pars = signal.find_peaks(data_, distance=min_peak_dist, rel_height=0.5, width=(2,4), prominence=1.5)
dat_peaks, dat_peak_pars = signal.find_peaks(data_, distance=min_peak_dist)

ref_ = numpy.asarray(ref).copy()
ref_peaks, ref_peak_pars = signal.find_peaks(ref_, distance=min_peak_dist, rel_height=0.5, width=(2,4), prominence=1.5)
Expand All @@ -238,16 +239,19 @@ def _normalize_peaks(data, ref, min_peak_dist):
dat_norm = numpy.interp(numpy.arange(data_.shape[0]), dat_peaks, data_[dat_peaks])
ref_norm = numpy.interp(numpy.arange(data_.shape[0]), ref_peaks, ref_[ref_peaks])

# import matplotlib.pyplot as plt
# plt.figure()
# plt.plot(dat_norm)
# plt.plot(ref_norm)

ref_ = ref_ / ref_norm
data_ = data_ / dat_norm
# ref_ = ref_ / ref_norm * dat_norm / numpy.median(data_)
# data_ = data_ / numpy.median(data_)

# import matplotlib.pyplot as plt
# plt.figure()
# plt.plot(dat_norm, "-b")
# plt.vlines(dat_peaks, 0, 1, lw=1, color="tab:blue")
# # plt.plot(ref_norm, "-r")
# # plt.vlines(ref_peaks, 0, 1, lw=1, color="tab:red")

return data_, ref_, dat_peaks, dat_peak_pars, ref_peaks, ref_peak_pars


Expand Down Expand Up @@ -285,19 +289,37 @@ def _choose_cc_peak(cc, shifts, min_shift, max_shift):
# print(ccp, sum_cc)
return ccp[numpy.argmax(sum_cc)]


def align_blocks(ref_spec, obs_spec, median_box=21):
"""Cross-correlate median-filtered versions of fiber profile data and model to get coarse alignment"""
# sigma-clip spectra to half median to remove fiber features
obs_avg = biweight_location(obs_spec, ignore_nan=True)
ref_avg = biweight_location(ref_spec, ignore_nan=True)
obs_spec = numpy.clip(obs_spec, 0, 0.5*obs_avg)
ref_spec = numpy.clip(ref_spec, 0, 0.5*ref_avg)

obs_median = signal.medfilt(obs_spec, median_box)
ref_median = signal.medfilt(ref_spec, median_box)

obs_median /= numpy.median(obs_median)
ref_median /= numpy.median(ref_median)

# import matplotlib.pyplot as plt
# plt.figure()
# pixels = numpy.arange(obs_spec.size)
# plt.step(pixels, obs_median, where="mid", color="green")
# plt.step(pixels, ref_median, where="mid", color="purple")
# # plt.axhline(obs_avg, ls="--")
# # plt.axhline(obs_avg+obs_std, ls=":")

cc = signal.correlate(obs_median, ref_median, mode="same")

shifts = signal.correlation_lags(len(obs_spec), len(ref_spec), mode="same")
best_shift = shifts[numpy.argmax(cc)]

# plt.figure()
# plt.step(shifts, cc, where="mid")

return best_shift


Expand Down Expand Up @@ -512,6 +534,14 @@ def _fiber_cc_match(
)
cross_corr = signal.correlate(obs_spec_, ref_spec_, mode="same")

# import matplotlib.pyplot as plt
# plt.figure()
# pixels = numpy.arange(obs_spec_.size)
# plt.step(pixels, obs_spec_, where="mid")
# plt.step(pixels, ref_spec_, where="mid")
# plt.figure()
# plt.step(shifts, cross_corr, where="mid")

# Normalize the cross correlation
cross_corr = cross_corr.astype(numpy.float32)
cross_corr /= norm(ref_spec_) * norm(obs_spec_)
Expand Down

0 comments on commit fc8dcf9

Please sign in to comment.