Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3900: Tests, clean up, and code style for residual_fringe #9242

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ repos:
jwst/refpix/.* |
jwst/resample/.* |
jwst/reset/.* |
jwst/residual_fringe/.* |
jwst/rscd/.* |
jwst/saturation/.* |
jwst/scripts/.* |
Expand Down
2 changes: 0 additions & 2 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ exclude = [
"jwst/refpix/**.py",
"jwst/resample/**.py",
"jwst/reset/**.py",
"jwst/residual_fringe/**.py",
"jwst/rscd/**.py",
"jwst/saturation/**.py",
"jwst/scripts/**.py",
Expand Down Expand Up @@ -140,7 +139,6 @@ ignore-fully-untyped = true # Turn off annotation checking for fully untyped co
"jwst/refpix/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/resample/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/reset/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/residual_fringe/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/rscd/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/saturation/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
"jwst/scripts/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"]
Expand Down
1 change: 1 addition & 0 deletions changes/9242.residual_fringe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor for maintainability, fix intermediate filenames when input datamodel is read from memory, and stop producing an unused intermediate output file.
27 changes: 27 additions & 0 deletions jwst/regtest/test_miri_mrs_spec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ def run_spec2(rtdata_module):
Step.from_cmdline(args)


@pytest.fixture(scope='module')
def run_spec2_with_residual_fringe(rtdata_module):
"""Run the Spec2Pipeline on a single exposure"""
rtdata = rtdata_module

# Get the input rate file
rtdata.get_data(INPUT_PATH + '/' + 'jw01024001001_04101_00001_mirifulong_rate.fits')

# Run the pipeline
args = ["calwebb_spec2", rtdata.input,
'--output_file=jw01024001001_04101_00001_mirifulong_rf',
'--steps.residual_fringe.skip=false',
'--steps.residual_fringe.save_results=true',
]

Step.from_cmdline(args)


@pytest.mark.slow
@pytest.mark.bigdata
@pytest.mark.parametrize(
Expand Down Expand Up @@ -83,3 +101,12 @@ def test_miri_mrs_wcs(run_spec2, fitsdiff_default_kwargs, rtdata_module):
xtruth, ytruth = im_truth.meta.wcs.backward_transform(ratruth, dectruth, lamtruth)
assert_allclose(xtest, xtruth)
assert_allclose(ytest, ytruth)


@pytest.mark.slow
@pytest.mark.bigdata
@pytest.mark.parametrize('suffix', ['residual_fringe', 'rf_s3d', 'rf_x1d', 'rf_cal'])
def test_miri_mrs_spec2_with_rf(run_spec2_with_residual_fringe,
fitsdiff_default_kwargs, suffix, rtdata_module):
rtdata = rtdata_module
rt.is_like_truth(rtdata, fitsdiff_default_kwargs, suffix, truth_path=TRUTH_PATH)
4 changes: 3 additions & 1 deletion jwst/residual_fringe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Correct residual fringes in MIRI MRS data."""

from .residual_fringe_step import ResidualFringeStep

__all__ = ['ResidualFringeStep']
__all__ = ["ResidualFringeStep"]
38 changes: 34 additions & 4 deletions jwst/residual_fringe/fitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import warnings

import numpy as np
import scipy.interpolate


Expand All @@ -8,6 +9,34 @@ def _lsq_spline(x, y, weights, knots, degree):


def spline_fitter(x, y, weights, knots, degree, reject_outliers=False, domain=10, tolerance=0.0001):
"""
Fit a spline function to 1D data.

Parameters
----------
x : ndarray
Independent variable. Must be increasing.
y : ndarray
Dependent variable, matching dimensions of `x`.
weights : ndarray
Weights for spline fitting. Must be positive.
knots : ndarray
Interior knots for the spline.
degree : int
Degree of the spline to fit, >= 1 and <= 5.
reject_outliers : bool, optional
If True, iteratively fit the data with outlier rejection.
domain : int, optional
Factor controlling the outlier threshold when `reject_outliers`
is True.
tolerance : float, optional
Fit convergence tolerance when reject_outliers` is True.

Returns
-------
callable
The spline function fit to the data.
"""
if not reject_outliers:
return _lsq_spline(x, y, weights, knots, degree)

Expand All @@ -16,7 +45,6 @@ def spline_fitter(x, y, weights, knots, degree, reject_outliers=False, domain=10
def chi_sq(spline, weights):
return np.nansum((y - spline(x)) ** 2 * weights)


# initial fit
spline = _lsq_spline(x, y, weights, knots, degree)
chi = chi_sq(spline, weights)
Expand All @@ -26,11 +54,13 @@ def chi_sq(spline, weights):
deg_of_freedom = np.sum(weights) - nparams

for _ in range(1000 * nparams):
scale = np.sqrt(chi / deg_of_freedom)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
scale = np.sqrt(chi / deg_of_freedom)

# Calculate new weights
resid = (y - spline(x)) / (scale * domain)
new_w = (np.where(resid**2 <= 1, 1 - resid**2, 0.))**2 * weights
new_w = (np.where(resid**2 <= 1, 1 - resid**2, 0.0)) ** 2 * weights

# Fit new model and find chi
spline = _lsq_spline(x, y, new_w, knots, degree)
Expand Down
Loading
Loading