From d9e63244d6fdc993fc1b0cbdcc01ee1ac2e7f4f6 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Fri, 28 Feb 2025 16:34:16 -0500 Subject: [PATCH 1/7] Tests for residual_fringe --- jwst/regtest/test_miri_mrs_spec2.py | 27 ++ jwst/residual_fringe/fitter.py | 5 +- jwst/residual_fringe/residual_fringe.py | 455 +++++++++--------- jwst/residual_fringe/residual_fringe_step.py | 34 +- .../tests/test_configuration.py | 55 ++- .../tests/test_residual_fringe.py | 286 +++++++++++ jwst/residual_fringe/utils.py | 39 -- pyproject.toml | 2 +- 8 files changed, 607 insertions(+), 296 deletions(-) create mode 100644 jwst/residual_fringe/tests/test_residual_fringe.py diff --git a/jwst/regtest/test_miri_mrs_spec2.py b/jwst/regtest/test_miri_mrs_spec2.py index 3e4a23cb9f..e2d66d4738 100644 --- a/jwst/regtest/test_miri_mrs_spec2.py +++ b/jwst/regtest/test_miri_mrs_spec2.py @@ -45,6 +45,24 @@ def run_spec2(rtdata_module): Step.from_cmdline(args) +@pytest.fixture(scope='module') +def run_spec2_with_residual_fringe(rtdata_module): + """Run the Spec2Pipeline on a single exposure""" + rtdata = rtdata_module + + # Get the input rate file + rtdata.get_data(INPUT_PATH + '/' + 'jw01024001001_04101_00001_mirifulong_rate.fits') + + # Run the pipeline + args = ["calwebb_spec2", rtdata.input, + '--output_file=jw01024001001_04101_00001_mirifulong_rf', + '--steps.residual_fringe.skip=false', + '--steps.residual_fringe.save_results=true', + ] + + Step.from_cmdline(args) + + @pytest.mark.slow @pytest.mark.bigdata @pytest.mark.parametrize( @@ -83,3 +101,12 @@ def test_miri_mrs_wcs(run_spec2, fitsdiff_default_kwargs, rtdata_module): xtruth, ytruth = im_truth.meta.wcs.backward_transform(ratruth, dectruth, lamtruth) assert_allclose(xtest, xtruth) assert_allclose(ytest, ytruth) + + +@pytest.mark.slow +@pytest.mark.bigdata +@pytest.mark.parametrize('suffix', ['rf_residual_fringe', 'rf_s3d', 'rf_x1d', 'rf_cal']) +def test_miri_mrs_spec2_with_rf(run_spec2_with_residual_fringe, + fitsdiff_default_kwargs, suffix, rtdata_module): + rtdata = rtdata_module + rt.is_like_truth(rtdata, fitsdiff_default_kwargs, suffix, truth_path=TRUTH_PATH) diff --git a/jwst/residual_fringe/fitter.py b/jwst/residual_fringe/fitter.py index 8e531ccae0..1928d0b656 100644 --- a/jwst/residual_fringe/fitter.py +++ b/jwst/residual_fringe/fitter.py @@ -1,3 +1,4 @@ + import numpy as np import scipy.interpolate @@ -16,7 +17,6 @@ def spline_fitter(x, y, weights, knots, degree, reject_outliers=False, domain=10 def chi_sq(spline, weights): return np.nansum((y - spline(x)) ** 2 * weights) - # initial fit spline = _lsq_spline(x, y, weights, knots, degree) chi = chi_sq(spline, weights) @@ -25,6 +25,9 @@ def chi_sq(spline, weights): nparams = len(knots) + (degree + 1) * 2 deg_of_freedom = np.sum(weights) - nparams + if deg_of_freedom <= 0: + raise RuntimeError("Degrees of freedom <= 0") + for _ in range(1000 * nparams): scale = np.sqrt(chi / deg_of_freedom) diff --git a/jwst/residual_fringe/residual_fringe.py b/jwst/residual_fringe/residual_fringe.py index 5f9bc1b3ee..c12e93ceb0 100644 --- a/jwst/residual_fringe/residual_fringe.py +++ b/jwst/residual_fringe/residual_fringe.py @@ -1,42 +1,45 @@ -# -# Module for applying fringe correction -# +"""Apply residual fringe correction.""" -import numpy as np +import logging from functools import partial +import numpy as np +from astropy.table import Table +from astropy.io import ascii, fits +from stdatamodels import fits_support from stdatamodels.jwst import datamodels -from ..stpipe import Step -from astropy.table import Table -from astropy.io import ascii -from astropy.io import fits -from . import utils +from jwst.stpipe import Step +from jwst.residual_fringe import utils -import logging log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -class ResidualFringeCorrection(): +class ResidualFringeCorrection: def __init__(self, input_model, residual_fringe_reference_file, regions_reference_file, ignore_regions, - **pars): + save_intermediate_results=False, + transmission_level=80, + make_output_path=None): self.input_model = input_model.copy() self.model = input_model.copy() self.residual_fringe_reference_file = residual_fringe_reference_file self.regions_reference_file = regions_reference_file self.ignore_regions = ignore_regions - self.save_intermediate_results = pars['save_intermediate_results'] - self.transmission_level = int(pars['transmission_level']) + self.save_intermediate_results = save_intermediate_results + self.transmission_level = transmission_level + # define how filenames are created - self.make_output_path = pars.get('make_output_path', - partial(Step._make_output_path, None)) + if make_output_path is None: + self.make_output_path = partial(Step._make_output_path, None) + else: + self.make_output_path = make_output_path self.rfc_factors = None self.fit_mask = None @@ -159,9 +162,7 @@ def do_correction(self): units=('', ''), dtype=('i4', '(3,{})f8'.format(n_wav_samples))) - ysize = self.input_model.data.shape[0] - xsize = self.input_model.data.shape[1] - + wave_map = self._get_wave_map() for c in self.channels: num_corrected = 0 log.info("Processing channel {}".format(c)) @@ -170,9 +171,6 @@ def do_correction(self): log.debug(' Slice Ranges {}'.format(slice_x_ranges)) - y, x = np.mgrid[:ysize, :xsize] - _, _, wave_map = self.input_model.meta.wcs(x, y) - # if the user wants to ignore some values use the wave_map array to set the corresponding # weight values to 0 if self.ignore_regions['num'] > 0: @@ -237,213 +235,210 @@ def do_correction(self): # Sometimes can return nan, inf for bad data so include this in check if snr2 < min_snr[0]: - log.debug('SNR too high not fitting column {}, {}, {}'.format(col, snr2, min_snr[0])) - pass - else: - log.debug("Fitting column{} ".format(col)) - log.debug("SNR > {} ".format(min_snr[0])) + log.debug(f'SNR too low not fitting column {col}, {snr2}, {min_snr[0]}') + continue - col_weight = ss_weight[:, col] - col_max_amp = np.interp(col_wmap, self.max_amp['Wavelength'], self.max_amp['Amplitude']) - col_snr2 = np.where(col_snr > 10, 1, 0) # hardcoded at snr > 10 for now + log.debug(f"Fitting column {col}") + log.debug(f"SNR > {min_snr[0]} ") - # get the in-slice pixel indices for replacing in output later - idx = np.where(col_data > 0) + col_weight = ss_weight[:, col] + col_max_amp = np.interp(col_wmap, self.max_amp['Wavelength'], self.max_amp['Amplitude']) + col_snr2 = np.where(col_snr > 10, 1, 0) # hardcoded at snr > 10 for now - # BayesicFitting doesn't like 0s at data or weight array edges so set to small value - # replacing array 0s with arbitrarily low number - col_data[col_data <= 0] = 1e-08 - col_weight[col_weight <= 0] = 1e-08 + # get the in-slice pixel indices for replacing in output later + idx = np.where(col_data > 0) - # check for off-slice pixels and send to be filled with interpolated/extrapolated wnums - # to stop BayesicFitting crashing, will not be fitted anyway - # finding out-of-slice pixels in column and filling + # BayesicFitting doesn't like 0s at data or weight array edges so set to small value + # replacing array 0s with arbitrarily low number + col_data[col_data <= 0] = 1e-08 + col_weight[col_weight <= 0] = 1e-08 - found_bad = np.logical_or(np.isnan(col_wnum), np.isinf(col_wnum)) - num_bad = len(np.where(found_bad)[0]) + # check for off-slice pixels and send to be filled with interpolated/extrapolated wnums + # to stop BayesicFitting crashing, will not be fitted anyway + # finding out-of-slice pixels in column and filling - if num_bad > 0: - col_wnum[found_bad] = 0 - col_wnum = utils.fill_wavenumbers(col_wnum) + found_bad = np.logical_or(np.isnan(col_wnum), np.isinf(col_wnum)) + num_bad = len(np.where(found_bad)[0]) - # do feature finding on slice now column-by-column - log.debug(" starting feature finding") + if num_bad > 0: + col_wnum[found_bad] = 0 + col_wnum = utils.fill_wavenumbers(col_wnum) + + # do feature finding on slice now column-by-column + log.debug(" starting feature finding") + + # narrow features (similar or less than fringe #1 period) + # find spectral features (env is spline fit of troughs and peaks) + env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), col_data) + mod = np.abs(col_data / env) - 1 + + # given signal in mod find location of lines > col_max_amp * 2 (fringe contrast) + # use col_snr to ignore noisy pixels + weight_factors = utils.find_lines(mod * col_snr2, col_max_amp * 2) + weights_feat = col_weight * weight_factors + + # account for fringe 2 on broad features in channels 3 and 4 + # need to smooth out the dichroic fringe as it breaks the feature finding method + if c in [3, 4]: + win = 7 # smoothing window hardcoded to 7 for now (based on testing) + cumsum = np.cumsum(np.insert(col_data, 0, 0)) + sm_col_data = (cumsum[win:] - cumsum[:-win]) / float(win) - # narrow features (similar or less than fringe #1 period) # find spectral features (env is spline fit of troughs and peaks) - env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), col_data) + env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), sm_col_data) mod = np.abs(col_data / env) - 1 - # given signal in mod find location of lines > col_max_amp * 2 (fringe contrast) - # use col_snr to ignore noisy pixels - weight_factors = utils.find_lines(mod * col_snr2, col_max_amp * 2) - weights_feat = col_weight * weight_factors - - # account for fringe 2 on broad features in channels 3 and 4 - # need to smooth out the dichroic fringe as it breaks the feature finding method - if c in [3, 4]: - win = 7 # smoothing window hardcoded to 7 for now (based on testing) - cumsum = np.cumsum(np.insert(col_data, 0, 0)) - sm_col_data = (cumsum[win:] - cumsum[:-win]) / float(win) - - # find spectral features (env is spline fit of troughs and peaks) - env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), sm_col_data) - mod = np.abs(col_data / env) - 1 - - # given signal in mod find location of lines > col_max_amp * 2 - weight_factors = utils.find_lines(mod, col_max_amp * 2) - weights_feat *= weight_factors - - # iterate over the fringe components to fit, initialize pre-contrast, other output arrays - # in case fit fails - proc_data = col_data.copy() - proc_factors = np.ones(col_data.shape) - pre_contrast = 0.0 - bg_fit = col_data.copy() - res_fringes = np.zeros(col_data.shape) - res_fringe_fit = np.zeros(col_data.shape) - res_fringe_fit_flag = np.zeros(col_data.shape) - wpix_num = 1024 - - # check the end points. A single value followed by gap of zero can cause - # problems in the fitting. - index = np.where(weights_feat != 0.0) - length = np.diff(index[0]) - - if weights_feat[0] != 0 and length[0] > 1: - weights_feat[0] = 1e-08 - - if weights_feat[-1] != 0 and length[-1] > 1: - weights_feat[-1] = 1e-08 - - # jane added this - fit can fail in evidence function. - # once we replace evidence function with astropy routine - we can test - # removing setting weights < 0.003 to zero (1e-08) - - weights_feat[weights_feat <= 0.003] = 1e-08 - - # currently the reference file fits one fringe originating in the detector pixels, and a - # second high frequency, low amplitude fringe in channels 3 and 4 which has been - # attributed to the dichroics. - try: - for fn, ff in enumerate(ffreq): - # ignore place holder fringes - if ff > 1e-03: - log.debug(' starting ffreq = {}'.format(ff)) - - # check if snr criteria is met for fringe component, should always be true for fringe 1 - if snr2 > min_snr[fn]: - log.debug(" fitting spectral baseline") - - bg_fit, bgindx = \ - utils.fit_1d_background_complex(proc_data, weights_feat, - col_wnum, ffreq=ffreq[fn], channel=c) - - # get the residual fringes as fraction of signal - res_fringes = np.divide(proc_data, bg_fit, out=np.zeros_like(proc_data), - where=bg_fit != 0) - res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) - res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) - # get the pre-correction contrast using fringe component 1 - # TODO: REMOVE CONTRAST CHECKS - # set dummy values for contrast check parameters until removed - pre_contrast = 0.0 - quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), - np.zeros(col_data.shape)]) - # if fn == 0: - # pre_contrast, quality = utils.fit_quality(col_wnum, - # res_fringes, - # weights_feat, - # ffreq[0], - # dffreq[0]) - # - # log.debug(" pre-correction contrast = {}".format(pre_contrast)) - # - # fit the residual fringes - log.debug(" set up bayes ") - res_fringe_fit, wpix_num, opt_nfringe, peak_freq, freq_min, freq_max = \ - utils.new_fit_1d_fringes_bayes_evidence(res_fringes, - weights_feat, - col_wnum, - ffreq[fn], - dffreq[fn], - min_nfringes=min_nfringes[fn], - max_nfringes=max_nfringes[fn], - pgram_res=pgram_res[fn], - col_snr2=col_snr2) - - # check for fit blowing up, reset rfc fit to 0, raise a flag - log.debug("check residual fringe fit for bad fit regions") - res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes(res_fringe_fit, - col_max_amp) - - # correct for residual fringes - log.debug(" divide out residual fringe fit, get fringe corrected column") - _, _, _, env, u_x, u_y = utils.fit_envelope(np.arange(res_fringe_fit.shape[0]), - res_fringe_fit) - - rfc_factors = 1 / (res_fringe_fit * (col_weight > 1e-05).astype(int) + 1) - proc_data *= rfc_factors - proc_factors *= rfc_factors - - # handle nans or infs that may exist - proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) - proc_data[proc_data < 0] = 1e-08 - - out_table.add_row((ss, col, fn, snr2, pre_contrast, pre_contrast, pgram_res[fn], - opt_nfringe, peak_freq, freq_min, freq_max)) - - # define fringe sub after all fringe components corrections - fringe_sub = proc_data.copy() - rfc_factors = proc_factors.copy() - - # get the new fringe contrast - log.debug(" analysing fit quality") - - pbg_fit, pbgindx = utils.fit_1d_background_complex(fringe_sub, - weights_feat, - col_wnum, - ffreq=ffreq[0], channel=c) - - # get the residual fringes as fraction of signal - fit_res = np.divide(fringe_sub, pbg_fit, out=np.zeros_like(fringe_sub), - where=pbg_fit != 0) - fit_res = np.subtract(fit_res, 1, where=fit_res != 0) - fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) - - # TODO: REMOVE CONTRAST CHECKS - # set dummy values for contrast check parameters until removed - contrast = 0.0 - quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), np.zeros(col_data.shape)]) - # contrast, quality = utils.fit_quality(col_wnum, - # fit_res, - # weights_feat, - # ffreq, - # dffreq, - # save_results=self.save_intermediate_results) - - out_table.add_row((ss, col, fn, snr2, pre_contrast, contrast, pgram_res[0], - opt_nfringe, peak_freq, freq_min, freq_max)) - - qual_table.add_row((col, quality)) - correction_quality.append([contrast, pre_contrast]) - log.debug(" residual contrast = {}".format(contrast)) - - # replace the corrected in-slice column pixels in the data_cor array - log.debug(" updating the trace pixels in the output") - output_data[idx, col] = fringe_sub[idx] - self.rfc_factors[idx, col] = rfc_factors[idx] - self.fit_mask[idx, col] = np.ones(1024)[idx] - self.weights_feat[idx, col] = weights_feat[idx] - self.weighted_pix_num[idx, col] = np.ones(1024)[idx] * (wpix_num / 1024) - self.rejected_fit[idx, col] = res_fringe_fit_flag[idx] - self.background_fit[idx, col] = bg_fit[idx] - self.knot_locations[:bgindx.shape[0], col] = bgindx - num_corrected = num_corrected + 1 - - except Exception as e: - log.warning(" Skipping col={} {} ".format(col, ss)) - log.warning(' %s' % (str(e))) + # given signal in mod find location of lines > col_max_amp * 2 + weight_factors = utils.find_lines(mod, col_max_amp * 2) + weights_feat *= weight_factors + + # iterate over the fringe components to fit, initialize pre-contrast, other output arrays + # in case fit fails + proc_data = col_data.copy() + proc_factors = np.ones(col_data.shape) + pre_contrast = 0.0 + bg_fit = col_data.copy() + res_fringe_fit_flag = np.zeros(col_data.shape) + wpix_num = 1024 + + # check the end points. A single value followed by gap of zero can cause + # problems in the fitting. + index = np.where(weights_feat != 0.0) + length = np.diff(index[0]) + #print(weights_feat, index, length) + if weights_feat[0] != 0 and length[0] > 1: + weights_feat[0] = 1e-08 + + if weights_feat[-1] != 0 and length[-1] > 1: + weights_feat[-1] = 1e-08 + + # jane added this - fit can fail in evidence function. + # once we replace evidence function with astropy routine - we can test + # removing setting weights < 0.003 to zero (1e-08) + weights_feat[weights_feat <= 0.003] = 1e-08 + + # currently the reference file fits one fringe originating in the detector pixels, and a + # second high frequency, low amplitude fringe in channels 3 and 4 which has been + # attributed to the dichroics. + try: + for fn, ff in enumerate(ffreq): + # ignore place holder fringes + if ff > 1e-03: + log.debug(' starting ffreq = {}'.format(ff)) + + # check if snr criteria is met for fringe component, should always be true for fringe 1 + if snr2 > min_snr[fn]: + log.debug(" fitting spectral baseline") + + bg_fit, bgindx = \ + utils.fit_1d_background_complex(proc_data, weights_feat, + col_wnum, ffreq=ffreq[fn], channel=c) + + # get the residual fringes as fraction of signal + res_fringes = np.divide(proc_data, bg_fit, out=np.zeros_like(proc_data), + where=bg_fit != 0) + res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) + res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) + # get the pre-correction contrast using fringe component 1 + # TODO: REMOVE CONTRAST CHECKS + # set dummy values for contrast check parameters until removed + pre_contrast = 0.0 + quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), + np.zeros(col_data.shape)]) + # if fn == 0: + # pre_contrast, quality = utils.fit_quality(col_wnum, + # res_fringes, + # weights_feat, + # ffreq[0], + # dffreq[0]) + # + # log.debug(" pre-correction contrast = {}".format(pre_contrast)) + # + # fit the residual fringes + log.debug(" set up bayes ") + res_fringe_fit, wpix_num, opt_nfringe, peak_freq, freq_min, freq_max = \ + utils.new_fit_1d_fringes_bayes_evidence(res_fringes, + weights_feat, + col_wnum, + ffreq[fn], + dffreq[fn], + min_nfringes=min_nfringes[fn], + max_nfringes=max_nfringes[fn], + pgram_res=pgram_res[fn], + col_snr2=col_snr2) + + # check for fit blowing up, reset rfc fit to 0, raise a flag + log.debug("check residual fringe fit for bad fit regions") + res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes(res_fringe_fit, + col_max_amp) + + # correct for residual fringes + log.debug(" divide out residual fringe fit, get fringe corrected column") + _, _, _, env, u_x, u_y = utils.fit_envelope(np.arange(res_fringe_fit.shape[0]), + res_fringe_fit) + + rfc_factors = 1 / (res_fringe_fit * (col_weight > 1e-05).astype(int) + 1) + proc_data *= rfc_factors + proc_factors *= rfc_factors + + # handle nans or infs that may exist + proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) + proc_data[proc_data < 0] = 1e-08 + + out_table.add_row((ss, col, fn, snr2, pre_contrast, pre_contrast, pgram_res[fn], + opt_nfringe, peak_freq, freq_min, freq_max)) + + # define fringe sub after all fringe components corrections + fringe_sub = proc_data.copy() + rfc_factors = proc_factors.copy() + + # get the new fringe contrast + log.debug(" analysing fit quality") + + pbg_fit, pbgindx = utils.fit_1d_background_complex(fringe_sub, + weights_feat, + col_wnum, + ffreq=ffreq[0], channel=c) + + # get the residual fringes as fraction of signal + fit_res = np.divide(fringe_sub, pbg_fit, out=np.zeros_like(fringe_sub), + where=pbg_fit != 0) + fit_res = np.subtract(fit_res, 1, where=fit_res != 0) + fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) + + # TODO: REMOVE CONTRAST CHECKS + # set dummy values for contrast check parameters until removed + contrast = 0.0 + quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), np.zeros(col_data.shape)]) + # contrast, quality = utils.fit_quality(col_wnum, + # fit_res, + # weights_feat, + # ffreq, + # dffreq, + # save_results=self.save_intermediate_results) + + out_table.add_row((ss, col, fn, snr2, pre_contrast, contrast, pgram_res[0], + opt_nfringe, peak_freq, freq_min, freq_max)) + + qual_table.add_row((col, quality)) + correction_quality.append([contrast, pre_contrast]) + log.debug(" residual contrast = {}".format(contrast)) + + # replace the corrected in-slice column pixels in the data_cor array + log.debug(" updating the trace pixels in the output") + output_data[idx, col] = fringe_sub[idx] + self.rfc_factors[idx, col] = rfc_factors[idx] + self.fit_mask[idx, col] = np.ones(1024)[idx] + self.weights_feat[idx, col] = weights_feat[idx] + self.weighted_pix_num[idx, col] = np.ones(1024)[idx] * (wpix_num / 1024) + self.rejected_fit[idx, col] = res_fringe_fit_flag[idx] + self.background_fit[idx, col] = bg_fit[idx] + self.knot_locations[:bgindx.shape[0], col] = bgindx + num_corrected = num_corrected + 1 + + except Exception as e: + log.warning(" Skipping col={} {} ".format(col, ss)) + log.warning(' %s' % (str(e))) del ss_data, ss_wmap, ss_weight # end on column @@ -458,7 +453,7 @@ def do_correction(self): # add the statistics to the Table stat_table.add_row((ss, mean, median, stddev, fmax, pmean, pmedian, pstddev, pfmax)) - del slice_x_ranges, all_slice_masks, slices_in_band, wave_map # end of channel + del slice_x_ranges, all_slice_masks, slices_in_band # end of channel log.info('Number of columns corrected for channel {}'.format(num_corrected)) log.info("Processing complete") @@ -489,9 +484,11 @@ def do_correction(self): basepath=self.input_model.meta.filename, suffix='fit_results', ext='.fits') log.info('Saving intermediate fit results output {}'.format(fit_results_name)) - h = fits.open(self.input_model.meta.filename) - hdr = h[0].header - h.close() + + # Get a primary header from the input model + hdul = fits_support.to_fits(self.input_model._instance, self.input_model._schema) + hdr = hdul[0].header + hdul.close() hdu0 = fits.PrimaryHDU(header=hdr) hdu1 = fits.ImageHDU(self.rfc_factors, name='RFC_FACTORS') @@ -536,6 +533,20 @@ def calc_weights(self): weights[np.isnan(weights)] = 0 return weights + def _get_wave_map(self): + """ + Get a wavelength map from the input WCS. + + Returns + ------- + ndarray + 2D map of wavelengths matching self.input.data. + """ + ysize = self.input_model.data.shape[0] + xsize = self.input_model.data.shape[1] + y, x = np.mgrid[:ysize, :xsize] + _, _, wave_map = self.input_model.meta.wcs(x, y) + return wave_map class ErrorNoFringeFlat(Exception): pass diff --git a/jwst/residual_fringe/residual_fringe_step.py b/jwst/residual_fringe/residual_fringe_step.py index fc73abedcf..843f2b93f9 100755 --- a/jwst/residual_fringe/residual_fringe_step.py +++ b/jwst/residual_fringe/residual_fringe_step.py @@ -1,9 +1,7 @@ -#! /usr/bin/env python from stdatamodels.jwst import datamodels -from ..stpipe import Step -from . import residual_fringe -from functools import partial +from jwst.stpipe import Step +from jwst.residual_fringe import residual_fringe __all__ = ["ResidualFringeStep"] @@ -27,7 +25,7 @@ class ResidualFringeStep(Step): ignore_region_min = list(default = None) ignore_region_max = list(default = None) suffix = string(default = 'residual_fringe') - """ # noqa: E501 + """ # noqa: E501 reference_file_types = ['fringefreq', 'regions'] @@ -58,7 +56,7 @@ def process(self, input): ignore_regions['num'] = min_num if min_num > 0: - self.log.info('Ignoring {} wavelength regions'.format(min_num)) + self.log.info(f'Ignoring {min_num} wavelength regions') self.ignore_regions = ignore_regions @@ -67,25 +65,7 @@ def process(self, input): if isinstance(input, datamodels.IFUImageModel): exptype = input.meta.exposure.type else: - raise TypeError("Failed to process file type {}".format(type(input))) - - # Setup output path naming if associations are involved. - asn_id = None - try: - asn_id = self.input.meta.asn_table.asn_id - except (AttributeError, KeyError): - pass - if asn_id is None: - asn_id = self.search_attr('asn_id') - if asn_id is not None: - _make_output_path = self.search_attr( - '_make_output_path', parent_first=True - ) - - self._make_output_path = partial( - _make_output_path, - asn_id=asn_id - ) + raise TypeError(f"Failed to process input type: {type(input)}") # Set up residual fringe correction parameters pars = { @@ -95,8 +75,8 @@ def process(self, input): } if exptype != 'MIR_MRS': - self.log(" Residual Fringe correction is only for MIRI MRS data") - self.log.error("Unsupported ", f"exposure type: {exptype}") + self.log.warning("Residual fringe correction is only for MIRI MRS data") + self.log.warning(f"Input is: {exptype}") input.meta.cal_step.residual_fringe = "SKIPPED" return input diff --git a/jwst/residual_fringe/tests/test_configuration.py b/jwst/residual_fringe/tests/test_configuration.py index 33ab11b589..6fc8cb4d1c 100644 --- a/jwst/residual_fringe/tests/test_configuration.py +++ b/jwst/residual_fringe/tests/test_configuration.py @@ -1,14 +1,14 @@ -""" -Unit test for Residual Fringe Correction for testing interface -""" +"""Unit tests for Residual Fringe Correction step interface.""" + +import logging import pytest import numpy as np - from stdatamodels.jwst import datamodels from jwst.residual_fringe import ResidualFringeStep from jwst.residual_fringe import residual_fringe +from jwst.tests.helpers import LogWatcher @pytest.fixture(scope='function') @@ -25,9 +25,18 @@ def miri_image(): return image -def test_call_residual_fringe(tmp_cwd, miri_image): - """ test defaults of step are set up and user input are defined correctly """ +@pytest.fixture() +def step_log_watcher(monkeypatch): + # Set a log watcher to check for a log message at any level + # in the emicorr step + watcher = LogWatcher("") + logger = logging.getLogger("stpipe.ResidualFringeStep") + for level in ["debug", "info", "warning", "error"]: + monkeypatch.setattr(logger, level, watcher) + return watcher + +def test_bad_ignore_regions(tmp_cwd, miri_image): # testing the ignore_regions_min # There has to be an equal number of min and max ignore region values # --ignore_region_min="4.9," --ignore_region_max='5.5," @@ -42,6 +51,24 @@ def test_call_residual_fringe(tmp_cwd, miri_image): step.run(miri_image) +def test_ignore_regions(tmp_cwd, monkeypatch, miri_image, step_log_watcher): + # Set some reasonable wavelength regions - these should be read in properly + step_log_watcher.message = "Ignoring 2 wavelength regions" + + step = ResidualFringeStep() + step.ignore_region_min = [4.9, 5.7] + step.ignore_region_max = [5.6, 6.5] + step.skip = False + + # monkeypatch the reference file retrieval so step aborts but does + # not error out for this incomplete input + monkeypatch.setattr(step, 'get_reference_file', lambda *args: 'N/A') + + # check for ignore regions log message + step.run(miri_image) + step_log_watcher.assert_seen() + + def test_fringe_flat_applied(tmp_cwd, miri_image): miri_image.meta.cal_step.fringe = 'SKIPPED' @@ -62,3 +89,19 @@ def test_fringe_flat_applied(tmp_cwd, miri_image): with pytest.raises(residual_fringe.ErrorNoFringeFlat): rfc.do_correction() + + +def test_rf_step_wrong_input_type(): + model = datamodels.ImageModel() + with pytest.raises(TypeError, match="Failed to process input type"): + ResidualFringeStep.call(model, skip=False) + + +def test_rf_step_wrong_exptype(miri_image, step_log_watcher): + model = miri_image + model.meta.exposure.type = 'NRS_IFU' + + step_log_watcher.message = "only for MIRI MRS" + result = ResidualFringeStep.call(model, skip=False) + assert result.meta.cal_step.residual_fringe == 'SKIPPED' + step_log_watcher.assert_seen() diff --git a/jwst/residual_fringe/tests/test_residual_fringe.py b/jwst/residual_fringe/tests/test_residual_fringe.py new file mode 100644 index 0000000000..9ba1937456 --- /dev/null +++ b/jwst/residual_fringe/tests/test_residual_fringe.py @@ -0,0 +1,286 @@ +import logging + +import numpy as np +import pytest +from stdatamodels.jwst import datamodels + +from jwst.residual_fringe.residual_fringe import utils, ResidualFringeCorrection +from jwst.residual_fringe.residual_fringe_step import ResidualFringeStep +from jwst.residual_fringe.utils import fit_residual_fringes_1d as rf1d +from jwst.tests.helpers import LogWatcher + + +@pytest.fixture() +def wave(): + x1, x2 = 10.0, 11.75 + wave_array = np.linspace(x1, x2, 1024) + return wave_array + + +@pytest.fixture() +def linear_spectrum(wave): + """Mock a spectrum with a linear continuum.""" + # linear flux signal between min and max wavelengths + x1, x2 = wave.min(), wave.max() + y1, y2 = 0.5, 0.4 + flux = (y2 - y1) / (x2 - x1) * (wave - x1) + y1 + + return wave, flux + + +@pytest.fixture() +def fringed_spectrum(linear_spectrum): + """Mock a spectrum with periodic fringe signature.""" + wave, flux = linear_spectrum + + # add a sinusoid signal on top of the linear flux + amp = 0.01 + period = 0.04 + fringe = amp * np.sin(2 * np.pi * wave / period) + return wave, flux + fringe + + +@pytest.fixture() +def miri_mrs_model_linear(monkeypatch, linear_spectrum): + shape = (1024, 10) + model = datamodels.IFUImageModel(shape) + model.meta.instrument.name = 'MIRI' + model.meta.instrument.detector = 'MIRIFUSHORT' + model.meta.instrument.channel = '12' + model.meta.instrument.band = 'SHORT' + model.meta.exposure.type = 'MIR_MRS' + model.meta.observation.date = '2022-05-01' + model.meta.observation.time = '01:01:01' + model.meta.cal_step.fringe = 'COMPLETE' + + wave, flux = linear_spectrum + model.data[:, :] = flux[:, None] + model.wavelength[:, :] = wave[:, None] + model.err = model.data * 0.01 + + return model + + +@pytest.fixture() +def miri_mrs_model_with_fringe(miri_mrs_model_linear, fringed_spectrum): + wave, flux = fringed_spectrum + model_copy = miri_mrs_model_linear.copy() + model_copy.data[:, :] = flux[:, None] + return model_copy + + +@pytest.fixture() +def mock_wavemap(monkeypatch, wave): + # mock the wavelength function to avoid making a full test WCS + def return_wavelength(*args): + wavemap = np.zeros((1024, 10)) + wavemap[:, :] = wave[:, None] + return wavemap + + monkeypatch.setattr(ResidualFringeCorrection, '_get_wave_map', return_wavelength) + + +@pytest.fixture() +def mock_wavemap_with_nans(monkeypatch, wave): + # mock the wavelength function to avoid making a full test WCS + def return_wavelength(*args): + wavemap = np.zeros((1024, 10)) + wavemap[:, :] = wave[:, None] + + # add some scattered NaN values + wavemap[::20, ::2] = np.nan + return wavemap + + monkeypatch.setattr(ResidualFringeCorrection, '_get_wave_map', return_wavelength) + + +@pytest.fixture() +def mock_slice_info_short(monkeypatch): + # mock a single slice to fit matching the test data, for testing speed + def one_slice(*args): + slices_in_band = [101] + xrange_channel = np.array([[0, 10]]) + slice_x_ranges = np.array([[101, 0, 10]]) + all_slice_masks = np.ones((1, 1024, 10)) + return slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks + + monkeypatch.setattr(utils, 'slice_info', one_slice) + + +@pytest.fixture() +def mock_slice_info_long(monkeypatch): + # mock a single slice to fit matching the test data, for testing speed + def one_slice(*args): + slices_in_band = [301] + xrange_channel = np.array([[0, 10]]) + slice_x_ranges = np.array([[301, 0, 10]]) + all_slice_masks = np.ones((1, 1024, 10)) + return slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks + + monkeypatch.setattr(utils, 'slice_info', one_slice) + + +@pytest.fixture() +def module_log_watcher(monkeypatch): + # Set a log watcher to check for a log message at any level + # in the emicorr module + watcher = LogWatcher("") + logger = logging.getLogger("jwst.residual_fringe.residual_fringe") + for level in ["debug", "info", "warning", "error"]: + monkeypatch.setattr(logger, level, watcher) + return watcher + + +def test_rf1d(linear_spectrum, fringed_spectrum): + """ + Test the performance of the 1d residual defringe routine. + + Input synthetic data mimics a Ch2C spectrum taken from observations + of bright star 16 CygB. + """ + wave, flux = fringed_spectrum + outflux = rf1d(flux, wave, channel=2) + + # defringing won't remove the pure sinusoidal fringe completely, but + # it should be reasonably close to linear and significantly better + # than no correction + expected_flux = linear_spectrum[1] + + # corrected output has small diffs from linear on average + # (edge effects might be larger) + relative_diff_output = np.abs(outflux - expected_flux)/expected_flux + assert np.nanmean(relative_diff_output) < 0.005 + + # input diffs from linear are much bigger + relative_diff_input = np.abs(flux - expected_flux)/expected_flux + assert np.nanmean(relative_diff_input) > 0.01 + + +def test_get_wavemap(): + # Test the _get_wavemap function directly, since + # all full calls to the correction method mock it + model = datamodels.IFUImageModel() + + # Mock a WCS that returns 1 for wavelengths + def return_ones(x, y): + return None, None, np.ones(x.shape) + model.meta.wcs = return_ones + + rf = ResidualFringeCorrection(model, "N/A", "N/A", None) + wavemap = rf._get_wave_map() + assert wavemap.shape == model.data.shape + assert np.all(wavemap == 1.0) + + +@pytest.mark.parametrize('band', ['SHORT', 'MEDIUM', 'LONG']) +def test_rf_step_short(miri_mrs_model_linear, miri_mrs_model_with_fringe, + mock_slice_info_short, mock_wavemap, band): + model = miri_mrs_model_with_fringe + model.meta.instrument.band = band + result = ResidualFringeStep.call(model, skip=False) + + assert result.meta.cal_step.residual_fringe == 'COMPLETE' + + # output should be closer to a linear spectrum than input, + # correction will not be precise + expected = miri_mrs_model_linear.data + relative_diff_input = np.abs(model.data - expected) / expected + relative_diff_output = np.abs(result.data - expected) / expected + assert np.nanmean(relative_diff_output) < np.nanmean(relative_diff_input) + + +@pytest.mark.parametrize('band', ['SHORT', 'MEDIUM', 'LONG']) +def test_rf_step_long(miri_mrs_model_with_fringe, mock_slice_info_long, mock_wavemap, + band, module_log_watcher): + model = miri_mrs_model_with_fringe + model.meta.instrument.detector = 'MIRIFULONG' + model.meta.instrument.channel = '34' + model.meta.instrument.band = band + + # Synthetic input data is reasonable for MIRIFUSHORT, but is expected + # to fail with a warning when treated as MIRIFULONG. + module_log_watcher.message = "Skipping col" + result = ResidualFringeStep.call(model, skip=False) + module_log_watcher.assert_seen() + + # Output data should be identical to input, although step is complete + assert result.meta.cal_step.residual_fringe == 'COMPLETE' + assert np.allclose(model.data, result.data) + + +def test_rf_step_nans_in_wavelength( + miri_mrs_model_linear, miri_mrs_model_with_fringe, + mock_slice_info_short, mock_wavemap_with_nans): + model = miri_mrs_model_with_fringe + + # wavelength array has some scattered NaNs: + # they should be interpolated over and correction should succeed + result = ResidualFringeStep.call(model, skip=False) + + # output should be closer to a linear spectrum than input, + # correction will not be precise + assert result.meta.cal_step.residual_fringe == 'COMPLETE' + expected = miri_mrs_model_linear.data + relative_diff_input = np.abs(model.data - expected) / expected + relative_diff_output = np.abs(result.data - expected) / expected + assert np.nanmean(relative_diff_output) < np.nanmean(relative_diff_input) + + + +def test_rf_step_save_intermediate(tmp_path, miri_mrs_model_with_fringe, + mock_slice_info_short, mock_wavemap): + model = miri_mrs_model_with_fringe + model.meta.filename = 'test.fits' + ResidualFringeStep.call(model, skip=False, output_dir=str(tmp_path), + save_results=True, save_intermediate_results=True) + + output_files = ['test_residual_fringe.fits', + 'test_stat_table.ecsv', + 'test_out_table.ecsv', + 'test_fit_results.fits'] + for output_file in output_files: + assert (tmp_path / output_file).exists() + + +def test_rf_step_ignore_regions(miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap): + model = miri_mrs_model_with_fringe + + # ignore all the data + ignore_region_min = [model.wavelength.min()] + ignore_region_max = [model.wavelength.max()] + result = ResidualFringeStep.call( + model, skip=False, ignore_region_min=ignore_region_min, + ignore_region_max=ignore_region_max) + + # output should be the same as input + assert np.allclose(model.data, result.data) + + +def test_rf_step_low_snr(miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap, + module_log_watcher): + model = miri_mrs_model_with_fringe + + # set all the data to a very small value so SNR is too low to fit + model.data[:] = 1e-6 + + module_log_watcher.message = "SNR too low" + result = ResidualFringeStep.call(model, skip=False) + module_log_watcher.assert_seen() + + # output should be the same as input + assert np.allclose(model.data, result.data) + + +def test_rf_step_weights_gap(miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap): + model = miri_mrs_model_with_fringe + + # set some big emission lines in rows 1 and -2: + # this triggers a weighting edge case that ignores the first + # and last rows + model.data[1, :] = 10 + model.data[-2, :] = 10 + + result = ResidualFringeStep.call(model, skip=False) + + # Fit should complete + assert not np.allclose(result.data, model.data) diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index 3baef3d1bf..95dcaeca9b 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -191,45 +191,6 @@ def fit_envelope(wavenum, signal): return pcl(wavenum), l_x, l_y, pcu(wavenum), u_x, u_y -def find_lines_resfringe(signal, max_amp): - """ - *** Replaced with find_lines below. This version does not include some of the - feature finding functionality*** - - Take signal and max amp array, determine location of spectral - features with amplitudes greater than max amp - - :param signal: - :param max_amp: - :return: - """ - - r_x = np.arange(signal.shape[0] - 1) - - # setup the output arrays - signal_check = signal.copy() - weights_factors = np.ones(signal.shape[0]) - - # Detect peaks - - u_y, u_x, l_y, l_x = [], [], [], [] - - for x in r_x: - if ((np.sign(signal_check[x] - signal_check[x - 1]) == 1) and - (np.sign(signal_check[x] - signal_check[x + 1]) == 1)): - u_y.append(signal_check[x]) - u_x.append(x) - - if ((np.sign(signal_check[x] - signal_check[x - 1]) == -1) and - (np.sign(signal_check[x] - signal_check[x + 1]) == -1)): - l_y.append(signal[x]) - l_x.append(x) - - log.debug("find_lines: Found {} peaks {} troughs".format(len(u_x), len(l_x))) - weights_factors[signal_check > max_amp] = 0 - return weights_factors - - def find_lines(signal, max_amp): """ Take signal and max amp array, determine location of spectral diff --git a/pyproject.toml b/pyproject.toml index 1ff2d039fb..2b9c93d4db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ "asdf>=3.3,<5", "astropy>=6.1", - "BayesicFitting>=3.0.1", + "BayesicFitting>=3.2.2", "crds>=12.0.3", "drizzle>=2.0.1", # "gwcs>=0.22.0,<0.23.0", From 676e34d83ff53c95c434f040ecb6d0e5c2061537 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Fri, 28 Feb 2025 16:50:08 -0500 Subject: [PATCH 2/7] Remove unused contrast/quality code --- jwst/residual_fringe/residual_fringe.py | 81 ++++------------------- jwst/residual_fringe/utils.py | 85 +------------------------ 2 files changed, 13 insertions(+), 153 deletions(-) diff --git a/jwst/residual_fringe/residual_fringe.py b/jwst/residual_fringe/residual_fringe.py index c12e93ceb0..b7612594f6 100644 --- a/jwst/residual_fringe/residual_fringe.py +++ b/jwst/residual_fringe/residual_fringe.py @@ -147,20 +147,13 @@ def do_correction(self): self.knot_locations = np.full_like(self.input_model.data, np.nan) allregions.close() - # intermediate output product - Table + # intermediate output product - Tables stat_table = Table(names=('Slice', 'mean', 'median', 'stddev', 'max', 'pmean', 'pmedian', 'pstddev', 'pmax'), dtype=('i4', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')) - out_table = Table(names=('Slice', 'col', 'fringe', 'sn', 'pre_contrast', 'post_contrast', - 'periodogram_res', 'opt_fringes', 'peak_freq', 'freq_min', 'freq_max'), - dtype=('i4', 'i4', 'i4', 'f8', 'f8', 'f8', - 'f8', 'f8', 'f8', 'f8', 'f8')) - - # intermediate output - n_wav_samples = 1024 - qual_table = Table(names=('column', 'quality'), - units=('', ''), - dtype=('i4', '(3,{})f8'.format(n_wav_samples))) + out_table = Table(names=('Slice', 'col', 'fringe', 'sn', 'periodogram_res', + 'opt_fringes', 'peak_freq', 'freq_min', 'freq_max'), + dtype=('i4', 'i4', 'i4', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')) wave_map = self._get_wave_map() for c in self.channels: @@ -182,15 +175,12 @@ def do_correction(self): for n, ss in enumerate(slices_in_band): log.info(" Processing slice {} =================================".format(ss)) log.debug(" X ranges of slice {} {}".format(slice_x_ranges[n, 1], slice_x_ranges[n, 2])) - # initialise the list to store correction quality numbers for slice - correction_quality = [] # use the mask to set all out-of-slice pixels to 0 in wmap and data # set out-of-slice pixels to 0 in arrays ss_data = all_slice_masks[n] * output_data.copy() ss_wmap = all_slice_masks[n] * wave_map ss_weight = all_slice_masks[n] * self.input_weights.copy() - # ss_mask = all_slice_masks[n] # get the freq_table info for this slice this_row = np.where(self.freq_table['slice'] == float(ss))[0][0] @@ -292,11 +282,10 @@ def do_correction(self): weight_factors = utils.find_lines(mod, col_max_amp * 2) weights_feat *= weight_factors - # iterate over the fringe components to fit, initialize pre-contrast, other output arrays + # iterate over the fringe components to fit, initialize other output arrays # in case fit fails proc_data = col_data.copy() proc_factors = np.ones(col_data.shape) - pre_contrast = 0.0 bg_fit = col_data.copy() res_fringe_fit_flag = np.zeros(col_data.shape) wpix_num = 1024 @@ -339,21 +328,7 @@ def do_correction(self): where=bg_fit != 0) res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) - # get the pre-correction contrast using fringe component 1 - # TODO: REMOVE CONTRAST CHECKS - # set dummy values for contrast check parameters until removed - pre_contrast = 0.0 - quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), - np.zeros(col_data.shape)]) - # if fn == 0: - # pre_contrast, quality = utils.fit_quality(col_wnum, - # res_fringes, - # weights_feat, - # ffreq[0], - # dffreq[0]) - # - # log.debug(" pre-correction contrast = {}".format(pre_contrast)) - # + # fit the residual fringes log.debug(" set up bayes ") res_fringe_fit, wpix_num, opt_nfringe, peak_freq, freq_min, freq_max = \ @@ -385,45 +360,26 @@ def do_correction(self): proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) proc_data[proc_data < 0] = 1e-08 - out_table.add_row((ss, col, fn, snr2, pre_contrast, pre_contrast, pgram_res[fn], + out_table.add_row((ss, col, fn, snr2, pgram_res[fn], opt_nfringe, peak_freq, freq_min, freq_max)) # define fringe sub after all fringe components corrections fringe_sub = proc_data.copy() rfc_factors = proc_factors.copy() - # get the new fringe contrast - log.debug(" analysing fit quality") - + # get the residual fringes as fraction of signal pbg_fit, pbgindx = utils.fit_1d_background_complex(fringe_sub, weights_feat, col_wnum, ffreq=ffreq[0], channel=c) - - # get the residual fringes as fraction of signal fit_res = np.divide(fringe_sub, pbg_fit, out=np.zeros_like(fringe_sub), where=pbg_fit != 0) fit_res = np.subtract(fit_res, 1, where=fit_res != 0) fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) - # TODO: REMOVE CONTRAST CHECKS - # set dummy values for contrast check parameters until removed - contrast = 0.0 - quality = np.array([np.zeros(col_data.shape), np.zeros(col_data.shape), np.zeros(col_data.shape)]) - # contrast, quality = utils.fit_quality(col_wnum, - # fit_res, - # weights_feat, - # ffreq, - # dffreq, - # save_results=self.save_intermediate_results) - - out_table.add_row((ss, col, fn, snr2, pre_contrast, contrast, pgram_res[0], + out_table.add_row((ss, col, fn, snr2, pgram_res[0], opt_nfringe, peak_freq, freq_min, freq_max)) - qual_table.add_row((col, quality)) - correction_quality.append([contrast, pre_contrast]) - log.debug(" residual contrast = {}".format(contrast)) - # replace the corrected in-slice column pixels in the data_cor array log.debug(" updating the trace pixels in the output") output_data[idx, col] = fringe_sub[idx] @@ -442,17 +398,6 @@ def do_correction(self): del ss_data, ss_wmap, ss_weight # end on column - # asses the fit quality statistics and set up data to make plot outside of step - log.debug(" analysing fit statistics") - if len(correction_quality) > 0: - contrasts = np.asarray(correction_quality)[:, 0] - pre_contrasts = np.asarray(correction_quality)[:, 1] - mean, median, stddev, fmax = utils.fit_quality_stats(contrasts) - pmean, pmedian, pstddev, pfmax = utils.fit_quality_stats(pre_contrasts) - - # add the statistics to the Table - stat_table.add_row((ss, mean, median, stddev, fmax, pmean, pmedian, pstddev, pfmax)) - del slice_x_ranges, all_slice_masks, slices_in_band # end of channel log.info('Number of columns corrected for channel {}'.format(num_corrected)) log.info("Processing complete") @@ -475,15 +420,13 @@ def do_correction(self): out_table_name = self.make_output_path( basepath=self.input_model.meta.filename, suffix='out_table', ext='.ecsv') - log.info(' Saving intermediate Output table {}'.format(out_table_name)) + log.info(f'Saving intermediate output table {out_table_name}') ascii.write(out_table, out_table_name, format='ecsv', fast_writer=False, overwrite=True) - t = fits.BinTableHDU(data=qual_table, name='FIT_QUAL') - fit_results_name = self.make_output_path( basepath=self.input_model.meta.filename, suffix='fit_results', ext='.fits') - log.info('Saving intermediate fit results output {}'.format(fit_results_name)) + log.info(f'Saving intermediate fit results output {fit_results_name}') # Get a primary header from the input model hdul = fits_support.to_fits(self.input_model._instance, self.input_model._schema) @@ -498,7 +441,7 @@ def do_correction(self): hdu5 = fits.ImageHDU(self.background_fit, name='BACKGROUND_FIT') hdu6 = fits.ImageHDU(self.knot_locations, name='KNOT_LOCATIONS') - hdu = fits.HDUList([hdu0, hdu1, hdu2, hdu3, hdu4, hdu5, hdu6, t]) + hdu = fits.HDUList([hdu0, hdu1, hdu2, hdu3, hdu4, hdu5, hdu6]) hdu.writeto(fit_results_name, overwrite=True) hdu.close() diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index 95dcaeca9b..b7205b89c5 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -15,7 +15,7 @@ import logging log = logging.getLogger(__name__) log.setLevel(logging.INFO) -# + # hard coded parameters, have been selected based on testing but can be changed NUM_KNOTS = 80 # number of knots for bkg model if no other info provided @@ -107,22 +107,6 @@ def fill_wavenumbers(wnums): return wnums_filled -def fit_quality_stats(stats): - """Get simple statistics for the fits - - :Parameters: - - fit_stats: np.array - the fringe contrast per fit - - :Returns: - - median, stddev, max of stat numpy array - - """ - return np.mean(stats), np.median(stats), np.std(stats), np.amax(stats) - - def multi_sine(n): """ Return a model composed of n sines @@ -455,73 +439,6 @@ def fit_1d_background_complex(flux, weights, wavenum, order=2, ffreq=None, chann return bg_fit, bgindx -def fit_quality(wavenum, res_fringes, weights, ffreq, dffreq, save_results=False): - """Determine the post correction fringe residual - - Fit a single sine model to the corrected array to get the post correction fringe residual - - :Parameters: - - wavenum: numpy array - the wavenum array - - res_fringes: numpy array - the residual fringe fit data - - weights: numpy array - the weights array - - ffreq: float, required - the central scan frequency - - dffreq: float, required - the one-sided interval of scan frequencies - - :Returns: - - fringe_res_amp: numpy array - the post correction fringe residual amplitude - - """ - ffreq, dffreq = 2.8, 0.2 - - # fit the residual with a single sine model - # use a Lomb-Scargle periodogram to get PSD and identify the strongest frequency - freq = np.linspace(ffreq - dffreq, ffreq + dffreq, 100) - - # handle out of slice pixels - res_fringes = np.nan_to_num(res_fringes) - res_fringe_scan = res_fringes[np.where(weights > 1e-05)] - wavenum_scan = wavenum[np.where(weights > 1e-05)] - pgram = LombScargle(wavenum_scan[::-1], res_fringe_scan[::-1]).power(1 / freq) - peak = np.argmax(pgram) - peak_freq = freq[peak] - log.debug("fit_quality: strongest frequency is {}".format(peak_freq)) - - # create the model - mdl = SineModel(pars=[0.1, 0.1], fixed={0: 1 / peak_freq}) - - fitter = LevenbergMarquardtFitter(wavenum[10:-10], mdl) - ftr = RobustShell(fitter, domain=10) - - fr_par = ftr.fit(res_fringes[10:-10], weights=weights[10:-10]) - log.debug("fit_quality: best fit pars: {}".format(fr_par)) - - if np.abs(fr_par[0]) > np.abs(fr_par[1]): - contrast = np.abs(round(fr_par[0] * 2, 3)) - else: - contrast = np.abs(round(fr_par[1] * 2, 3)) - - # make data to return for fit quality - quality = None - if save_results: - best_mdl = SineModel(fixed={0: 1 / peak_freq, 1: fr_par[0], 2: fr_par[1]}) - fit = best_mdl.result(wavenum) - quality = np.array([(10000.0 / wavenum), res_fringes, fit]) - - return contrast, quality - - def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffreq, min_nfringes, max_nfringes, pgram_res, col_snr2): From d455eb77823cde738df48ff4d0d05fad00a06019 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Fri, 28 Feb 2025 17:37:45 -0500 Subject: [PATCH 3/7] Test for slice info --- jwst/residual_fringe/tests/test_utils.py | 39 ++++++++++++++++++++++++ jwst/residual_fringe/utils.py | 15 ++------- 2 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 jwst/residual_fringe/tests/test_utils.py diff --git a/jwst/residual_fringe/tests/test_utils.py b/jwst/residual_fringe/tests/test_utils.py new file mode 100644 index 0000000000..4eecbd7c63 --- /dev/null +++ b/jwst/residual_fringe/tests/test_utils.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from jwst.residual_fringe import utils + +@pytest.fixture() +def slice_map(): + """ + Make a slice map image containing 4 slices. + + There will be two channel 1 slices and two channel 2. + """ + shape = (20, 90) + map_image = np.zeros(shape) + map_image[:, 10:20] = 101 + map_image[:, 30:40] = 102 + map_image[:, 50:60] = 201 + map_image[:, 70:80] = 202 + return map_image + + +def test_slice_info_ch1(slice_map): + result = utils.slice_info(slice_map, 1) + slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks = result + assert np.all(slices_in_band == [101, 102]) + assert np.all(xrange_channel == [9, 40]) + assert np.all(slice_x_ranges == [[101, 9, 20], [102, 29, 40]]) + assert np.sum(all_slice_masks[0]) == 200 + assert np.sum(all_slice_masks[1]) == 200 + + +def test_slice_info_ch2(slice_map): + result = utils.slice_info(slice_map, 2) + slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks = result + assert np.all(slices_in_band == [201, 202]) + assert np.all(xrange_channel == [49, 80]) + assert np.all(slice_x_ranges == [[201, 49, 60], [202, 69, 80]]) + assert np.sum(all_slice_masks[0]) == 200 + assert np.sum(all_slice_masks[1]) == 200 diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index b7205b89c5..9a32c74784 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -16,18 +16,9 @@ log = logging.getLogger(__name__) log.setLevel(logging.INFO) - -# hard coded parameters, have been selected based on testing but can be changed -NUM_KNOTS = 80 # number of knots for bkg model if no other info provided - - -def find_nearest(array, value): - """ Utility function to find the index of pixel with value in 'array' nearest to 'value'. - - Used by det_pixel_trace function - """ - idx = (np.abs(array - value)).argmin() - return idx +# Number of knots for bkg model if no other info provided +# Hard coded parameter, has been selected based on testing but can be changed +NUM_KNOTS = 80 def slice_info(slice_map, c): From 2cb3b255ac8a7d0bc91165a7390d6909a3fb88d5 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Tue, 4 Mar 2025 13:43:08 -0500 Subject: [PATCH 4/7] Handle warnings --- jwst/regtest/test_miri_mrs_spec2.py | 2 +- jwst/residual_fringe/fitter.py | 9 ++++----- jwst/residual_fringe/residual_fringe.py | 9 +++++++-- jwst/residual_fringe/utils.py | 1 - 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/jwst/regtest/test_miri_mrs_spec2.py b/jwst/regtest/test_miri_mrs_spec2.py index e2d66d4738..518baceb07 100644 --- a/jwst/regtest/test_miri_mrs_spec2.py +++ b/jwst/regtest/test_miri_mrs_spec2.py @@ -105,7 +105,7 @@ def test_miri_mrs_wcs(run_spec2, fitsdiff_default_kwargs, rtdata_module): @pytest.mark.slow @pytest.mark.bigdata -@pytest.mark.parametrize('suffix', ['rf_residual_fringe', 'rf_s3d', 'rf_x1d', 'rf_cal']) +@pytest.mark.parametrize('suffix', ['residual_fringe', 'rf_s3d', 'rf_x1d', 'rf_cal']) def test_miri_mrs_spec2_with_rf(run_spec2_with_residual_fringe, fitsdiff_default_kwargs, suffix, rtdata_module): rtdata = rtdata_module diff --git a/jwst/residual_fringe/fitter.py b/jwst/residual_fringe/fitter.py index 1928d0b656..ca01cf4b78 100644 --- a/jwst/residual_fringe/fitter.py +++ b/jwst/residual_fringe/fitter.py @@ -1,6 +1,6 @@ +import warnings import numpy as np - import scipy.interpolate @@ -25,11 +25,10 @@ def chi_sq(spline, weights): nparams = len(knots) + (degree + 1) * 2 deg_of_freedom = np.sum(weights) - nparams - if deg_of_freedom <= 0: - raise RuntimeError("Degrees of freedom <= 0") - for _ in range(1000 * nparams): - scale = np.sqrt(chi / deg_of_freedom) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + scale = np.sqrt(chi / deg_of_freedom) # Calculate new weights resid = (y - spline(x)) / (scale * domain) diff --git a/jwst/residual_fringe/residual_fringe.py b/jwst/residual_fringe/residual_fringe.py index b7612594f6..7b3dc7c8e0 100644 --- a/jwst/residual_fringe/residual_fringe.py +++ b/jwst/residual_fringe/residual_fringe.py @@ -1,6 +1,7 @@ """Apply residual fringe correction.""" import logging +import warnings from functools import partial import numpy as np @@ -208,7 +209,9 @@ def do_correction(self): test_flux = col_data[valid] test_flux[test_flux < 0] = 1e-08 # Transform wavelength in micron to wavenumber in cm^-1. - col_wnum = 10000.0 / col_wmap + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + col_wnum = 10000.0 / col_wmap # use the error array to get col snr, used to remove noisey pixels col_snr = self.model.data.copy()[:, col] / self.model.err.copy()[:, col] @@ -468,7 +471,9 @@ def calc_weights(self): weights = np.zeros(self.input_model.data.shape) for c in np.arange(weights.shape[1]): flux_1d = self.input_model.data[:, c] - w = flux_1d / np.nanmean(flux_1d) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + w = flux_1d / np.nanmean(flux_1d) weights[:, c] = w # replace infs and nans in weights with 0 diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index 9a32c74784..a580e3bd41 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -198,7 +198,6 @@ def find_lines(signal, max_amp): for n, amp in enumerate(u_y): max_amp_val = max_amp[u_x[n]] - log.debug("find_lines: check if peak above max amp") if amp > max_amp_val: # peak in x From a7f9c614af8930a20ce5a7640fefc467bf8a6232 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Tue, 4 Mar 2025 15:17:32 -0500 Subject: [PATCH 5/7] Code style fixes and cleanup --- .pre-commit-config.yaml | 1 - .ruff.toml | 2 - jwst/residual_fringe/__init__.py | 4 +- jwst/residual_fringe/fitter.py | 30 +- jwst/residual_fringe/residual_fringe.py | 709 +++++++++-------- jwst/residual_fringe/residual_fringe_step.py | 105 +-- jwst/residual_fringe/tests/__init__.py | 0 .../tests/test_background_fit.py | 35 +- .../tests/test_configuration.py | 52 +- .../tests/test_residual_fringe.py | 103 +-- jwst/residual_fringe/tests/test_utils.py | 9 +- jwst/residual_fringe/utils.py | 717 ++++++++++-------- 12 files changed, 997 insertions(+), 770 deletions(-) create mode 100644 jwst/residual_fringe/tests/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a56d755dc8..b90da8965f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,6 @@ repos: jwst/refpix/.* | jwst/resample/.* | jwst/reset/.* | - jwst/residual_fringe/.* | jwst/rscd/.* | jwst/saturation/.* | jwst/scripts/.* | diff --git a/.ruff.toml b/.ruff.toml index 86f9e8e55e..32cb8f8b00 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -40,7 +40,6 @@ exclude = [ "jwst/refpix/**.py", "jwst/resample/**.py", "jwst/reset/**.py", - "jwst/residual_fringe/**.py", "jwst/rscd/**.py", "jwst/saturation/**.py", "jwst/scripts/**.py", @@ -140,7 +139,6 @@ ignore-fully-untyped = true # Turn off annotation checking for fully untyped co "jwst/refpix/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/resample/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/reset/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] -"jwst/residual_fringe/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/rscd/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/saturation/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/scripts/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] diff --git a/jwst/residual_fringe/__init__.py b/jwst/residual_fringe/__init__.py index 8b01f646b0..7f95b44f86 100644 --- a/jwst/residual_fringe/__init__.py +++ b/jwst/residual_fringe/__init__.py @@ -1,3 +1,5 @@ +"""Correct residual fringes in MIRI MRS data.""" + from .residual_fringe_step import ResidualFringeStep -__all__ = ['ResidualFringeStep'] +__all__ = ["ResidualFringeStep"] diff --git a/jwst/residual_fringe/fitter.py b/jwst/residual_fringe/fitter.py index ca01cf4b78..adf8927dd6 100644 --- a/jwst/residual_fringe/fitter.py +++ b/jwst/residual_fringe/fitter.py @@ -9,6 +9,34 @@ def _lsq_spline(x, y, weights, knots, degree): def spline_fitter(x, y, weights, knots, degree, reject_outliers=False, domain=10, tolerance=0.0001): + """ + Fit a spline function to 1D data. + + Parameters + ---------- + x : ndarray + Independent variable. Must be increasing. + y : ndarray + Dependent variable, matching dimensions of `x`. + weights : ndarray + Weights for spline fitting. Must be positive. + knots : ndarray + Interior knots for the spline. + degree : int + Degree of the spline to fit, >= 1 and <= 5. + reject_outliers : bool, optional + If True, iteratively fit the data with outlier rejection. + domain : int, optional + Factor controlling the outlier threshold when `reject_outliers` + is True. + tolerance : float, optional + Fit convergence tolerance when reject_outliers` is True. + + Returns + ------- + callable + The spline function fit to the data. + """ if not reject_outliers: return _lsq_spline(x, y, weights, knots, degree) @@ -32,7 +60,7 @@ def chi_sq(spline, weights): # Calculate new weights resid = (y - spline(x)) / (scale * domain) - new_w = (np.where(resid**2 <= 1, 1 - resid**2, 0.))**2 * weights + new_w = (np.where(resid**2 <= 1, 1 - resid**2, 0.0)) ** 2 * weights # Fit new model and find chi spline = _lsq_spline(x, y, new_w, knots, degree) diff --git a/jwst/residual_fringe/residual_fringe.py b/jwst/residual_fringe/residual_fringe.py index 7b3dc7c8e0..9d6268a2f8 100644 --- a/jwst/residual_fringe/residual_fringe.py +++ b/jwst/residual_fringe/residual_fringe.py @@ -2,11 +2,11 @@ import logging import warnings -from functools import partial import numpy as np from astropy.table import Table -from astropy.io import ascii, fits +from astropy.io import fits +from astropy.io import ascii as astropy_ascii from stdatamodels import fits_support from stdatamodels.jwst import datamodels @@ -18,17 +18,45 @@ class ResidualFringeCorrection: + """Calculate and apply correction for residual fringes.""" + + def __init__( + self, + input_model, + residual_fringe_reference_file, + regions_reference_file, + ignore_regions, + save_intermediate_results=False, + transmission_level=80, + make_output_path=None, + ): + """ + Manage residual fringe correction. - def __init__(self, - input_model, - residual_fringe_reference_file, - regions_reference_file, - ignore_regions, - save_intermediate_results=False, - transmission_level=80, - make_output_path=None): - - self.input_model = input_model.copy() + Parameters + ---------- + input_model : IFUImageModel + Input data to correct. + residual_fringe_reference_file : str + Path to FRINGEFREQ reference file. + regions_reference_file : str + Path to REGIONS reference file. + ignore_regions : dict + Wavelength regions to ignore. Keys are "num", "min", and "max. + Values are the number of regions specified (int), the list + of minimum wavelength values, and the list of maximum wavelength + values. Minimum and maximum lists must match. + save_intermediate_results : bool, optional + If True, intermediate files are saved to disk. + transmission_level : int, optional + The transmission level used to extract the appropriate region + definitions from the REGIONS reference file. + make_output_path : callable or None, optional + If provided, is used to create the output file names when + `save_intermediate_results` is True. If None, filenames + are created with the default `Step.make_output_path` function. + """ + self.input_model = input_model self.model = input_model.copy() self.residual_fringe_reference_file = residual_fringe_reference_file self.regions_reference_file = regions_reference_file @@ -38,7 +66,7 @@ def __init__(self, # define how filenames are created if make_output_path is None: - self.make_output_path = partial(Step._make_output_path, None) + self.make_output_path = Step().make_output_path else: self.make_output_path = make_output_path @@ -61,39 +89,32 @@ def __init__(self, self.diagnostic_mode = True def do_correction(self): - """ - Short Summary - ------------- - Residual Fringe-correct a JWST data model using a residual fringe model - - Parameters - ---------- - input_model: JWST data model - input science data model to be residual fringe-corrected + Apply residual fringe correction. - residual_fringe_model: JWST data model - data model containing residual fringe correction + Correction is applied to a model copied from self.input_model. Returns ------- - output_model: JWST data model - residual fringe-corrected science data model - + output_model : IFUImageModel + Datamodel with correction applied. """ - # Check that the fringe flat has been applied - if self.input_model.meta.cal_step.fringe != 'COMPLETE': - raise ErrorNoFringeFlat("The fringe flat step has not been run on file %s", - self.input_model.meta.filename) + if self.input_model.meta.cal_step.fringe != "COMPLETE": + raise NoFringeFlatError( + f"The fringe flat step has not been run on file {self.input_model.meta.filename}" + ) # Remove any NaN values and flagged DO_NOT_USE pixels from the data prior to processing # Set them to 0 for the residual fringe routine # They will be re-added at the end output_data = self.model.data.copy() - DO_NOT_USE = datamodels.dqflags.pixel["DO_NOT_USE"] - nanval_indx = np.where(np.logical_or(np.bitwise_and(self.model.dq, DO_NOT_USE).astype(bool), - ~np.isfinite(output_data))) + dnu = datamodels.dqflags.pixel["DO_NOT_USE"] + nanval_indx = np.where( + np.logical_or( + np.bitwise_and(self.model.dq, dnu).astype(bool), ~np.isfinite(output_data) + ) + ) output_data[nanval_indx] = 0 # normalise the output_data to remove units @@ -106,9 +127,9 @@ def do_correction(self): # read in the band band = self.input_model.meta.instrument.band.lower() - if band == 'short': + if band == "short": residual_fringe_table = residual_fringe_model.rfc_freq_short_table - elif band == 'medium': + elif band == "medium": residual_fringe_table = residual_fringe_model.rfc_freq_medium_table else: residual_fringe_table = residual_fringe_model.rfc_freq_long_table @@ -123,18 +144,18 @@ def do_correction(self): self.transmission_level = int(self.transmission_level / 10) slice_map = (allregions.regions)[self.transmission_level - 1, :, :].copy() - log.info(" Using {} throughput threshold.".format(self.transmission_level)) + log.info(f" Using {self.transmission_level} throughput threshold.") self.slice_map = slice_map # set up the channels for the detector detector = self.input_model.meta.instrument.detector.lower() - if 'short' in detector: + if "short" in detector: self.channels = [1, 2] - elif 'long' in detector: + elif "long" in detector: self.channels = [3, 4] - log.info("Detector {} {} ".format(detector, self.channels)) + log.info(f"Detector {detector} {self.channels} ") self.input_weights = self.calc_weights() self.weights_feat = self.input_weights.copy() @@ -149,264 +170,347 @@ def do_correction(self): allregions.close() # intermediate output product - Tables - stat_table = Table(names=('Slice', 'mean', 'median', 'stddev', 'max', 'pmean', 'pmedian', 'pstddev', 'pmax'), - dtype=('i4', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')) - - out_table = Table(names=('Slice', 'col', 'fringe', 'sn', 'periodogram_res', - 'opt_fringes', 'peak_freq', 'freq_min', 'freq_max'), - dtype=('i4', 'i4', 'i4', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')) + stat_table = Table( + names=( + "Slice", + "mean", + "median", + "stddev", + "max", + "pmean", + "pmedian", + "pstddev", + "pmax", + ), + dtype=("i4", "f8", "f8", "f8", "f8", "f8", "f8", "f8", "f8"), + ) + + out_table = Table( + names=( + "Slice", + "col", + "fringe", + "sn", + "periodogram_res", + "opt_fringes", + "peak_freq", + "freq_min", + "freq_max", + ), + dtype=("i4", "i4", "i4", "f8", "f8", "f8", "f8", "f8", "f8"), + ) wave_map = self._get_wave_map() for c in self.channels: num_corrected = 0 - log.info("Processing channel {}".format(c)) - (slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks) = \ - utils.slice_info(slice_map, c) - - log.debug(' Slice Ranges {}'.format(slice_x_ranges)) - - # if the user wants to ignore some values use the wave_map array to set the corresponding - # weight values to 0 - if self.ignore_regions['num'] > 0: - for r in range(self.ignore_regions['num']): - min_wave = self.ignore_regions['min'][r] - max_wave = self.ignore_regions['max'][r] + log.info(f"Processing channel {c}") + (slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks) = utils.slice_info( + slice_map, c + ) + + log.debug(f" Slice Ranges {slice_x_ranges}") + + # if the user wants to ignore some values, use the wave_map + # array to set the corresponding weight values to 0 + if self.ignore_regions["num"] > 0: + for r in range(self.ignore_regions["num"]): + min_wave = self.ignore_regions["min"][r] + max_wave = self.ignore_regions["max"][r] self.input_weights[((wave_map > min_wave) & (wave_map < max_wave))] = 0 - for n, ss in enumerate(slices_in_band): - log.info(" Processing slice {} =================================".format(ss)) - log.debug(" X ranges of slice {} {}".format(slice_x_ranges[n, 1], slice_x_ranges[n, 2])) + for n, ss in enumerate(slices_in_channel): + log.info(f" Processing slice {ss} =================================") + log.debug(f" X ranges of slice {slice_x_ranges[n, 1]} {slice_x_ranges[n, 2]}") # use the mask to set all out-of-slice pixels to 0 in wmap and data # set out-of-slice pixels to 0 in arrays - ss_data = all_slice_masks[n] * output_data.copy() + ss_data = all_slice_masks[n] * output_data ss_wmap = all_slice_masks[n] * wave_map - ss_weight = all_slice_masks[n] * self.input_weights.copy() + ss_weight = all_slice_masks[n] * self.input_weights # get the freq_table info for this slice - this_row = np.where(self.freq_table['slice'] == float(ss))[0][0] - log.debug('Row in reference file for slice {}'.format(this_row)) + this_row = np.where(self.freq_table["slice"] == float(ss))[0][0] + log.debug(f"Row in reference file for slice {this_row}") - slice_row = self.freq_table[(self.freq_table['slice'] == float(ss))] - ffreq = slice_row['ffreq'][0] - dffreq = slice_row['dffreq'][0] - min_nfringes = slice_row['min_nfringes'][0] - max_nfringes = slice_row['max_nfringes'][0] - min_snr = slice_row['min_snr'][0] - pgram_res = slice_row['pgram_res'][0] + slice_row = self.freq_table[(self.freq_table["slice"] == float(ss))] + ffreq = slice_row["ffreq"][0] + dffreq = slice_row["dffreq"][0] + max_nfringes = slice_row["max_nfringes"][0] + min_snr = slice_row["min_snr"][0] + pgram_res = slice_row["pgram_res"][0] # cycle through the cols and fit the fringes for col in np.arange(slice_x_ranges[n, 1], slice_x_ranges[n, 2]): col_data = ss_data[:, col] col_wmap = ss_wmap[:, col] + # because of the curvature of the slices there can be + # large regions not falling on a column valid = np.logical_and((col_wmap > 0), ~np.isnan(col_wmap)) - # because of the curvature of the slices there can be large regions not falling on a column num_good = len(np.where(valid)[0]) - # Need at least 50 pixels in column to proceed - - if num_good > 50: - test_flux = col_data[valid] - test_flux[test_flux < 0] = 1e-08 - # Transform wavelength in micron to wavenumber in cm^-1. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - col_wnum = 10000.0 / col_wmap - - # use the error array to get col snr, used to remove noisey pixels - col_snr = self.model.data.copy()[:, col] / self.model.err.copy()[:, col] - - # do some checks on column to make sure there is reasonable signal. If the SNR < min_snr (CDP), pass - # determine SNR for this column of data - n = len(test_flux) - signal = np.nanmean(test_flux) - noise = 0.6052697 * np.nanmedian(np.abs(2.0 * test_flux[2:n - 2] - test_flux[0:n - 4] - test_flux[4:n])) - - snr2 = 0.0 # initialize - if noise != 0: - snr2 = signal / noise - - # Sometimes can return nan, inf for bad data so include this in check - if snr2 < min_snr[0]: - log.debug(f'SNR too low not fitting column {col}, {snr2}, {min_snr[0]}') - continue - log.debug(f"Fitting column {col}") - log.debug(f"SNR > {min_snr[0]} ") - - col_weight = ss_weight[:, col] - col_max_amp = np.interp(col_wmap, self.max_amp['Wavelength'], self.max_amp['Amplitude']) - col_snr2 = np.where(col_snr > 10, 1, 0) # hardcoded at snr > 10 for now - - # get the in-slice pixel indices for replacing in output later - idx = np.where(col_data > 0) - - # BayesicFitting doesn't like 0s at data or weight array edges so set to small value - # replacing array 0s with arbitrarily low number - col_data[col_data <= 0] = 1e-08 - col_weight[col_weight <= 0] = 1e-08 - - # check for off-slice pixels and send to be filled with interpolated/extrapolated wnums - # to stop BayesicFitting crashing, will not be fitted anyway - # finding out-of-slice pixels in column and filling - - found_bad = np.logical_or(np.isnan(col_wnum), np.isinf(col_wnum)) - num_bad = len(np.where(found_bad)[0]) - - if num_bad > 0: - col_wnum[found_bad] = 0 - col_wnum = utils.fill_wavenumbers(col_wnum) - - # do feature finding on slice now column-by-column - log.debug(" starting feature finding") + # Need at least 50 pixels in column to proceed + if num_good <= 50: + continue + + test_flux = col_data[valid] + test_flux[test_flux < 0] = 1e-08 + # Transform wavelength in micron to wavenumber in cm^-1. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + col_wnum = 10000.0 / col_wmap + + # use the error array to get col snr, used to remove noisy pixels + col_snr = self.model.data[:, col] / self.model.err[:, col] + + # Do some checks on column to make sure there is + # reasonable signal. If the SNR < min_snr (CDP), pass + n = len(test_flux) + signal = np.nanmean(test_flux) + noise = 0.6052697 * np.nanmedian( + np.abs(2.0 * test_flux[2 : n - 2] - test_flux[0 : n - 4] - test_flux[4:n]) + ) + + snr2 = 0.0 # initialize + if noise != 0: + snr2 = signal / noise + + # Sometimes can return nan, inf for bad data so include this in check + if snr2 < min_snr[0]: + log.debug(f"SNR too low not fitting column {col}, {snr2}, {min_snr[0]}") + continue + + log.debug(f"Fitting column {col}") + log.debug(f"SNR > {min_snr[0]} ") + + col_weight = ss_weight[:, col] + col_max_amp = np.interp( + col_wmap, self.max_amp["Wavelength"], self.max_amp["Amplitude"] + ) + col_snr2 = np.where(col_snr > 10, 1, 0) # hardcoded at snr > 10 for now + + # get the in-slice pixel indices for replacing in output later + idx = np.where(col_data > 0) + + # BayesicFitting doesn't like zeros at data or weight array + # edges so set zeros to an arbitrarily small value + col_data[col_data <= 0] = 1e-08 + col_weight[col_weight <= 0] = 1e-08 + + # Check for off-slice pixels and send to be filled with + # interpolated/extrapolated wnums to stop BayesicFitting from + # crashing. They will not be fitted anyway. + found_bad = np.logical_or(np.isnan(col_wnum), np.isinf(col_wnum)) + num_bad = len(np.where(found_bad)[0]) + + if num_bad > 0: + col_wnum[found_bad] = 0 + col_wnum = utils.fill_wavenumbers(col_wnum) + + # do feature finding on slice now column-by-column + log.debug(" Starting feature finding") + + # narrow features (similar or less than fringe #1 period) + # find spectral features (env is spline fit of troughs and peaks) + env, l_x, l_y, _, _, _ = utils.fit_envelope( + np.arange(col_data.shape[0]), col_data + ) + mod = np.abs(col_data / env) - 1 + + # Use col_snr to ignore noisy pixels: + # given signal in mod, find location of + # lines > col_max_amp * 2 (fringe contrast) + weight_factors = utils.find_lines(mod * col_snr2, col_max_amp * 2) + weights_feat = col_weight * weight_factors + + # account for fringe 2 on broad features in channels 3 and 4 + # need to smooth out the dichroic fringe as it breaks + # the feature finding method + if c in [3, 4]: + # smoothing window hardcoded to 7 for now (based on testing) + win = 7 + cumsum = np.cumsum(np.insert(col_data, 0, 0)) + sm_col_data = (cumsum[win:] - cumsum[:-win]) / float(win) - # narrow features (similar or less than fringe #1 period) # find spectral features (env is spline fit of troughs and peaks) - env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), col_data) + env, l_x, l_y, _, _, _ = utils.fit_envelope( + np.arange(col_data.shape[0]), sm_col_data + ) mod = np.abs(col_data / env) - 1 - # given signal in mod find location of lines > col_max_amp * 2 (fringe contrast) - # use col_snr to ignore noisy pixels - weight_factors = utils.find_lines(mod * col_snr2, col_max_amp * 2) - weights_feat = col_weight * weight_factors - - # account for fringe 2 on broad features in channels 3 and 4 - # need to smooth out the dichroic fringe as it breaks the feature finding method - if c in [3, 4]: - win = 7 # smoothing window hardcoded to 7 for now (based on testing) - cumsum = np.cumsum(np.insert(col_data, 0, 0)) - sm_col_data = (cumsum[win:] - cumsum[:-win]) / float(win) - - # find spectral features (env is spline fit of troughs and peaks) - env, l_x, l_y, _, _, _ = utils.fit_envelope(np.arange(col_data.shape[0]), sm_col_data) - mod = np.abs(col_data / env) - 1 - - # given signal in mod find location of lines > col_max_amp * 2 - weight_factors = utils.find_lines(mod, col_max_amp * 2) - weights_feat *= weight_factors - - # iterate over the fringe components to fit, initialize other output arrays - # in case fit fails - proc_data = col_data.copy() - proc_factors = np.ones(col_data.shape) - bg_fit = col_data.copy() - res_fringe_fit_flag = np.zeros(col_data.shape) - wpix_num = 1024 - - # check the end points. A single value followed by gap of zero can cause - # problems in the fitting. - index = np.where(weights_feat != 0.0) - length = np.diff(index[0]) - #print(weights_feat, index, length) - if weights_feat[0] != 0 and length[0] > 1: - weights_feat[0] = 1e-08 - - if weights_feat[-1] != 0 and length[-1] > 1: - weights_feat[-1] = 1e-08 - - # jane added this - fit can fail in evidence function. - # once we replace evidence function with astropy routine - we can test - # removing setting weights < 0.003 to zero (1e-08) - weights_feat[weights_feat <= 0.003] = 1e-08 - - # currently the reference file fits one fringe originating in the detector pixels, and a - # second high frequency, low amplitude fringe in channels 3 and 4 which has been - # attributed to the dichroics. - try: - for fn, ff in enumerate(ffreq): - # ignore place holder fringes - if ff > 1e-03: - log.debug(' starting ffreq = {}'.format(ff)) - - # check if snr criteria is met for fringe component, should always be true for fringe 1 - if snr2 > min_snr[fn]: - log.debug(" fitting spectral baseline") - - bg_fit, bgindx = \ - utils.fit_1d_background_complex(proc_data, weights_feat, - col_wnum, ffreq=ffreq[fn], channel=c) - - # get the residual fringes as fraction of signal - res_fringes = np.divide(proc_data, bg_fit, out=np.zeros_like(proc_data), - where=bg_fit != 0) - res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) - res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) - - # fit the residual fringes - log.debug(" set up bayes ") - res_fringe_fit, wpix_num, opt_nfringe, peak_freq, freq_min, freq_max = \ - utils.new_fit_1d_fringes_bayes_evidence(res_fringes, - weights_feat, - col_wnum, - ffreq[fn], - dffreq[fn], - min_nfringes=min_nfringes[fn], - max_nfringes=max_nfringes[fn], - pgram_res=pgram_res[fn], - col_snr2=col_snr2) - - # check for fit blowing up, reset rfc fit to 0, raise a flag - log.debug("check residual fringe fit for bad fit regions") - res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes(res_fringe_fit, - col_max_amp) - - # correct for residual fringes - log.debug(" divide out residual fringe fit, get fringe corrected column") - _, _, _, env, u_x, u_y = utils.fit_envelope(np.arange(res_fringe_fit.shape[0]), - res_fringe_fit) - - rfc_factors = 1 / (res_fringe_fit * (col_weight > 1e-05).astype(int) + 1) - proc_data *= rfc_factors - proc_factors *= rfc_factors - - # handle nans or infs that may exist - proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) - proc_data[proc_data < 0] = 1e-08 - - out_table.add_row((ss, col, fn, snr2, pgram_res[fn], - opt_nfringe, peak_freq, freq_min, freq_max)) - - # define fringe sub after all fringe components corrections - fringe_sub = proc_data.copy() - rfc_factors = proc_factors.copy() - - # get the residual fringes as fraction of signal - pbg_fit, pbgindx = utils.fit_1d_background_complex(fringe_sub, - weights_feat, - col_wnum, - ffreq=ffreq[0], channel=c) - fit_res = np.divide(fringe_sub, pbg_fit, out=np.zeros_like(fringe_sub), - where=pbg_fit != 0) - fit_res = np.subtract(fit_res, 1, where=fit_res != 0) - fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) - - out_table.add_row((ss, col, fn, snr2, pgram_res[0], - opt_nfringe, peak_freq, freq_min, freq_max)) - - # replace the corrected in-slice column pixels in the data_cor array - log.debug(" updating the trace pixels in the output") - output_data[idx, col] = fringe_sub[idx] - self.rfc_factors[idx, col] = rfc_factors[idx] - self.fit_mask[idx, col] = np.ones(1024)[idx] - self.weights_feat[idx, col] = weights_feat[idx] - self.weighted_pix_num[idx, col] = np.ones(1024)[idx] * (wpix_num / 1024) - self.rejected_fit[idx, col] = res_fringe_fit_flag[idx] - self.background_fit[idx, col] = bg_fit[idx] - self.knot_locations[:bgindx.shape[0], col] = bgindx - num_corrected = num_corrected + 1 - - except Exception as e: - log.warning(" Skipping col={} {} ".format(col, ss)) - log.warning(' %s' % (str(e))) - - del ss_data, ss_wmap, ss_weight # end on column - - del slice_x_ranges, all_slice_masks, slices_in_band # end of channel - log.info('Number of columns corrected for channel {}'.format(num_corrected)) + # given signal in mod find location of lines > col_max_amp * 2 + weight_factors = utils.find_lines(mod, col_max_amp * 2) + weights_feat *= weight_factors + + # iterate over the fringe components to fit, initialize other output arrays + # in case fit fails + proc_data = col_data.copy() + proc_factors = np.ones(col_data.shape) + bg_fit = col_data.copy() + res_fringe_fit_flag = np.zeros(col_data.shape) + wpix_num = 1024 + + # check the end points. A single value followed by gap of zero can cause + # problems in the fitting. + index = np.where(weights_feat != 0.0) + length = np.diff(index[0]) + + if weights_feat[0] != 0 and length[0] > 1: + weights_feat[0] = 1e-08 + + if weights_feat[-1] != 0 and length[-1] > 1: + weights_feat[-1] = 1e-08 + + # jane added this - fit can fail in evidence function. + # once we replace evidence function with astropy routine - we can test + # removing setting weights < 0.003 to zero (1e-08) + weights_feat[weights_feat <= 0.003] = 1e-08 + + # currently the reference file fits one fringe originating in the + # detector pixels, and a second high frequency, low amplitude fringe + # in channels 3 and 4 which has been attributed to the dichroics. + try: + for fn, ff in enumerate(ffreq): + # ignore place holder fringes + if ff > 1e-03: + log.debug(f" Start ffreq = {ff}") + + # check if snr criteria is met for fringe component, + # should always be true for fringe 1 + if snr2 > min_snr[fn]: + log.debug(" Fit spectral baseline") + + bg_fit, bgindx = utils.fit_1d_background_complex( + proc_data, + weights_feat, + col_wnum, + ffreq=ffreq[fn], + channel=c, + ) + + # get the residual fringes as fraction of signal + res_fringes = np.divide( + proc_data, + bg_fit, + out=np.zeros_like(proc_data), + where=bg_fit != 0, + ) + res_fringes = np.subtract( + res_fringes, 1, where=res_fringes != 0 + ) + res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) + + # fit the residual fringes + log.debug(" Set up Bayes evidence") + ( + res_fringe_fit, + wpix_num, + opt_nfringe, + peak_freq, + freq_min, + freq_max, + ) = utils.fit_1d_fringes_bayes_evidence( + res_fringes, + weights_feat, + col_wnum, + ffreq[fn], + dffreq[fn], + max_nfringes[fn], + pgram_res[fn], + col_snr2, + ) + + # check for fit blowing up, reset rfc fit to 0, raise a flag + log.debug(" Check residual fringe fit for bad fit regions") + res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes( + res_fringe_fit, col_max_amp + ) + + # correct for residual fringes + log.debug(" Divide out residual fringe fit") + _, _, _, env, u_x, u_y = utils.fit_envelope( + np.arange(res_fringe_fit.shape[0]), res_fringe_fit + ) + + rfc_factors = 1 / ( + res_fringe_fit * (col_weight > 1e-05).astype(int) + 1 + ) + proc_data *= rfc_factors + proc_factors *= rfc_factors + + # handle nans or infs that may exist + proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) + proc_data[proc_data < 0] = 1e-08 + + out_table.add_row( + ( + ss, + col, + fn, + snr2, + pgram_res[fn], + opt_nfringe, + peak_freq, + freq_min, + freq_max, + ) + ) + + # define fringe sub after all fringe components corrections + fringe_sub = proc_data.copy() + rfc_factors = proc_factors.copy() + + # get the residual fringes as fraction of signal + pbg_fit, pbgindx = utils.fit_1d_background_complex( + fringe_sub, weights_feat, col_wnum, ffreq=ffreq[0], channel=c + ) + fit_res = np.divide( + fringe_sub, + pbg_fit, + out=np.zeros_like(fringe_sub), + where=pbg_fit != 0, + ) + fit_res = np.subtract(fit_res, 1, where=fit_res != 0) + fit_res *= np.where(col_weight > 1e-07, 1, 1e-08) + + out_table.add_row( + ( + ss, + col, + fn, + snr2, + pgram_res[0], + opt_nfringe, + peak_freq, + freq_min, + freq_max, + ) + ) + + # replace the corrected in-slice column pixels in the data_cor array + log.debug(" Update the trace pixels in the output") + output_data[idx, col] = fringe_sub[idx] + self.rfc_factors[idx, col] = rfc_factors[idx] + self.fit_mask[idx, col] = np.ones(1024)[idx] + self.weights_feat[idx, col] = weights_feat[idx] + self.weighted_pix_num[idx, col] = np.ones(1024)[idx] * (wpix_num / 1024) + self.rejected_fit[idx, col] = res_fringe_fit_flag[idx] + self.background_fit[idx, col] = bg_fit[idx] + self.knot_locations[: bgindx.shape[0], col] = bgindx + num_corrected = num_corrected + 1 + + except Exception as e: + log.warning(f" Skipping col={col} {ss}:") + log.warning(f" {str(e)}") + + del ss_data, ss_wmap, ss_weight # end of column + + del slice_x_ranges, all_slice_masks, slices_in_channel # end of channel + log.info(f"Number of columns corrected for channel {num_corrected}") log.info("Processing complete") # add units back to output data - log.info(" adding units back to output array") + log.debug("Adding units back to output array") output_data *= normalization_factor # Add NaNs back to output data output_data[nanval_indx] = np.nan @@ -415,34 +519,38 @@ def do_correction(self): if self.save_intermediate_results: stat_table_name = self.make_output_path( - basepath=self.input_model.meta.filename, - suffix='stat_table', ext='.ecsv') - log.info(' Saving intermediate Stat table {}'.format(stat_table_name)) - ascii.write(stat_table, stat_table_name, format='ecsv', fast_writer=False, overwrite=True) + basepath=self.input_model.meta.filename, suffix="stat_table", ext=".ecsv" + ) + log.info(f"Saving intermediate stats table {stat_table_name}") + astropy_ascii.write( + stat_table, stat_table_name, format="ecsv", fast_writer=False, overwrite=True + ) out_table_name = self.make_output_path( - basepath=self.input_model.meta.filename, - suffix='out_table', ext='.ecsv') - log.info(f'Saving intermediate output table {out_table_name}') - ascii.write(out_table, out_table_name, format='ecsv', fast_writer=False, overwrite=True) + basepath=self.input_model.meta.filename, suffix="out_table", ext=".ecsv" + ) + log.info(f"Saving intermediate output table {out_table_name}") + astropy_ascii.write( + out_table, out_table_name, format="ecsv", fast_writer=False, overwrite=True + ) fit_results_name = self.make_output_path( - basepath=self.input_model.meta.filename, - suffix='fit_results', ext='.fits') - log.info(f'Saving intermediate fit results output {fit_results_name}') + basepath=self.input_model.meta.filename, suffix="fit_results", ext=".fits" + ) + log.info(f"Saving intermediate fit results output {fit_results_name}") # Get a primary header from the input model - hdul = fits_support.to_fits(self.input_model._instance, self.input_model._schema) + hdul = fits_support.to_fits(self.input_model._instance, self.input_model._schema) # noqa: SLF001 hdr = hdul[0].header hdul.close() hdu0 = fits.PrimaryHDU(header=hdr) - hdu1 = fits.ImageHDU(self.rfc_factors, name='RFC_FACTORS') - hdu2 = fits.ImageHDU(self.fit_mask, name='FIT_MASK') - hdu3 = fits.ImageHDU(self.weights_feat, name='WEIGHTS_FEATURES') - hdu4 = fits.ImageHDU(self.weighted_pix_num, name='WEIGHTED_PIXEL_FRACTION') - hdu5 = fits.ImageHDU(self.background_fit, name='BACKGROUND_FIT') - hdu6 = fits.ImageHDU(self.knot_locations, name='KNOT_LOCATIONS') + hdu1 = fits.ImageHDU(self.rfc_factors, name="RFC_FACTORS") + hdu2 = fits.ImageHDU(self.fit_mask, name="FIT_MASK") + hdu3 = fits.ImageHDU(self.weights_feat, name="WEIGHTS_FEATURES") + hdu4 = fits.ImageHDU(self.weighted_pix_num, name="WEIGHTED_PIXEL_FRACTION") + hdu5 = fits.ImageHDU(self.background_fit, name="BACKGROUND_FIT") + hdu6 = fits.ImageHDU(self.knot_locations, name="KNOT_LOCATIONS") hdu = fits.HDUList([hdu0, hdu1, hdu2, hdu3, hdu4, hdu5, hdu6]) hdu.writeto(fit_results_name, overwrite=True) @@ -451,23 +559,19 @@ def do_correction(self): return self.model def calc_weights(self): + """ + Make a weights array based on flux. - """Make a weights array based on flux. This is a placeholder function, - for now just returns a normalised flux array to use as weights array. - This is because any smoothing results in incorrect rfc around emission lines. + This is a placeholder function. For now, it just returns a normalised + flux array to use as a weights array. This is because any smoothing + results in incorrect fringe correction around emission lines. This can be changed in the future if need be. - :Parameters: - - flux: numpy array, required - the 1D array of fluxes - - :Returns: - - weights array which is just a copy of the flux array, normalised by the mean - + Returns + ------- + weights : ndarray + Weights array. """ - weights = np.zeros(self.input_model.data.shape) for c in np.arange(weights.shape[1]): flux_1d = self.input_model.data[:, c] @@ -496,5 +600,8 @@ def _get_wave_map(self): _, _, wave_map = self.input_model.meta.wcs(x, y) return wave_map -class ErrorNoFringeFlat(Exception): + +class NoFringeFlatError(Exception): + """Error raised when the input has not been fringe flat corrected.""" + pass diff --git a/jwst/residual_fringe/residual_fringe_step.py b/jwst/residual_fringe/residual_fringe_step.py index 843f2b93f9..8f8c3f5e73 100755 --- a/jwst/residual_fringe/residual_fringe_step.py +++ b/jwst/residual_fringe/residual_fringe_step.py @@ -8,15 +8,12 @@ class ResidualFringeStep(Step): """ - ResidualFringeStep: Apply residual fringe correction to a science image - using parameters in the residual fringe reference file. + Apply residual fringe correction to a MIRI MRS image. - Parameters - ---------- - input_data : asn file or single file + Requires frequency parameters provided in the FRINGEFREQ reference file. """ - class_alias = 'residual_fringe' + class_alias = "residual_fringe" spec = """ skip = boolean(default=True) @@ -27,91 +24,101 @@ class ResidualFringeStep(Step): suffix = string(default = 'residual_fringe') """ # noqa: E501 - reference_file_types = ['fringefreq', 'regions'] + reference_file_types = ["fringefreq", "regions"] - def process(self, input): + def process(self, input_data): + """ + Perform the residual fringe correction. + + Parameters + ---------- + input_data : str or IFUImageModel + Input data to correct. Must be a MIRI MRS IFU image. + + Returns + ------- + IFUImageModel + The corrected datamodel. + """ self.transmission_level = 80 # sets the transmission level to use in the regions file # 80% is what other steps use. # set up the dictionary to ignore wavelength regions in the residual fringe correction - ignore_regions = {} - ignore_regions['num'] = 0 - ignore_regions['min'] = [] - ignore_regions['max'] = [] + ignore_regions = {"num": 0, "min": [], "max": []} if self.ignore_region_min is not None: for region in self.ignore_region_min: - ignore_regions['min'].append(float(region)) + ignore_regions["min"].append(float(region)) - min_num = len(ignore_regions['min']) + min_num = len(ignore_regions["min"]) if self.ignore_region_max is not None: for region in self.ignore_region_max: - ignore_regions['max'].append(float(region)) - max_num = len(ignore_regions['max']) + ignore_regions["max"].append(float(region)) + max_num = len(ignore_regions["max"]) if max_num != min_num: self.log.error("Number of minimum and maximum wavelengths to ignore are not the same") raise ValueError("Number of ignore_region_min does not match ignore_region_max") - ignore_regions['num'] = min_num + ignore_regions["num"] = min_num if min_num > 0: - self.log.info(f'Ignoring {min_num} wavelength regions') + self.log.info(f"Ignoring {min_num} wavelength regions") self.ignore_regions = ignore_regions - input = datamodels.open(input) + input_data = datamodels.open(input_data) - if isinstance(input, datamodels.IFUImageModel): - exptype = input.meta.exposure.type + if isinstance(input_data, datamodels.IFUImageModel): + exptype = input_data.meta.exposure.type else: - raise TypeError(f"Failed to process input type: {type(input)}") + raise TypeError(f"Failed to process input type: {type(input_data)}") # Set up residual fringe correction parameters pars = { - 'transmission_level': self.transmission_level, - 'save_intermediate_results': self.save_intermediate_results, - 'make_output_path': self.make_output_path + "transmission_level": self.transmission_level, + "save_intermediate_results": self.save_intermediate_results, + "make_output_path": self.make_output_path, } - if exptype != 'MIR_MRS': + if exptype != "MIR_MRS": self.log.warning("Residual fringe correction is only for MIRI MRS data") self.log.warning(f"Input is: {exptype}") - input.meta.cal_step.residual_fringe = "SKIPPED" - return input + input_data.meta.cal_step.residual_fringe = "SKIPPED" + return input_data # 1. set up the reference files # 2. correct the model # 3. return from step - self.residual_fringe_filename = self.get_reference_file(input, 'fringefreq') - self.log.info('Using FRINGEFREQ reference file:{}'. - format(self.residual_fringe_filename)) + self.residual_fringe_filename = self.get_reference_file(input_data, "fringefreq") + self.log.info(f"Using FRINGEFREQ reference file:{self.residual_fringe_filename}") # set up regions reference file - self.regions_filename = self.get_reference_file(input, 'regions') - self.log.info('Using MRS regions reference file: {}'. - format(self.regions_filename)) + self.regions_filename = self.get_reference_file(input_data, "regions") + self.log.info(f"Using MRS regions reference file: {self.regions_filename}") # Check for a valid reference files. If they are not found skip step - if self.residual_fringe_filename == 'N/A' or self.regions_filename == 'N/A': - if self.residual_fringe_filename == 'N/A': - self.log.warning('No FRINGEFREQ reference file found') - self.log.warning('Residual Fringe step will be skipped') + if self.residual_fringe_filename == "N/A" or self.regions_filename == "N/A": + if self.residual_fringe_filename == "N/A": + self.log.warning("No FRINGEFREQ reference file found") + self.log.warning("Residual Fringe step will be skipped") - if self.regions_filename == 'N/A': - self.log.warning('No MRS regions reference file found') - self.log.warning('Residual Fringe step will be skipped') + if self.regions_filename == "N/A": + self.log.warning("No MRS regions reference file found") + self.log.warning("Residual Fringe step will be skipped") - input.meta.cal_step.residual_fringe = "SKIPPED" - return input + input_data.meta.cal_step.residual_fringe = "SKIPPED" + return input_data # Do the correction - rfc = residual_fringe.ResidualFringeCorrection(input, - self.residual_fringe_filename, - self.regions_filename, - self.ignore_regions, - **pars) + rfc = residual_fringe.ResidualFringeCorrection( + input_data, + self.residual_fringe_filename, + self.regions_filename, + self.ignore_regions, + **pars, + ) result = rfc.do_correction() - result.meta.cal_step.residual_fringe = 'COMPLETE' + result.meta.cal_step.residual_fringe = "COMPLETE" return result diff --git a/jwst/residual_fringe/tests/__init__.py b/jwst/residual_fringe/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/residual_fringe/tests/test_background_fit.py b/jwst/residual_fringe/tests/test_background_fit.py index a235ae9b31..0ccac5b82f 100644 --- a/jwst/residual_fringe/tests/test_background_fit.py +++ b/jwst/residual_fringe/tests/test_background_fit.py @@ -1,6 +1,5 @@ -""" -Unit test for Residual Fringe Correction fitting of the background -""" +"""Unit test for Residual Fringe Correction fitting of the background.""" + import pytest from pathlib import Path @@ -11,9 +10,22 @@ def read_fit_column(file): - """ This is really a small regression test, testing that the background fitting is working """ - - # Data was pulled out of an exposure by modifying residual_fringe.py to write out a column of data + """ + Read some small sample data for testing. + + Parameters + ---------- + file : str + File name, should be stored in the same directory + as this file. + + Returns + ------- + tuple + Test data: col_data, col_weight, col_wnum, bg_fit, store_freq. + """ + # Data was pulled out of an exposure by modifying + # residual_fringe.py to write out a column of data # The function we are testing is fit_1d_background_complex. file_dir = Path(__file__).parent.resolve() @@ -24,18 +36,15 @@ def read_fit_column(file): col_weight = hdu[2].data col_wnum = hdu[3].data bg_fit = hdu[4].data - store_freq = hdu[0].header['FFREQ'] + store_freq = hdu[0].header["FFREQ"] - return col_data, col_weight, col_wnum, bg_fit, store_freq + return col_data, col_weight, col_wnum, bg_fit, store_freq -@pytest.mark.parametrize("file", ['good_col.fits', 'edge_col.fits']) +@pytest.mark.parametrize("file", ["good_col.fits", "edge_col.fits"]) def test_background_fit(file): - """ test fit_1d_background_complex""" - (col_data, col_weight, col_wnum, bg_fit, store_freq) = read_fit_column(file) - bg_fit2, _ = utils.fit_1d_background_complex(col_data, col_weight, - col_wnum, ffreq=store_freq) + bg_fit2, _ = utils.fit_1d_background_complex(col_data, col_weight, col_wnum, ffreq=store_freq) assert_allclose(bg_fit, bg_fit2, atol=0.001) diff --git a/jwst/residual_fringe/tests/test_configuration.py b/jwst/residual_fringe/tests/test_configuration.py index 6fc8cb4d1c..05b0bd0440 100644 --- a/jwst/residual_fringe/tests/test_configuration.py +++ b/jwst/residual_fringe/tests/test_configuration.py @@ -11,17 +11,19 @@ from jwst.tests.helpers import LogWatcher -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def miri_image(): - image = datamodels.IFUImageModel((20, 20)) - image.data = np.random.random((20, 20)) - image.meta.instrument.name = 'MIRI' - image.meta.instrument.detector = 'MIRIFULONG' - image.meta.exposure.type = 'MIR_MRS' - image.meta.instrument.channel = '12' - image.meta.instrument.band = 'SHORT' - image.meta.filename = 'test_miri.fits' + image.meta.instrument.name = "MIRI" + image.meta.instrument.detector = "MIRIFULONG" + image.meta.exposure.type = "MIR_MRS" + image.meta.instrument.channel = "12" + image.meta.instrument.band = "SHORT" + image.meta.filename = "test_miri.fits" + + rng = np.random.default_rng(42) + image.data = rng.random((20, 20)) + return image @@ -62,7 +64,7 @@ def test_ignore_regions(tmp_cwd, monkeypatch, miri_image, step_log_watcher): # monkeypatch the reference file retrieval so step aborts but does # not error out for this incomplete input - monkeypatch.setattr(step, 'get_reference_file', lambda *args: 'N/A') + monkeypatch.setattr(step, "get_reference_file", lambda *args: "N/A") # check for ignore regions log message step.run(miri_image) @@ -70,24 +72,24 @@ def test_ignore_regions(tmp_cwd, monkeypatch, miri_image, step_log_watcher): def test_fringe_flat_applied(tmp_cwd, miri_image): - - miri_image.meta.cal_step.fringe = 'SKIPPED' + miri_image.meta.cal_step.fringe = "SKIPPED" residual_fringe_reference_file = None regions_reference_file = None save_intermediate_results = False transmission_level = 2 ignore_regions = {} - pars = {'save_intermediate_results': save_intermediate_results, - 'transmission_level': transmission_level} - - rfc = residual_fringe.ResidualFringeCorrection(miri_image, - residual_fringe_reference_file, - regions_reference_file, - ignore_regions, - **pars) - # test that the fringe flat step has to be already run on the data before running residual fringe step - - with pytest.raises(residual_fringe.ErrorNoFringeFlat): + pars = { + "save_intermediate_results": save_intermediate_results, + "transmission_level": transmission_level, + } + + rfc = residual_fringe.ResidualFringeCorrection( + miri_image, residual_fringe_reference_file, regions_reference_file, ignore_regions, **pars + ) + + # test that the fringe flat step has to be already run + # on the data before running residual fringe step + with pytest.raises(residual_fringe.NoFringeFlatError): rfc.do_correction() @@ -99,9 +101,9 @@ def test_rf_step_wrong_input_type(): def test_rf_step_wrong_exptype(miri_image, step_log_watcher): model = miri_image - model.meta.exposure.type = 'NRS_IFU' + model.meta.exposure.type = "NRS_IFU" step_log_watcher.message = "only for MIRI MRS" result = ResidualFringeStep.call(model, skip=False) - assert result.meta.cal_step.residual_fringe == 'SKIPPED' + assert result.meta.cal_step.residual_fringe == "SKIPPED" step_log_watcher.assert_seen() diff --git a/jwst/residual_fringe/tests/test_residual_fringe.py b/jwst/residual_fringe/tests/test_residual_fringe.py index 9ba1937456..b39f8e6d80 100644 --- a/jwst/residual_fringe/tests/test_residual_fringe.py +++ b/jwst/residual_fringe/tests/test_residual_fringe.py @@ -44,14 +44,14 @@ def fringed_spectrum(linear_spectrum): def miri_mrs_model_linear(monkeypatch, linear_spectrum): shape = (1024, 10) model = datamodels.IFUImageModel(shape) - model.meta.instrument.name = 'MIRI' - model.meta.instrument.detector = 'MIRIFUSHORT' - model.meta.instrument.channel = '12' - model.meta.instrument.band = 'SHORT' - model.meta.exposure.type = 'MIR_MRS' - model.meta.observation.date = '2022-05-01' - model.meta.observation.time = '01:01:01' - model.meta.cal_step.fringe = 'COMPLETE' + model.meta.instrument.name = "MIRI" + model.meta.instrument.detector = "MIRIFUSHORT" + model.meta.instrument.channel = "12" + model.meta.instrument.band = "SHORT" + model.meta.exposure.type = "MIR_MRS" + model.meta.observation.date = "2022-05-01" + model.meta.observation.time = "01:01:01" + model.meta.cal_step.fringe = "COMPLETE" wave, flux = linear_spectrum model.data[:, :] = flux[:, None] @@ -77,7 +77,7 @@ def return_wavelength(*args): wavemap[:, :] = wave[:, None] return wavemap - monkeypatch.setattr(ResidualFringeCorrection, '_get_wave_map', return_wavelength) + monkeypatch.setattr(ResidualFringeCorrection, "_get_wave_map", return_wavelength) @pytest.fixture() @@ -91,33 +91,33 @@ def return_wavelength(*args): wavemap[::20, ::2] = np.nan return wavemap - monkeypatch.setattr(ResidualFringeCorrection, '_get_wave_map', return_wavelength) + monkeypatch.setattr(ResidualFringeCorrection, "_get_wave_map", return_wavelength) @pytest.fixture() def mock_slice_info_short(monkeypatch): # mock a single slice to fit matching the test data, for testing speed def one_slice(*args): - slices_in_band = [101] + slices_in_channel = [101] xrange_channel = np.array([[0, 10]]) slice_x_ranges = np.array([[101, 0, 10]]) all_slice_masks = np.ones((1, 1024, 10)) - return slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks + return slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks - monkeypatch.setattr(utils, 'slice_info', one_slice) + monkeypatch.setattr(utils, "slice_info", one_slice) @pytest.fixture() def mock_slice_info_long(monkeypatch): # mock a single slice to fit matching the test data, for testing speed def one_slice(*args): - slices_in_band = [301] + slices_in_channel = [301] xrange_channel = np.array([[0, 10]]) slice_x_ranges = np.array([[301, 0, 10]]) all_slice_masks = np.ones((1, 1024, 10)) - return slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks + return slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks - monkeypatch.setattr(utils, 'slice_info', one_slice) + monkeypatch.setattr(utils, "slice_info", one_slice) @pytest.fixture() @@ -148,11 +148,11 @@ def test_rf1d(linear_spectrum, fringed_spectrum): # corrected output has small diffs from linear on average # (edge effects might be larger) - relative_diff_output = np.abs(outflux - expected_flux)/expected_flux + relative_diff_output = np.abs(outflux - expected_flux) / expected_flux assert np.nanmean(relative_diff_output) < 0.005 # input diffs from linear are much bigger - relative_diff_input = np.abs(flux - expected_flux)/expected_flux + relative_diff_input = np.abs(flux - expected_flux) / expected_flux assert np.nanmean(relative_diff_input) > 0.01 @@ -164,6 +164,7 @@ def test_get_wavemap(): # Mock a WCS that returns 1 for wavelengths def return_ones(x, y): return None, None, np.ones(x.shape) + model.meta.wcs = return_ones rf = ResidualFringeCorrection(model, "N/A", "N/A", None) @@ -172,14 +173,15 @@ def return_ones(x, y): assert np.all(wavemap == 1.0) -@pytest.mark.parametrize('band', ['SHORT', 'MEDIUM', 'LONG']) -def test_rf_step_short(miri_mrs_model_linear, miri_mrs_model_with_fringe, - mock_slice_info_short, mock_wavemap, band): +@pytest.mark.parametrize("band", ["SHORT", "MEDIUM", "LONG"]) +def test_rf_step_short( + miri_mrs_model_linear, miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap, band +): model = miri_mrs_model_with_fringe model.meta.instrument.band = band result = ResidualFringeStep.call(model, skip=False) - assert result.meta.cal_step.residual_fringe == 'COMPLETE' + assert result.meta.cal_step.residual_fringe == "COMPLETE" # output should be closer to a linear spectrum than input, # correction will not be precise @@ -189,12 +191,13 @@ def test_rf_step_short(miri_mrs_model_linear, miri_mrs_model_with_fringe, assert np.nanmean(relative_diff_output) < np.nanmean(relative_diff_input) -@pytest.mark.parametrize('band', ['SHORT', 'MEDIUM', 'LONG']) -def test_rf_step_long(miri_mrs_model_with_fringe, mock_slice_info_long, mock_wavemap, - band, module_log_watcher): +@pytest.mark.parametrize("band", ["SHORT", "MEDIUM", "LONG"]) +def test_rf_step_long( + miri_mrs_model_with_fringe, mock_slice_info_long, mock_wavemap, band, module_log_watcher +): model = miri_mrs_model_with_fringe - model.meta.instrument.detector = 'MIRIFULONG' - model.meta.instrument.channel = '34' + model.meta.instrument.detector = "MIRIFULONG" + model.meta.instrument.channel = "34" model.meta.instrument.band = band # Synthetic input data is reasonable for MIRIFUSHORT, but is expected @@ -204,13 +207,13 @@ def test_rf_step_long(miri_mrs_model_with_fringe, mock_slice_info_long, mock_wav module_log_watcher.assert_seen() # Output data should be identical to input, although step is complete - assert result.meta.cal_step.residual_fringe == 'COMPLETE' + assert result.meta.cal_step.residual_fringe == "COMPLETE" assert np.allclose(model.data, result.data) def test_rf_step_nans_in_wavelength( - miri_mrs_model_linear, miri_mrs_model_with_fringe, - mock_slice_info_short, mock_wavemap_with_nans): + miri_mrs_model_linear, miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap_with_nans +): model = miri_mrs_model_with_fringe # wavelength array has some scattered NaNs: @@ -219,25 +222,32 @@ def test_rf_step_nans_in_wavelength( # output should be closer to a linear spectrum than input, # correction will not be precise - assert result.meta.cal_step.residual_fringe == 'COMPLETE' + assert result.meta.cal_step.residual_fringe == "COMPLETE" expected = miri_mrs_model_linear.data relative_diff_input = np.abs(model.data - expected) / expected relative_diff_output = np.abs(result.data - expected) / expected assert np.nanmean(relative_diff_output) < np.nanmean(relative_diff_input) - -def test_rf_step_save_intermediate(tmp_path, miri_mrs_model_with_fringe, - mock_slice_info_short, mock_wavemap): +def test_rf_step_save_intermediate( + tmp_path, miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap +): model = miri_mrs_model_with_fringe - model.meta.filename = 'test.fits' - ResidualFringeStep.call(model, skip=False, output_dir=str(tmp_path), - save_results=True, save_intermediate_results=True) - - output_files = ['test_residual_fringe.fits', - 'test_stat_table.ecsv', - 'test_out_table.ecsv', - 'test_fit_results.fits'] + model.meta.filename = "test.fits" + ResidualFringeStep.call( + model, + skip=False, + output_dir=str(tmp_path), + save_results=True, + save_intermediate_results=True, + ) + + output_files = [ + "test_residual_fringe.fits", + "test_stat_table.ecsv", + "test_out_table.ecsv", + "test_fit_results.fits", + ] for output_file in output_files: assert (tmp_path / output_file).exists() @@ -249,15 +259,16 @@ def test_rf_step_ignore_regions(miri_mrs_model_with_fringe, mock_slice_info_shor ignore_region_min = [model.wavelength.min()] ignore_region_max = [model.wavelength.max()] result = ResidualFringeStep.call( - model, skip=False, ignore_region_min=ignore_region_min, - ignore_region_max=ignore_region_max) + model, skip=False, ignore_region_min=ignore_region_min, ignore_region_max=ignore_region_max + ) # output should be the same as input assert np.allclose(model.data, result.data) -def test_rf_step_low_snr(miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap, - module_log_watcher): +def test_rf_step_low_snr( + miri_mrs_model_with_fringe, mock_slice_info_short, mock_wavemap, module_log_watcher +): model = miri_mrs_model_with_fringe # set all the data to a very small value so SNR is too low to fit diff --git a/jwst/residual_fringe/tests/test_utils.py b/jwst/residual_fringe/tests/test_utils.py index 4eecbd7c63..d5521e161f 100644 --- a/jwst/residual_fringe/tests/test_utils.py +++ b/jwst/residual_fringe/tests/test_utils.py @@ -3,6 +3,7 @@ from jwst.residual_fringe import utils + @pytest.fixture() def slice_map(): """ @@ -21,8 +22,8 @@ def slice_map(): def test_slice_info_ch1(slice_map): result = utils.slice_info(slice_map, 1) - slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks = result - assert np.all(slices_in_band == [101, 102]) + slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks = result + assert np.all(slices_in_channel == [101, 102]) assert np.all(xrange_channel == [9, 40]) assert np.all(slice_x_ranges == [[101, 9, 20], [102, 29, 40]]) assert np.sum(all_slice_masks[0]) == 200 @@ -31,8 +32,8 @@ def test_slice_info_ch1(slice_map): def test_slice_info_ch2(slice_map): result = utils.slice_info(slice_map, 2) - slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks = result - assert np.all(slices_in_band == [201, 202]) + slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks = result + assert np.all(slices_in_channel == [201, 202]) assert np.all(xrange_channel == [49, 80]) assert np.all(slice_x_ranges == [[201, 49, 60], [202, 69, 80]]) assert np.sum(all_slice_masks[0]) == 200 diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index a580e3bd41..5eb895586d 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import math import numpy.polynomial.polynomial as poly @@ -10,9 +12,9 @@ from BayesicFitting import ConstantModel from BayesicFitting import Fitter -from .fitter import spline_fitter +from jwst.residual_fringe.fitter import spline_fitter + -import logging log = logging.getLogger(__name__) log.setLevel(logging.INFO) @@ -21,66 +23,101 @@ NUM_KNOTS = 80 -def slice_info(slice_map, c): - """ Function to take slice map and channel and find pixels in the slice and xrange of each slice +def slice_info(slice_map, channel): """ + Identify pixels by slice. - slice_inventory = np.unique(slice_map) - slices_in_band = slice_inventory[np.where((slice_inventory >= 100 * c) - & (slice_inventory < 100 * (c + 1)))] + Parameters + ---------- + slice_map : ndarray of int + 2D image containing slice identification values by pixel. + Slice ID values are integers with the value 100 * channel number + + slice number. Pixels not included in a slice have value 0. + channel : int + Channel number. - log.info("Number of slices in band {} ".format(slices_in_band.shape[0])) - slice_x_ranges = np.zeros((slices_in_band.shape[0], 3), dtype=int) - all_slice_masks = np.zeros((slices_in_band.shape[0], slice_map.shape[0], slice_map.shape[1])) - for n, s in enumerate(slices_in_band): + Returns + ------- + slices_in_channel : ndarray of int + 1D array of slice IDs included in the channel. + xrange_channel : ndarray of int + 1D array with two elements: minimum and maximum x values + for the channel. + slice_x_ranges : ndarray of int + N x 3 array for N slices, where the first column is the slice ID, + second column is the minimum x value for the slice, + and the third column is the maximum x value for the slice. + all_slice_masks : ndarray of int + N x nx x ny for N slices, matching the x and y shape of the + input slice_map. Values are 1 for pixels included in the slice, + 0 otherwise. + """ + slice_inventory = np.unique(slice_map) + slices_in_channel = slice_inventory[ + np.where((slice_inventory >= 100 * channel) & (slice_inventory < 100 * (channel + 1))) + ] + + log.info(f"Number of slices in channel {slices_in_channel.shape[0]} ") + slice_x_ranges = np.zeros((slices_in_channel.shape[0], 3), dtype=int) + all_slice_masks = np.zeros((slices_in_channel.shape[0], slice_map.shape[0], slice_map.shape[1])) + for n, s in enumerate(slices_in_channel): # create a mask of the slice pixels = np.where(slice_map == s) - slice = np.zeros(slice_map.shape) + slice_mask = np.zeros(slice_map.shape) - slice[pixels] = 1 + slice_mask[pixels] = 1 # add this to the all_slice_mask array - all_slice_masks[n] = slice + all_slice_masks[n] = slice_mask # get the indices at the start and end of the slice - collapsed_slice = np.sum(slice, axis=0) + collapsed_slice = np.sum(slice_mask, axis=0) indices = np.where(collapsed_slice[:-1] != collapsed_slice[1:])[0] - slice_x_ranges[n, 0], slice_x_ranges[n, 1], slice_x_ranges[n, 2] = int(s), \ - int(np.amin(indices)), int(np.amax(indices) + 1) + slice_x_ranges[n, 0], slice_x_ranges[n, 1], slice_x_ranges[n, 2] = ( + int(s), + int(np.amin(indices)), + int(np.amax(indices) + 1), + ) - log.debug("For slice {} x ranges of slices region {}, {}". - format(slice_x_ranges[n, 0], slice_x_ranges[n, 1], slice_x_ranges[n, 2])) + log.debug( + f"For slice {slice_x_ranges[n, 0]} x ranges of slices " + f"region {slice_x_ranges[n, 1]}, {slice_x_ranges[n, 2]}" + ) - log.debug("Min and max x pixel values of all slices in channel {} {}". - format(np.amin(slice_x_ranges[:, 1]), np.amax(slice_x_ranges[:, 2]))) + log.debug( + "Min and max x pixel values of all slices " + f"in channel {np.amin(slice_x_ranges[:, 1])} {np.amax(slice_x_ranges[:, 2])}" + ) xrange_channel = np.zeros(2) xrange_channel[0] = np.amin(slice_x_ranges[:, 1]) xrange_channel[1] = np.amax(slice_x_ranges[:, 2]) - result = (slices_in_band, xrange_channel, slice_x_ranges, all_slice_masks) - return result + return slices_in_channel, xrange_channel, slice_x_ranges, all_slice_masks def fill_wavenumbers(wnums): """ - Function to take a wavenumber array with missing values (e.g., columns with on-slice and - off-slice pixels), fit the good points using a polynomial, then run use the coefficients - to estimate wavenumbers on the off-slice pixels. - Note that these new values are physically meaningless but having them in the wavenum array - stops the BayesicFitting package from crashing with a LinAlgErr. - - :Parameters: + Fill in missing wavenumber values. - wnums: numpy array, required - the wavenumber array + Given a wavenumber array with missing values (e.g., columns with + on-slice and off-slice pixels), fit the good points using a + polynomial, then use the coefficients to estimate wavenumbers + on the off-slice pixels. - :Returns: + Note that these new values are physically meaningless but having them + in the wavenum array stops the BayesicFitting package from crashing with + a LinAlgErr. - wnums_filled: numpy array - the wavenumber array with off-slice pixels filled + Parameters + ---------- + wnums : ndarray + The wavenumber array. + Returns + ------- + wnums_filled : ndarray + The wavenumber array with off-slice pixels filled """ - # set the off-slice pixels to nans and get their indices wnums[wnums == 0] = np.nan idx = np.isfinite(wnums) @@ -98,29 +135,28 @@ def fill_wavenumbers(wnums): return wnums_filled -def multi_sine(n): +def multi_sine(n_sines): """ - Return a model composed of n sines - - :Parameters: + Create a mult-sine model. - n: int, required - number of sines - - :Returns: + Parameters + ---------- + n_sines : int + Number of sines to include. - mdl, BayesFitting model - the model composed of n sines + Returns + ------- + model : SineModel + The model composed of n sines. """ - # make the first sine mdl = SineModel() # make a copy model = mdl.copy() - # add the copy n-1 times - for i in range(1, n): + # add the copy n - 1 times + for _ in range(1, n_sines): mdl.addModel(model.copy()) # clean @@ -130,13 +166,31 @@ def multi_sine(n): def fit_envelope(wavenum, signal): - """ Fit the upper and lower envelope of signal using a univariate spline - - :param wavenum: - :param signal: - :return: """ + Fit the upper and lower envelope of signal using a univariate spline. + + Parameters + ---------- + wavenum : ndarray + Wavenumber values. + signal : ndarray + Signal values + Returns + ------- + lower_fit : ndarray + Fit to the lower envelope. + l_x : list + Input lower wavenum values. + l_y : list + Input lower signal values. + upper_fit : ndarray + Fit to the upper envelope. + u_x : list + Input upper wavenum values. + u_y : list + Input lower wavenum values. + """ # Detect troughs and mark their location. Define endpoints l_x = [wavenum[0]] l_y = [signal[0]] @@ -144,16 +198,19 @@ def fit_envelope(wavenum, signal): u_y = [signal[0]] for k in np.arange(1, len(signal) - 1): - if ((np.sign(signal[k] - signal[k - 1]) == -1) and - ((np.sign(signal[k] - signal[k + 1])) == -1)): + if (np.sign(signal[k] - signal[k - 1]) == -1) and ( + (np.sign(signal[k] - signal[k + 1])) == -1 + ): l_x.append(wavenum[k]) l_y.append(signal[k]) - if ((np.sign(signal[k] - signal[k - 1]) == 1) and - ((np.sign(signal[k] - signal[k + 1])) == 1)): + if (np.sign(signal[k] - signal[k - 1]) == 1) and ( + (np.sign(signal[k] - signal[k + 1])) == 1 + ): u_x.append(wavenum[k]) u_y.append(signal[k]) - # Append the last value of (s) to the interpolating values. This forces the model to use the same ending point + # Append the last value of (s) to the interpolating values. + # This forces the model to use the same ending point l_x.append(wavenum[-1]) l_y.append(signal[-1]) u_x.append(wavenum[-1]) @@ -168,14 +225,23 @@ def fit_envelope(wavenum, signal): def find_lines(signal, max_amp): """ - Take signal and max amp array, determine location of spectral - features with amplitudes greater than max amp + Determine the location of large spectral features. - :param signal: - :param max_amp: - :return: - """ + Parameters + ---------- + signal : ndarray + Signal data. + max_amp : ndarray + Maximum amplitude, by column. Features larger than + this value are flagged. + Returns + ------- + weights : ndarray + 1D array matching signal dimensions, containing 0 values + for large features and 1 values where no features were + detected. + """ r_x = np.arange(signal.shape[0] - 1) # setup the output arrays @@ -186,20 +252,22 @@ def find_lines(signal, max_amp): u_y, u_x, l_y, l_x = [], [], [], [] for x in r_x: - if (np.sign(signal_check[x] - signal_check[x - 1]) == 1) and \ - (np.sign(signal_check[x] - signal_check[x + 1]) == 1): + if (np.sign(signal_check[x] - signal_check[x - 1]) == 1) and ( + np.sign(signal_check[x] - signal_check[x + 1]) == 1 + ): u_y.append(signal_check[x]) u_x.append(x) - if (np.sign(signal_check[x] - signal_check[x - 1]) == -1) and \ - (np.sign(signal_check[x] - signal_check[x + 1]) == -1): + if (np.sign(signal_check[x] - signal_check[x - 1]) == -1) and ( + np.sign(signal_check[x] - signal_check[x + 1]) == -1 + ): l_y.append(signal[x]) l_x.append(x) for n, amp in enumerate(u_y): max_amp_val = max_amp[u_x[n]] + log.debug("find_lines: check if peak above max amp") if amp > max_amp_val: - # peak in x # log.debug("find_lines: flagging neighbours") xpeaks = [u_x[n] - 1, u_x[n], u_x[n] + 1] @@ -208,10 +276,9 @@ def find_lines(signal, max_amp): # find nearest troughs for xp in xpeaks: - log.debug("find_lines: checking ind {}".format(xp)) + log.debug(f"find_lines: checking ind {xp}") try: - x1 = l_x[np.argsort(np.abs(l_x - xp))[0]] try: @@ -237,7 +304,7 @@ def find_lines(signal, max_amp): except IndexError: pass - log.debug("find_lines: Found {} peaks {} troughs".format(len(u_x), len(l_x))) + log.debug(f"find_lines: Found {len(u_x)} peaks {len(l_x)} troughs") weights_factors[signal_check > max_amp * 2] = 0 # catch any remaining # weights_factors[signal_check > np.amax(max_amp)] = 0 @@ -246,132 +313,131 @@ def find_lines(signal, max_amp): def check_res_fringes(res_fringe_fit, max_amp): """ - Check for regions where res fringe fit runs away (greater than max amp), - set the beat where this happens to 0 to avoid making the fringes worse - - :Parameters: - - res_fringe_fit: numpy array, required - the residual fringe fit - - max_amp: numpy array, required - the maximum amplitude array + Check for regions with bad fringe fits. - :Returns: + Set the beat where this happens to 0 to avoid making the fringes worse. - res_fringe_fit: numpy array - the residual fringe fit with exploding fit regions removed - - flats: numpy array - flags where the fit was rejected + Parameters + ---------- + res_fringe_fit : ndarray + The residual fringe fit. + max_amp : ndarray + The maximum amplitude array. + Returns + ------- + res_fringe_fit: ndarray + The residual fringe fit with bad fit regions removed and replaced + with 0. + flags: ndarray + 1D flag array indicating where the fit was altered, matching + the size of the first dimension of `res_fringe_fit`. + 1 indicates a bad fit region; 0 indicates a good region, left + unchanged. """ - flags = np.zeros(res_fringe_fit.shape[0]) # get fit envelope npix = np.arange(res_fringe_fit.shape[0]) lenv_fit, _, _, uenv_fit, _, _ = fit_envelope(npix, res_fringe_fit) - # get the indices of the nodes (where uenv slope goes from negative to positive), add 0 and 1023 + # get the indices of the nodes (where uenv slope goes from + # negative to positive), add 0 and 1023 node_ind = [0] for k in np.arange(1, len(uenv_fit) - 1): - if (np.sign(uenv_fit[k] - uenv_fit[k - 1]) == -1) and ((np.sign(uenv_fit[k] - uenv_fit[k + 1])) == -1): + if (np.sign(uenv_fit[k] - uenv_fit[k - 1]) == -1) and ( + (np.sign(uenv_fit[k] - uenv_fit[k + 1])) == -1 + ): node_ind.append(k) node_ind.append(res_fringe_fit.shape[0] - 1) node_ind = np.asarray(node_ind) - log.debug("check_res_fringes: found {} nodes".format(len(node_ind))) + log.debug(f"check_res_fringes: found {len(node_ind)} nodes") # find where res_fringes goes above max_amp runaway_rfc = np.argwhere((np.abs(lenv_fit) + np.abs(uenv_fit)) > (max_amp * 2)) # check which signal env the blow ups are located in and set to 1, and set a flag array if len(runaway_rfc) > 0: - log.debug("check_res_fringes: {} data points exceed threshold".format(len(runaway_rfc))) + log.debug(f"check_res_fringes: {len(runaway_rfc)} data points exceed threshold") log.debug("check_res_fringes: resetting fits to related beats") for i in runaway_rfc: # find where the index is compared to the nodes node_loc = np.searchsorted(node_ind, i) - # set the res_fringes between the nodes to 1 + # set the res_fringes between the nodes to 0 lind = node_ind[node_loc - 1] uind = node_ind[node_loc] - res_fringe_fit[lind[0]:uind[0]] = 0 - flags[lind[0]:uind[0]] = 1 # set flag to 1 for reject fit region + res_fringe_fit[lind[0] : uind[0]] = 0 + flags[lind[0] : uind[0]] = 1 # set flag to 1 for reject fit region return res_fringe_fit, flags def interp_helper(mask): - """Helper function to for interpolating in feature gaps. - - :Parameters: - - mask: numpy array, required - the 1D mask array (weights) + """ + Create a convenience function for indexing low-weight values. - :Returns: + Low-weight is defined to be a value < 1e-5. - - logical indices of NaNs - - index, a function, with signature indices= index(logical_indices), - to convert logical indices to 'equivalent' indices + Parameters + ---------- + mask : ndarray + The 1D mask array (weights). + Returns + ------- + index_array : ndarray of bool + Boolean index array for low weight pixels. + index_function : callable + A function, with signature indices = index_function(index_array), + to convert logical indices to equivalent direct index values. """ return mask < 1e-05, lambda z: z.nonzero()[0] -def fit_1d_background_complex(flux, weights, wavenum, order=2, ffreq=None, channel=1, test=False): - """Fit the background signal using a piecewise spline of n knots. Note that this will also try to identify - obvious emission lines and flag them so they aren't considered in the fitting. - - :Parameters: - - flux: numpy array, required - the 1D array of fluxes - - weights: numpy array, required - the 1D array of weights - - wavenum: numpy array, required - the 1D array of wavenum - - order: int, optional, default=2 - the order of the Splines model - - ffreq: float, optional, default=None - the expected fringe frequency, used to determine number of knots. If None, - defaults to NUM_KNOTS constant - - channel: int, optional, default=1 - the channel processed. used to determine if other arrays need to be reversed given the direction of increasing - wavelength down the detector in MIRIFULONG - - :Returns: - - bg_fit: numpy array - the fitted background +def fit_1d_background_complex(flux, weights, wavenum, ffreq=None, channel=1): + """ + Fit the background signal using a piecewise spline. - bgindx: numpy array - the location of the knots + Note that this will also try to identify obvious emission lines + and flag them, so they aren't considered in the fitting. - fitter: BayesicFitting object - fitter object, mainly used for testing + Parameters + ---------- + flux : ndarray + 1D array of fluxes. + weights : ndarray + 1D array of weights. + wavenum : ndarray + 1D array of wavenumbers. + ffreq : float, optional + The expected fringe frequency, used to determine number of knots. + If None, defaults to NUM_KNOTS constant + channel : int, optional + The channel to process. Used to determine if other arrays + need to be reversed given the direction of increasing + wavelength down the detector in MIRIFULONG. + Returns + ------- + bg_fit : ndarray + The fitted background. + bgindx: ndarray + The location of the knots. """ - # first get the weighted pixel fraction weighted_pix_frac = (weights > 1e-05).sum() / flux.shape[0] # define number of knots using fringe freq, want 1 knot per period if ffreq is not None: - log.debug("fit_1d_background_complex: knot positions for {} cm-1".format(ffreq)) + log.debug(f"fit_1d_background_complex: knot positions for {ffreq} cm-1") nknots = int((np.amax(wavenum) - np.amin(wavenum)) / (ffreq)) else: - log.debug("fit_1d_background_complex: using num_knots={}".format(NUM_KNOTS)) + log.debug(f"fit_1d_background_complex: using num_knots={NUM_KNOTS}") nknots = int((flux.shape[0] / 1024) * NUM_KNOTS) - log.debug("fit_1d_background_complex: number of knots = {}".format(nknots)) + log.debug(f"fit_1d_background_complex: number of knots = {nknots}") # recale wavenums to around 1 for bayesicfitting factor = np.amin(wavenum) @@ -379,18 +445,18 @@ def fit_1d_background_complex(flux, weights, wavenum, order=2, ffreq=None, chann # get number of fringe periods in array nper = (np.amax(wavenum) - np.amin(wavenum)) // ffreq - log.debug("fit_1d_background_complex: column is {} fringe periods".format(nper)) + log.debug(f"fit_1d_background_complex: column is {nper} fringe periods") # now reduce by the weighted pixel fraction to see how many can be fitted nper_cor = int(nper * weighted_pix_frac) - log.debug("fit_1d_background_complex: column has {} weighted fringe periods".format(nper_cor)) + log.debug(f"fit_1d_background_complex: column has {nper_cor} weighted fringe periods") # require at least 5 sine periods to fit if nper < 5: log.info(" not enough weighted data, no fit performed") return flux.copy(), np.zeros(flux.shape[0]), None - bgindx = new_make_knots(flux.copy(), int(nknots), weights=weights.copy()) + bgindx = make_knots(flux.copy(), int(nknots), weights=weights.copy()) bgknots = wavenum_scaled[bgindx].astype(float) # Reverse (and clip) the fit data as scipy/astropy need monotone increasing data for SW detector @@ -405,13 +471,14 @@ def fit_1d_background_complex(flux, weights, wavenum, order=2, ffreq=None, chann y = flux[::-1] w = weights[::-1] else: - raise ValueError('channel not in 1-4') + raise ValueError("channel not in 1-4") # Fit the spline - # robust fitting causing problems for fringe 2 in channels 3 and 4, just use the fitter class if ffreq > 1.5: bg_model = spline_fitter(x, y, w, t, 2, reject_outliers=True) else: + # robust fitting causing problems for fringe 2 in channels 3 and 4, + # just use the fitter class bg_model = spline_fitter(x, y, w, t, 1, reject_outliers=False) # fit the background @@ -429,46 +496,48 @@ def fit_1d_background_complex(flux, weights, wavenum, order=2, ffreq=None, chann return bg_fit, bgindx -def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffreq, min_nfringes, max_nfringes, - pgram_res, col_snr2): - - """Fit the residual fringe signal.- Improved method - Takes an input 1D array of residual fringes and fits using the supplied mode in the BayesicFitting package: - :Parameters: - res_fringes: numpy array, required - the 1D array with residual fringes - weights: numpy array, required - the 1D array of weights - ffreq: float, required - the central scan frequency - dffreq: float, required - the one-sided interval of scan frequencies - min_nfringes: int, required - the minimum number of fringes to check - max_nfringes: int, required - the maximum number of fringes to check - pgram_res: float, optional - resolution of the periodogram scan in cm-1 - wavenum: numpy array, required - the 1D array of wavenum - :Returns: - res_fringe_fit: numpy array - the residual fringe fit data +def fit_1d_fringes_bayes_evidence( + res_fringes, weights, wavenum, ffreq, dffreq, max_nfringes, pgram_res, col_snr2 +): + """ + Fit the residual fringe signal. + + Takes an input 1D array of residual fringes and fits using the + supplied mode in the BayesicFitting package. + + Parameters + ---------- + res_fringes : ndarray + The 1D array with residual fringes. + weights : ndarray + The 1D array of weights + wavenum : ndarray + The 1D array of wavenum. + ffreq : float + The central scan frequency + dffreq : float + The one-sided interval of scan frequencies. + max_nfringes : int + The maximum number of fringes to check. + pgram_res : float + Resolution of the periodogram scan in cm-1. + col_snr2 : ndarray + Location of pixels with sufficient SNR to fit. + + Returns + ------- + res_fringe_fit : ndarray + The residual fringe fit data. """ # initialize output to none res_fringe_fit = None - weighted_pix_num = None - peak_freq = None - freq_min = None - freq_max = None # get the number of weighted pixels weighted_pix_num = (weights > 1e-05).sum() - # set the maximum array size, always 1024 # get scan res res = np.around((2 * dffreq) / pgram_res).astype(int) - log.debug("fit_1d_fringes_bayes: scan res = {}".format(res)) + log.debug(f"fit_1d_fringes_bayes: scan res = {res}") factor = np.amin(wavenum) wavenum = wavenum.copy() / factor @@ -487,7 +556,6 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr res_fringes_proc = res_fringes.copy() nfringes = 0 keep_dict = {} - best_mdl = None fitted_frequencies = [] # get the initial evidence from ConstantModel @@ -495,12 +563,10 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr sftr = Fitter(wavenum, sdml) _ = sftr.fit(res_fringes, weights=weights) evidence1 = sftr.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10]) - log.debug( - "fit_1d_fringes_bayes_evidence: Initial Evidence: {}".format(evidence1)) + log.debug(f"fit_1d_fringes_bayes_evidence: Initial Evidence: {evidence1}") for f in np.arange(max_nfringes): - log.debug( - "Starting fringe {}".format(f + 1)) + log.debug(f"Starting fringe {f + 1}") # get the scan arrays weights *= col_snr2 @@ -511,15 +577,19 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr log.debug("fit_1d_fringes_bayes_evidence: get the periodogram") pgram = LombScargle(wavenum_scan[::-1], res_fringe_scan[::-1]).power(1 / freq) - log.debug("fit_1d_fringes_bayes_evidence: get the most significant frequency in the periodogram") + log.debug( + "fit_1d_fringes_bayes_evidence: get the most significant frequency in the periodogram" + ) peak = np.argmax(pgram) - freqs = 1. / freq[peak] + freqs = 1.0 / freq[peak] # fix the most significant frequency in the fixed dict that is passed to fitter keep_ind = nfringes * 3 keep_dict[keep_ind] = freqs - log.debug("fit_1d_fringes_bayes_evidence: creating multisine model of {} freqs".format(nfringes + 1)) + log.debug( + f"fit_1d_fringes_bayes_evidence: creating multisine model of {nfringes + 1} freqs" + ) mdl = multi_sine(nfringes + 1) # fit the multi-sine model and get evidence @@ -536,7 +606,8 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr ftr = RobustShell(fitter, domain=10) pars = ftr.fit(res_fringes, weights=weights) - # try get evidence (may fail for large component fits to noisy data, set to very negative value + # try to get evidence (may fail for large component + # fits to noisy data, set to very negative value) try: evidence2 = fitter.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10]) except ValueError: @@ -555,7 +626,8 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr fitter = LevenbergMarquardtFitter(wavenum, mdl, verbose=0) pars = fitter.fit(res_fringes, weights=weights) - # try get evidence (may fail for large component fits to noisy data, set to very negative value + # try to get evidence (may fail for large component + # fits to noisy data, set to very negative value) try: evidence2 = fitter.getEvidence(limits=[-3, 10], noiseLimits=[0.001, 10]) except ValueError: @@ -563,20 +635,23 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr except Exception: evidence2 = -1e9 - log.debug("fit_1d_fringes_bayes_evidence: nfringe={} ev={} chi={}".format(nfringes, evidence2, fitter.chisq)) + log.debug( + f"fit_1d_fringes_bayes_evidence: nfringe={nfringes + 1} " + f"ev={evidence2} chi={fitter.chisq}" + ) bayes_factor = evidence2 - evidence1 - log.debug( - "fit_1d_fringes_bayes_evidence: bayes factor={}".format(bayes_factor)) - if bayes_factor > 1: # strong evidence thresh (log(bayes factor)>1, Kass and Raftery 1995) + log.debug(f"fit_1d_fringes_bayes_evidence: bayes factor={bayes_factor}") + if bayes_factor > 1: + # strong evidence threshold (log(bayes factor)>1, Kass and Raftery 1995) evidence1 = evidence2 best_mdl = mdl.copy() fitted_frequencies.append(freqs) log.debug( - "fit_1d_fringes_bayes_evidence: strong evidence for nfringes={} ".format(nfringes + 1)) + f"fit_1d_fringes_bayes_evidence: strong evidence for nfringes={nfringes + 1} " + ) else: - log.debug( - "fit_1d_fringes_bayes_evidence: no evidence for nfringes={}".format(nfringes + 1)) + log.debug(f"fit_1d_fringes_bayes_evidence: no evidence for nfringes={nfringes + 1}") break # subtract the fringes for this frequency @@ -584,7 +659,7 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr res_fringes_proc = res_fringes.copy() - res_fringe_fit nfringes += 1 - log.debug("fit_1d_fringes_bayes_evidence: optimal={} fringes".format(nfringes)) + log.debug(f"fit_1d_fringes_bayes_evidence: optimal={nfringes} fringes") # create outputs to return fitted_frequencies = (1 / np.asarray(fitted_frequencies)) * factor @@ -595,28 +670,29 @@ def new_fit_1d_fringes_bayes_evidence(res_fringes, weights, wavenum, ffreq, dffr return res_fringe_fit, weighted_pix_num, nfringes, peak_freq, freq_min, freq_max -def new_make_knots(flux, nknots=20, weights=None): - """Defines knot positions for piecewise models. This simply splits the array into sections. It does - NOT take into account the shape of the data. - - :Parameters: - - flux: numpy array, required - the flux array or any array of the same dimension - - nknots: int, optional, default=20 - the number of knots to create (excluding 0 and 1023) - - weights: numpy array, optional, default=None - optionally supply a weights array. This will be used to add knots at the edge of bad pixels or features +def make_knots(flux, nknots=20, weights=None): + """ + Define knot positions for piecewise models. - :Returns: + This function simply splits the array into sections. It does + NOT take into account the shape of the data. - knot_idx, numpy array - the indices of the knots + Parameters + ---------- + flux : ndarray + The flux array or any array of the same dimension. + nknots : int, optional + The number of knots to create (excluding 0 and 1023). + weights : ndarray or None, optional + Optionally supply a weights array. This will be used to + add knots at the edge of bad pixels or features. + Returns + ------- + knot_idx : ndarray + The indices of the knots. """ - log.debug("new_make_knots: creating {} knots on flux array".format(nknots)) + log.debug(f"make_knots: creating {nknots} knots on flux array") # handle nans or infs that may exist flux = np.nan_to_num(flux, posinf=1e-08, neginf=1e-08) @@ -653,8 +729,7 @@ def new_make_knots(flux, nknots=20, weights=None): # if the weights array is supplied, determine the edges of good data and set knots there if weights is not None: - - log.debug("new_make_knots: adding knots at edges of bad pixels in weights array") + log.debug("make_knots: adding knots at edges of bad pixels in weights array") # if there are bad pixels in the flux array with flux~0, # add these to weights array if not already there @@ -676,7 +751,6 @@ def new_make_knots(flux, nknots=20, weights=None): # we don't need knots in the bad pixels so ignore these if largest > 1e-03: - # check if the absolute values are almost equal if math.isclose(largest, np.abs(wd), rel_tol=1e-01): # if so, set the index and adjust depending on whether the @@ -701,7 +775,9 @@ def new_make_knots(flux, nknots=20, weights=None): return knot_idx.astype(int) -# RFC1D additions ====================================== + +# The below functions were added to enable residual fringe correction +# in 1D extracted data. # Define some constants describing the two central fringe frequencies # (primary fringe, and dichroic fringe) and a range around them to search for residual fringes, @@ -711,60 +787,46 @@ def new_make_knots(flux, nknots=20, weights=None): MAX_NFRINGES_1d = [10, 15] MAXAMP_1d = 0.2 -# functions -def fit_1d_background_complex_1d(flux, weights, wavenum, order=2, ffreq=None, channel=1, test=False): - """Fit the background signal using a piecewise spline of n knots. Note that this will also try to identify - obvious emission lines and flag them so they aren't considered in the fitting. - Parameters - ---------- - flux : numpy array, required - the 1D array of fluxes - - weights : numpy array, required - the 1D array of weights - - wavenum : numpy array, required - the 1D array of wavenum - - order : int, optional, default=2 - the order of the Splines model +def fit_1d_background_complex_1d(flux, weights, wavenum, ffreq=None): + """ + Fit the background signal using a piecewise spline of n knots. - ffreq : float, optional, default=None - the expected fringe frequency, used to determine number of knots. If None, - defaults to NUM_KNOTS constant + Note that this will also try to identify obvious emission lines and + flag them so they aren't considered in the fitting. - channel : int, optional, default=1 - the channel processed. used to determine if other arrays need to be reversed given the direction of increasing - wavelength down the detector in MIRIFULONG + Parameters + ---------- + flux : ndarray + The 1D array of fluxes. + weights : ndarray + The 1D array of weights. + wavenum : ndarray + The 1D array of wavenum. + ffreq : float or None, optional + The expected fringe frequency, used to determine number of knots. + If None, defaults to NUM_KNOTS constant. Returns ------- - - bg_fit : numpy array - the fitted background - - bgindx : numpy array - the location of the knots - - fitter : BayesicFitting object - fitter object, mainly used for testing - + bg_fit : ndarray + The fitted background. + bgindx : ndarray + The location of the knots. """ - # first get the weighted pixel fraction weighted_pix_frac = (weights > 1e-05).sum() / flux.shape[0] # define number of knots using fringe freq, want 1 knot per period if ffreq is not None: - log.debug("fit_1d_background_complex: knot positions for {} cm-1".format(ffreq)) + log.debug(f"fit_1d_background_complex: knot positions for {ffreq} cm-1") nknots = int((np.amax(wavenum) - np.amin(wavenum)) / (ffreq)) else: - log.debug("fit_1d_background_complex: using num_knots={}".format(NUM_KNOTS)) + log.debug(f"fit_1d_background_complex: using num_knots={NUM_KNOTS}") nknots = int((flux.shape[0] / 1024) * NUM_KNOTS) - log.debug("fit_1d_background_complex: number of knots = {}".format(nknots)) + log.debug(f"fit_1d_background_complex: number of knots = {nknots}") # recale wavenums to around 1 for bayesicfitting factor = np.amin(wavenum) @@ -772,32 +834,31 @@ def fit_1d_background_complex_1d(flux, weights, wavenum, order=2, ffreq=None, ch # get number of fringe periods in array nper = (np.amax(wavenum) - np.amin(wavenum)) // ffreq - log.debug("fit_1d_background_complex: column is {} fringe periods".format(nper)) + log.debug(f"fit_1d_background_complex: column is {nper} fringe periods") # now reduce by the weighted pixel fraction to see how many can be fitted nper_cor = int(nper * weighted_pix_frac) - log.debug("fit_1d_background_complex: column has {} weighted fringe periods".format(nper_cor)) + log.debug(f"fit_1d_background_complex: column has {nper_cor} weighted fringe periods") # require at least 5 sine periods to fit if nper < 5: log.info(" not enough weighted data, no fit performed") return flux.copy(), np.zeros(flux.shape[0]), None - bgindx = new_make_knots(flux.copy(), int(nknots), weights=weights.copy()) + bgindx = make_knots(flux.copy(), int(nknots), weights=weights.copy()) bgknots = wavenum_scaled[bgindx].astype(float) # Reverse (and clip) the fit data as scipy/astropy need monotone increasing data for SW detector - t = bgknots[::-1][1:-1] x = wavenum_scaled[::-1] y = flux[::-1] w = weights[::-1] # Fit the spline - # TODO: robust fitting causing problems for fringe 2, change to just using fitter there if ffreq > 1.5: bg_model = spline_fitter(x, y, w, t, 2, reject_outliers=True) else: + # robust fitting causing problems for fringe 2, change to just using fitter there bg_model = spline_fitter(x, y, w, t, 1, reject_outliers=False) # fit the background @@ -815,41 +876,39 @@ def fit_1d_background_complex_1d(flux, weights, wavenum, order=2, ffreq=None, ch return bg_fit, bgindx -def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, dffreq, min_nfringes, max_nfringes, - pgram_res): - """Fit the residual fringe signal.- 1d version - Takes an input 1D array of residual fringes and fits using the supplied mode in the BayesicFitting package. +def fit_1d_fringes_bayes_evidence_1d( + res_fringes, weights, wavenum, ffreq, dffreq, max_nfringes, pgram_res +): + """ + Fit the residual fringe signal in 1D. + + Takes an input 1D array of residual fringes and fits them using + the supplied mode in the BayesicFitting package. Parameters ---------- - res_fringes : numpy array, required - the 1D array with residual fringes - weights : numpy array, required - the 1D array of weights - ffreq : float, required - the central scan frequency - dffreq : float, required - the one-sided interval of scan frequencies - min_nfringes : int, required - the minimum number of fringes to check - max_nfringes : int, required - the maximum number of fringes to check - pgram_res : float, optional - resolution of the periodogram scan in cm-1 - wavenum : numpy array, required - the 1D array of wavenum + res_fringes : ndarray + The 1D array with residual fringes. + weights : ndarray + The 1D array of weights. + wavenum : ndarray + The 1D array of wavenum. + ffreq : float + The central scan frequency. + dffreq : float + The one-sided interval of scan frequencies. + max_nfringes : int + The maximum number of fringes to check. + pgram_res : float + Resolution of the periodogram scan in cm-1. Returns ------- - res_fringe_fit : numpy array - the residual fringe fit data + res_fringe_fit : ndarray + The residual fringe fit data. """ # initialize output to none res_fringe_fit = None - weighted_pix_num = None - peak_freq = None - freq_min = None - freq_max = None # get the number of weighted pixels weighted_pix_num = (weights > 1e-05).sum() @@ -872,7 +931,6 @@ def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, d res_fringes_proc = res_fringes.copy() nfringes = 0 keep_dict = {} - best_mdl = None fitted_frequencies = [] # get the initial evidence from ConstantModel @@ -881,7 +939,7 @@ def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, d _ = sftr.fit(res_fringes, weights=weights) evidence1 = sftr.getEvidence(limits=[-2, 1000], noiseLimits=[0.001, 1]) - for n in np.arange(max_nfringes): + for _ in range(max_nfringes): # get the scan arrays res_fringe_scan = res_fringes_proc[np.where(weights > 1e-05)] wavenum_scan = wavenum[np.where(weights > 1e-05)] @@ -890,7 +948,7 @@ def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, d pgram = LombScargle(wavenum_scan[::-1], res_fringe_scan[::-1]).power(1 / freq) peak = np.argmax(pgram) - freqs = 1. / freq[peak] + freqs = 1.0 / freq[peak] # fix the most significant frequency in the fixed dict that is passed to fitter keep_ind = nfringes * 3 @@ -911,7 +969,8 @@ def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, d ftr = RobustShell(fitter, domain=10) pars = ftr.fit(res_fringes, weights=weights) - # try get evidence (may fail for large component fits to noisy data, set to very negative value + # try to get evidence (may fail for large component + # fits to noisy data, set to very negative value) try: evidence2 = fitter.getEvidence(limits=[-2, 1000], noiseLimits=[0.001, 1]) except ValueError: @@ -940,28 +999,29 @@ def new_fit_1d_fringes_bayes_evidence_1d(res_fringes, weights, wavenum, ffreq, d return res_fringe_fit, weighted_pix_num, nfringes, peak_freq, freq_min, freq_max + def fit_residual_fringes_1d(flux, wavelength, channel=1, dichroic_only=False, max_amp=None): - """This is the wrapper function for 1d residual fringe correction. + """ + Fit residual fringes in 1D. Parameters ---------- - flux : numpy array, required - The 1D array of fluxes - wavelength : numpy array, required - The 1D array of wavelengths - channel : integer, optional - The MRS spectral channel - dichroic_only : boolean, optional - Fit only dichroic fringes - max_amp : numpy array, optional - The maximum amplitude array + flux : ndarray + The 1D array of fluxes. + wavelength : ndarray + The 1D array of wavelengths. + channel : int, optional + The MRS spectral channel. + dichroic_only : bool, optional + Fit only dichroic fringes. + max_amp : ndarray, optional + The maximum amplitude array. Returns ------- - output : numpy array - Modified version of input flux array + output : ndarray + Modified version of input flux array. """ - # Restrict to just the non-zero positive fluxes indx = np.where(flux > 0) useflux = flux[indx] @@ -992,7 +1052,9 @@ def fit_residual_fringes_1d(flux, wavelength, channel=1, dichroic_only=False, ma if dichroic_only is True: if channel not in [3, 4]: - raise ValueError('Dichroic fringe should only be removed from channels 3 and 4, stopping!') + raise ValueError( + "Dichroic fringe should only be removed from channels 3 and 4, stopping!" + ) ffreq_vals = [FFREQ_1d[1]] dffreq_vals = [DFFREQ_1d[1]] @@ -1028,36 +1090,37 @@ def fit_residual_fringes_1d(flux, wavelength, channel=1, dichroic_only=False, ma for m, proc_data in enumerate(proc_arr): for n, ffreq in enumerate(ffreq_vals): - - # fit background - try: - bg_fit, bgindx = fit_1d_background_complex_1d(proc_data, weights_feat, - wavenum, ffreq=ffreq, channel=1) - except Exception as e: - raise e + bg_fit, bgindx = fit_1d_background_complex_1d( + proc_data, weights_feat, wavenum, ffreq=ffreq + ) # get the residual fringes as fraction of signal - res_fringes = np.divide(proc_data, bg_fit, out=np.zeros_like(proc_data), - where=bg_fit != 0) + res_fringes = np.divide( + proc_data, bg_fit, out=np.zeros_like(proc_data), where=bg_fit != 0 + ) res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) res_fringes *= np.where(weights > 1e-07, 1, 1e-08) # fit the residual fringes - try: - res_fringe_fit, wpix_num, opt_nfringes, peak_freq, freq_min, freq_max = new_fit_1d_fringes_bayes_evidence_1d( - res_fringes, weights_feat, - wavenum, ffreq, dffreq_vals[n], min_nfringes=0, - max_nfringes=max_nfringes_vals[n], pgram_res=0.001) - - except Exception as e: - raise e + res_fringe_fit, wpix_num, opt_nfringes, peak_freq, freq_min, freq_max = ( + fit_1d_fringes_bayes_evidence_1d( + res_fringes, + weights_feat, + wavenum, + ffreq, + dffreq_vals[n], + max_nfringes_vals[n], + 0.001, + ) + ) # check for fit blowing up, reset rfc fit to 0, raise a flag res_fringe_fit, res_fringe_fit_flag = check_res_fringes(res_fringe_fit, max_amp) # correct for residual fringes - _, _, _, env, u_x, u_y = fit_envelope(np.arange(res_fringe_fit.shape[0]), - res_fringe_fit) + _, _, _, env, u_x, u_y = fit_envelope( + np.arange(res_fringe_fit.shape[0]), res_fringe_fit + ) rfc_factors = 1 / (res_fringe_fit * (weights > 1e-05).astype(int) + 1) proc_data *= rfc_factors From b1b801330ede18b7292f8f322bcab1e0cb8d2808 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Tue, 4 Mar 2025 15:49:51 -0500 Subject: [PATCH 6/7] Add change note --- changes/9242.residual_fringe.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/9242.residual_fringe.rst diff --git a/changes/9242.residual_fringe.rst b/changes/9242.residual_fringe.rst new file mode 100644 index 0000000000..3114820cff --- /dev/null +++ b/changes/9242.residual_fringe.rst @@ -0,0 +1 @@ +Refactor for maintainability, fix intermediate filenames when input datamodel is read from memory, and stop producing an unused intermediate output file. From c018352278227f2bf993716cbe44eaaf4dfe3e19 Mon Sep 17 00:00:00 2001 From: Melanie Clarke Date: Wed, 5 Mar 2025 13:07:59 -0500 Subject: [PATCH 7/7] Minor edits from review comments --- jwst/residual_fringe/residual_fringe.py | 185 +++++++++--------- .../tests/test_configuration.py | 2 +- .../tests/test_residual_fringe.py | 10 +- jwst/residual_fringe/utils.py | 10 +- 4 files changed, 107 insertions(+), 100 deletions(-) diff --git a/jwst/residual_fringe/residual_fringe.py b/jwst/residual_fringe/residual_fringe.py index 9d6268a2f8..a3cf628655 100644 --- a/jwst/residual_fringe/residual_fringe.py +++ b/jwst/residual_fringe/residual_fringe.py @@ -16,6 +16,10 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +# Noise factor for DER_SNR spectroscopic signal-to-noise calculation +# (see Stoehr, ADASS 2008: https://archive.stsci.edu/vodocs/der_snr.pdf) +DER_SNR_FACTOR = 1.482602 / np.sqrt(6) + class ResidualFringeCorrection: """Calculate and apply correction for residual fringes.""" @@ -42,10 +46,10 @@ def __init__( regions_reference_file : str Path to REGIONS reference file. ignore_regions : dict - Wavelength regions to ignore. Keys are "num", "min", and "max. + Wavelength regions to ignore. Keys are "num", "min", and "max". Values are the number of regions specified (int), the list of minimum wavelength values, and the list of maximum wavelength - values. Minimum and maximum lists must match. + values. Length of minimum and maximum lists must match. save_intermediate_results : bool, optional If True, intermediate files are saved to disk. transmission_level : int, optional @@ -90,9 +94,7 @@ def __init__( def do_correction(self): """ - Apply residual fringe correction. - - Correction is applied to a model copied from self.input_model. + Apply residual fringe correction to a copy of self.input_model. Returns ------- @@ -267,17 +269,17 @@ def do_correction(self): # reasonable signal. If the SNR < min_snr (CDP), pass n = len(test_flux) signal = np.nanmean(test_flux) - noise = 0.6052697 * np.nanmedian( + noise = DER_SNR_FACTOR * np.nanmedian( np.abs(2.0 * test_flux[2 : n - 2] - test_flux[0 : n - 4] - test_flux[4:n]) ) - snr2 = 0.0 # initialize + snr2 = 0.0 if noise != 0: snr2 = signal / noise # Sometimes can return nan, inf for bad data so include this in check if snr2 < min_snr[0]: - log.debug(f"SNR too low not fitting column {col}, {snr2}, {min_snr[0]}") + log.debug(f"SNR too low; not fitting column {col}, {snr2}, {min_snr[0]}") continue log.debug(f"Fitting column {col}") @@ -372,89 +374,90 @@ def do_correction(self): try: for fn, ff in enumerate(ffreq): # ignore place holder fringes - if ff > 1e-03: - log.debug(f" Start ffreq = {ff}") - - # check if snr criteria is met for fringe component, - # should always be true for fringe 1 - if snr2 > min_snr[fn]: - log.debug(" Fit spectral baseline") - - bg_fit, bgindx = utils.fit_1d_background_complex( - proc_data, - weights_feat, - col_wnum, - ffreq=ffreq[fn], - channel=c, - ) - - # get the residual fringes as fraction of signal - res_fringes = np.divide( - proc_data, - bg_fit, - out=np.zeros_like(proc_data), - where=bg_fit != 0, - ) - res_fringes = np.subtract( - res_fringes, 1, where=res_fringes != 0 - ) - res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) - - # fit the residual fringes - log.debug(" Set up Bayes evidence") - ( - res_fringe_fit, - wpix_num, - opt_nfringe, - peak_freq, - freq_min, - freq_max, - ) = utils.fit_1d_fringes_bayes_evidence( - res_fringes, - weights_feat, - col_wnum, - ffreq[fn], - dffreq[fn], - max_nfringes[fn], - pgram_res[fn], - col_snr2, - ) - - # check for fit blowing up, reset rfc fit to 0, raise a flag - log.debug(" Check residual fringe fit for bad fit regions") - res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes( - res_fringe_fit, col_max_amp - ) - - # correct for residual fringes - log.debug(" Divide out residual fringe fit") - _, _, _, env, u_x, u_y = utils.fit_envelope( - np.arange(res_fringe_fit.shape[0]), res_fringe_fit - ) - - rfc_factors = 1 / ( - res_fringe_fit * (col_weight > 1e-05).astype(int) + 1 - ) - proc_data *= rfc_factors - proc_factors *= rfc_factors - - # handle nans or infs that may exist - proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) - proc_data[proc_data < 0] = 1e-08 - - out_table.add_row( - ( - ss, - col, - fn, - snr2, - pgram_res[fn], - opt_nfringe, - peak_freq, - freq_min, - freq_max, - ) - ) + if ff <= 1e-03: + continue + + # check if snr criteria is met for fringe component, + # should always be true for fringe 1 + if snr2 <= min_snr[fn]: + continue + + log.debug(f" Start ffreq = {ff}") + log.debug(" Fit spectral baseline") + + bg_fit, bgindx = utils.fit_1d_background_complex( + proc_data, + weights_feat, + col_wnum, + ffreq=ffreq[fn], + channel=c, + ) + + # get the residual fringes as fraction of signal + res_fringes = np.divide( + proc_data, + bg_fit, + out=np.zeros_like(proc_data), + where=bg_fit != 0, + ) + res_fringes = np.subtract(res_fringes, 1, where=res_fringes != 0) + res_fringes *= np.where(col_weight > 1e-07, 1, 1e-08) + + # fit the residual fringes + log.debug(" Set up Bayes evidence") + ( + res_fringe_fit, + wpix_num, + opt_nfringe, + peak_freq, + freq_min, + freq_max, + ) = utils.fit_1d_fringes_bayes_evidence( + res_fringes, + weights_feat, + col_wnum, + ffreq[fn], + dffreq[fn], + max_nfringes[fn], + pgram_res[fn], + col_snr2, + ) + + # check for fit blowing up, reset rfc fit to 0, raise a flag + log.debug(" Check residual fringe fit for bad fit regions") + res_fringe_fit, res_fringe_fit_flag = utils.check_res_fringes( + res_fringe_fit, col_max_amp + ) + + # correct for residual fringes + log.debug(" Divide out residual fringe fit") + _, _, _, env, u_x, u_y = utils.fit_envelope( + np.arange(res_fringe_fit.shape[0]), res_fringe_fit + ) + + rfc_factors = 1 / ( + res_fringe_fit * (col_weight > 1e-05).astype(int) + 1 + ) + proc_data *= rfc_factors + proc_factors *= rfc_factors + + # handle nans or infs that may exist + proc_data = np.nan_to_num(proc_data, posinf=1e-08, neginf=1e-08) + proc_data[proc_data < 0] = 1e-08 + + out_table.add_row( + ( + ss, + col, + fn, + snr2, + pgram_res[fn], + opt_nfringe, + peak_freq, + freq_min, + freq_max, + ) + ) # define fringe sub after all fringe components corrections fringe_sub = proc_data.copy() diff --git a/jwst/residual_fringe/tests/test_configuration.py b/jwst/residual_fringe/tests/test_configuration.py index 05b0bd0440..0e3bd9987a 100644 --- a/jwst/residual_fringe/tests/test_configuration.py +++ b/jwst/residual_fringe/tests/test_configuration.py @@ -30,7 +30,7 @@ def miri_image(): @pytest.fixture() def step_log_watcher(monkeypatch): # Set a log watcher to check for a log message at any level - # in the emicorr step + # in the residual_fringe step watcher = LogWatcher("") logger = logging.getLogger("stpipe.ResidualFringeStep") for level in ["debug", "info", "warning", "error"]: diff --git a/jwst/residual_fringe/tests/test_residual_fringe.py b/jwst/residual_fringe/tests/test_residual_fringe.py index b39f8e6d80..02b8c6ca8c 100644 --- a/jwst/residual_fringe/tests/test_residual_fringe.py +++ b/jwst/residual_fringe/tests/test_residual_fringe.py @@ -123,7 +123,7 @@ def one_slice(*args): @pytest.fixture() def module_log_watcher(monkeypatch): # Set a log watcher to check for a log message at any level - # in the emicorr module + # in the residual_fringe module watcher = LogWatcher("") logger = logging.getLogger("jwst.residual_fringe.residual_fringe") for level in ["debug", "info", "warning", "error"]: @@ -157,8 +157,12 @@ def test_rf1d(linear_spectrum, fringed_spectrum): def test_get_wavemap(): - # Test the _get_wavemap function directly, since - # all full calls to the correction method mock it + """ + Test the _get_wavemap function directly. + + A separate test is needed, since calls to the higher level correction method + mock this function for synthetic data simplicity. + """ model = datamodels.IFUImageModel() # Mock a WCS that returns 1 for wavelengths diff --git a/jwst/residual_fringe/utils.py b/jwst/residual_fringe/utils.py index 5eb895586d..f1ab2efc00 100644 --- a/jwst/residual_fringe/utils.py +++ b/jwst/residual_fringe/utils.py @@ -41,12 +41,12 @@ def slice_info(slice_map, channel): slices_in_channel : ndarray of int 1D array of slice IDs included in the channel. xrange_channel : ndarray of int - 1D array with two elements: minimum and maximum x values + 1D array with two elements: minimum and maximum x indices for the channel. slice_x_ranges : ndarray of int N x 3 array for N slices, where the first column is the slice ID, - second column is the minimum x value for the slice, - and the third column is the maximum x value for the slice. + second column is the minimum x index for the slice, + and the third column is the maximum x index for the slice. all_slice_masks : ndarray of int N x nx x ny for N slices, matching the x and y shape of the input slice_map. Values are 1 for pixels included in the slice, @@ -137,7 +137,7 @@ def fill_wavenumbers(wnums): def multi_sine(n_sines): """ - Create a mult-sine model. + Create a multi-sine model. Parameters ---------- @@ -146,7 +146,7 @@ def multi_sine(n_sines): Returns ------- - model : SineModel + model : BayesicFitting.SineModel The model composed of n sines. """ # make the first sine