Skip to content

Commit

Permalink
Minor edits from review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke committed Mar 5, 2025
1 parent e81e063 commit 796f70c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 100 deletions.
185 changes: 94 additions & 91 deletions jwst/residual_fringe/residual_fringe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

# Noise factor for DER_SNR spectroscopic signal-to-noise calculation
# (see Stoehr, ADASS 2008: https://archive.stsci.edu/vodocs/der_snr.pdf)
DER_SNR_FACTOR = 1.482602 / np.sqrt(6)


class ResidualFringeCorrection:
"""Calculate and apply correction for residual fringes."""
Expand All @@ -42,10 +46,10 @@ def __init__(
regions_reference_file : str
Path to REGIONS reference file.
ignore_regions : dict
Wavelength regions to ignore. Keys are "num", "min", and "max.
Wavelength regions to ignore. Keys are "num", "min", and "max".
Values are the number of regions specified (int), the list
of minimum wavelength values, and the list of maximum wavelength
values. Minimum and maximum lists must match.
values. Length of minimum and maximum lists must match.
save_intermediate_results : bool, optional
If True, intermediate files are saved to disk.
transmission_level : int, optional
Expand Down Expand Up @@ -90,9 +94,7 @@ def __init__(

def do_correction(self):
"""
Apply residual fringe correction.
Correction is applied to a model copied from self.input_model.
Apply residual fringe correction to a copy of self.input_model.
Returns
-------
Expand Down Expand Up @@ -267,17 +269,17 @@ def do_correction(self):
# reasonable signal. If the SNR < min_snr (CDP), pass
n = len(test_flux)
signal = np.nanmean(test_flux)
noise = 0.6052697 * np.nanmedian(
noise = DER_SNR_FACTOR * np.nanmedian(
np.abs(2.0 * test_flux[2 : n - 2] - test_flux[0 : n - 4] - test_flux[4:n])
)

snr2 = 0.0 # initialize
snr2 = 0.0
if noise != 0:
snr2 = signal / noise

# Sometimes can return nan, inf for bad data so include this in check
if snr2 < min_snr[0]:
log.debug(f"SNR too low not fitting column {col}, {snr2}, {min_snr[0]}")
log.debug(f"SNR too low; not fitting column {col}, {snr2}, {min_snr[0]}")
continue

log.debug(f"Fitting column {col}")
Expand Down Expand Up @@ -372,89 +374,90 @@ def do_correction(self):
try:
for fn, ff in enumerate(ffreq):
# ignore place holder fringes
if ff > 1e-03:
log.debug(f" Start ffreq = {ff}")

# check if snr criteria is met for fringe component,
# should always be true for fringe 1
if snr2 > min_snr[fn]:
log.debug(" Fit spectral baseline")

bg_fit, bgindx = utils.fit_1d_background_complex(
proc_data,
weights_feat,
col_wnum,
ffreq=ffreq[fn],
channel=c,
)

# get the residual fringes as fraction of signal
res_fringes = np.divide(
proc_data,
bg_fit,
out=np.zeros_like(proc_data),
where=bg_fit != 0,
)
res_fringes = np.subtract(
res_fringes, 1, where=res_fringes != 0
)
res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08)

# fit the residual fringes
log.debug(" Set up Bayes evidence")
(
res_fringe_fit,
wpix_num,
opt_nfringe,
peak_freq,
freq_min,
freq_max,
) = utils.fit_1d_fringes_bayes_evidence(
res_fringes,
weights_feat,
col_wnum,
ffreq[fn],
dffreq[fn],
max_nfringes[fn],
pgram_res[fn],
col_snr2,
)

# check for fit blowing up, reset rfc fit to 0, raise a flag
log.debug(" Check residual fringe fit for bad fit regions")
res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes(
res_fringe_fit, col_max_amp
)

# correct for residual fringes
log.debug(" Divide out residual fringe fit")
_, _, _, env, u_x, u_y = utils.fit_envelope(
np.arange(res_fringe_fit.shape[0]), res_fringe_fit
)

rfc_factors = 1 / (
res_fringe_fit * (col_weight > 1e-05).astype(int) + 1
)
proc_data *= rfc_factors
proc_factors *= rfc_factors

# handle nans or infs that may exist
proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08)
proc_data[proc_data < 0] = 1e-08

out_table.add_row(
(
ss,
col,
fn,
snr2,
pgram_res[fn],
opt_nfringe,
peak_freq,
freq_min,
freq_max,
)
)
if ff <= 1e-03:
continue

# check if snr criteria is met for fringe component,
# should always be true for fringe 1
if snr2 <= min_snr[fn]:
continue

Check warning on line 383 in jwst/residual_fringe/residual_fringe.py

View check run for this annotation

Codecov / codecov/patch

jwst/residual_fringe/residual_fringe.py#L383

Added line #L383 was not covered by tests

log.debug(f" Start ffreq = {ff}")
log.debug(" Fit spectral baseline")

bg_fit, bgindx = utils.fit_1d_background_complex(
proc_data,
weights_feat,
col_wnum,
ffreq=ffreq[fn],
channel=c,
)

# get the residual fringes as fraction of signal
res_fringes = np.divide(
proc_data,
bg_fit,
out=np.zeros_like(proc_data),
where=bg_fit != 0,
)
res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0)
res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08)

# fit the residual fringes
log.debug(" Set up Bayes evidence")
(
res_fringe_fit,
wpix_num,
opt_nfringe,
peak_freq,
freq_min,
freq_max,
) = utils.fit_1d_fringes_bayes_evidence(
res_fringes,
weights_feat,
col_wnum,
ffreq[fn],
dffreq[fn],
max_nfringes[fn],
pgram_res[fn],
col_snr2,
)

# check for fit blowing up, reset rfc fit to 0, raise a flag
log.debug(" Check residual fringe fit for bad fit regions")
res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes(
res_fringe_fit, col_max_amp
)

# correct for residual fringes
log.debug(" Divide out residual fringe fit")
_, _, _, env, u_x, u_y = utils.fit_envelope(
np.arange(res_fringe_fit.shape[0]), res_fringe_fit
)

rfc_factors = 1 / (
res_fringe_fit * (col_weight > 1e-05).astype(int) + 1
)
proc_data *= rfc_factors
proc_factors *= rfc_factors

# handle nans or infs that may exist
proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08)
proc_data[proc_data < 0] = 1e-08

out_table.add_row(
(
ss,
col,
fn,
snr2,
pgram_res[fn],
opt_nfringe,
peak_freq,
freq_min,
freq_max,
)
)

# define fringe sub after all fringe components corrections
fringe_sub = proc_data.copy()
Expand Down
2 changes: 1 addition & 1 deletion jwst/residual_fringe/tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def miri_image():
@pytest.fixture()
def step_log_watcher(monkeypatch):
# Set a log watcher to check for a log message at any level
# in the emicorr step
# in the residual_fringe step
watcher = LogWatcher("")
logger = logging.getLogger("stpipe.ResidualFringeStep")
for level in ["debug", "info", "warning", "error"]:
Expand Down
10 changes: 7 additions & 3 deletions jwst/residual_fringe/tests/test_residual_fringe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def one_slice(*args):
@pytest.fixture()
def module_log_watcher(monkeypatch):
# Set a log watcher to check for a log message at any level
# in the emicorr module
# in the residual_fringe module
watcher = LogWatcher("")
logger = logging.getLogger("jwst.residual_fringe.residual_fringe")
for level in ["debug", "info", "warning", "error"]:
Expand Down Expand Up @@ -157,8 +157,12 @@ def test_rf1d(linear_spectrum, fringed_spectrum):


def test_get_wavemap():
# Test the _get_wavemap function directly, since
# all full calls to the correction method mock it
"""
Test the _get_wavemap function directly.
A separate test is needed, since calls to the higher level correction method
mock this function for synthetic data simplicity.
"""
model = datamodels.IFUImageModel()

# Mock a WCS that returns 1 for wavelengths
Expand Down
10 changes: 5 additions & 5 deletions jwst/residual_fringe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def slice_info(slice_map, channel):
slices_in_channel : ndarray of int
1D array of slice IDs included in the channel.
xrange_channel : ndarray of int
1D array with two elements: minimum and maximum x values
1D array with two elements: minimum and maximum x indices
for the channel.
slice_x_ranges : ndarray of int
N x 3 array for N slices, where the first column is the slice ID,
second column is the minimum x value for the slice,
and the third column is the maximum x value for the slice.
second column is the minimum x index for the slice,
and the third column is the maximum x index for the slice.
all_slice_masks : ndarray of int
N x nx x ny for N slices, matching the x and y shape of the
input slice_map. Values are 1 for pixels included in the slice,
Expand Down Expand Up @@ -137,7 +137,7 @@ def fill_wavenumbers(wnums):

def multi_sine(n_sines):
"""
Create a mult-sine model.
Create a multi-sine model.
Parameters
----------
Expand All @@ -146,7 +146,7 @@ def multi_sine(n_sines):
Returns
-------
model : SineModel
model : BayesicFitting.SineModel
The model composed of n sines.
"""
# make the first sine
Expand Down

0 comments on commit 796f70c

Please sign in to comment.