Skip to content

Commit

Permalink
JP-3848 MIRI LRS s_region and resample WCS (#9193)
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke authored Mar 7, 2025
2 parents 9d288dd + 8a43968 commit a5a4482
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 79 deletions.
1 change: 1 addition & 0 deletions changes/9193.assign_wcs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix MIRI LRS s_region in assign_wcs
1 change: 1 addition & 0 deletions changes/9193.resample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix MIRI LRS s_region and WCS in resample_spec
12 changes: 9 additions & 3 deletions jwst/assign_wcs/assign_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import importlib
from gwcs.wcs import WCS
from .util import (update_s_region_spectral, update_s_region_imaging,
update_s_region_nrs_ifu, update_s_region_mrs)
update_s_region_nrs_ifu, update_s_region_mrs,
update_s_region_lrs)
from ..lib.exposure_types import IMAGING_TYPES, SPEC_TYPES, NRS_LAMP_MODE_SPEC_TYPES
from ..lib.dispaxis import get_dispersion_direction
from ..lib.wcs_utils import get_wavelengths
Expand Down Expand Up @@ -72,8 +73,13 @@ def load_wcs(input_model, reference_files={}, nrs_slit_y_range=None):

if output_model.meta.exposure.type.lower() not in exclude_types:
imaging_types = IMAGING_TYPES.copy()
imaging_types.update(['mir_lrs-fixedslit', 'mir_lrs-slitless'])
if output_model.meta.exposure.type.lower() in imaging_types:
imaging_types.update(['mir_lrs-slitless'])
imaging_lrs_types = ['mir_lrs-fixedslit']
if output_model.meta.exposure.type.lower() in imaging_lrs_types:
# uses slits corners in V2, V3 that are read in from the
# lrs specwcs reference file
update_s_region_lrs(output_model, reference_files)
elif output_model.meta.exposure.type.lower() in imaging_types:
try:
update_s_region_imaging(output_model)
except Exception as exc:
Expand Down
68 changes: 28 additions & 40 deletions jwst/assign_wcs/miri.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from astropy.modeling import models
from astropy import coordinates as coord
from astropy import units as u
from astropy.io import fits

from scipy.interpolate import UnivariateSpline
import gwcs.coordinate_frames as cf
from gwcs import selector

from stdatamodels.jwst.datamodels import (DistortionModel, FilteroffsetModel,
DistortionMRSModel, WavelengthrangeModel,
RegionsModel, SpecwcsModel)
RegionsModel, SpecwcsModel, MiriLRSSpecwcsModel)
from stdatamodels.jwst.transforms.models import (MIRI_AB2Slice, IdealToV2V3)

from . import pointing
Expand Down Expand Up @@ -239,7 +237,6 @@ def lrs_xytoabl(input_model, reference_files):
the "specwcs" and "distortion" reference files.
"""

# subarray to full array transform
subarray2full = subarray_transform(input_model)

Expand All @@ -253,19 +250,13 @@ def lrs_xytoabl(input_model, reference_files):
else:
subarray_dist = distortion

ref = fits.open(reference_files['specwcs'])

with ref:
lrsdata = np.array([d for d in ref[1].data])
# Get the zero point from the reference data.
# The zero_point is X, Y (which should be COLUMN, ROW)
# These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
zero_point = ref[0].header['imxsltl'] - 1, ref[0].header['imysltl'] - 1
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])
refmodel = MiriLRSSpecwcsModel(reference_files['specwcs'])
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = refmodel.meta.x_ref - 1, refmodel.meta.y_ref - 1
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
zero_point = refmodel.meta.x_ref_slitless - 1, refmodel.meta.y_ref_slitless - 1
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])

# Figure out the typical along-slice pixel scale at the center of the slit
v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1])
Expand All @@ -276,14 +267,14 @@ def lrs_xytoabl(input_model, reference_files):
# centroid trace along the detector in pixels relative to nominal location.
# x0,y0(ul) x1,y1 (ur) x2,y2(lr) x3,y3(ll) define corners of the box within which the distortion
# and wavelength calibration was derived
xcen = lrsdata[:, 0]
ycen = lrsdata[:, 1]
wavetab = lrsdata[:, 2]
x0 = lrsdata[:, 3]
y0 = lrsdata[:, 4]
x1 = lrsdata[:, 5]
y2 = lrsdata[:, 8]

xcen = refmodel.wavetable.x_center
ycen = refmodel.wavetable.y_center
wavetab = refmodel.wavetable.wavelength
x0 = refmodel.wavetable.x0
y0 = refmodel.wavetable.y0
x1 = refmodel.wavetable.x1
y2 = refmodel.wavetable.y2
refmodel.close()
# If in fixed slit mode, define the bounding box using the corner locations provided in
# the CDP reference file.
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
Expand Down Expand Up @@ -313,7 +304,6 @@ def lrs_xytoabl(input_model, reference_files):
# This function will give slit dX as a function of Y subarray pixel value
dxmodel = models.Tabular1D(lookup_table=xshiftref, points=ycen_subarray, name='xshiftref',
bounds_error=False, fill_value=np.nan)

if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
bb_sub = (bb_sub[0], (dxmodel.points[0].min(), dxmodel.points[0].max()))
# Fit for the wavelength as a function of Y
Expand All @@ -325,7 +315,6 @@ def lrs_xytoabl(input_model, reference_files):
# This model will now give the wavelength corresponding to a given Y subarray pixel value
wavemodel = models.Tabular1D(lookup_table=wavereference, points=ycen_subarray, name='waveref',
bounds_error=False, fill_value=np.nan)

# Wavelength barycentric correction
try:
velosys = input_model.meta.wcsinfo.velosys
Expand Down Expand Up @@ -383,6 +372,7 @@ def lrs_xytoabl(input_model, reference_files):

return dettoabl


def lrs_abltov2v3l(input_model, reference_files):
"""
The second part of LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.
Expand All @@ -405,19 +395,16 @@ def lrs_abltov2v3l(input_model, reference_files):
else:
subarray_dist = distortion

ref = fits.open(reference_files['specwcs'])

with ref:
# Get the zero point from the reference data.
# The zero_point is X, Y (which should be COLUMN, ROW)
# These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
zero_point = ref[0].header['imxsltl'] - 1, ref[0].header['imysltl'] - 1
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])

refmodel = MiriLRSSpecwcsModel(reference_files['specwcs'])
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = refmodel.meta.x_ref - 1, refmodel.meta.y_ref - 1
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
zero_point = refmodel.meta.x_ref_slitless - 1, \
refmodel.meta.y_ref_slitless - 1
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])

refmodel.close()
# Figure out the typical along-slice pixel scale at the center of the slit
v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1])
v2_off, v3_off = subarray_dist(zero_point[0] + 1, zero_point[1])
Expand Down Expand Up @@ -447,6 +434,7 @@ def lrs_abltov2v3l(input_model, reference_files):

return abl_to_v2v3l


def ifu(input_model, reference_files):
"""
The MIRI MRS WCS pipeline.
Expand Down
51 changes: 49 additions & 2 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from stpipe.exceptions import StpipeExitException
from stcal.alignment.util import compute_s_region_keyword, compute_s_region_imaging

from stdatamodels.jwst.datamodels import WavelengthrangeModel
from stdatamodels.jwst.datamodels import WavelengthrangeModel, MiriLRSSpecwcsModel
from stdatamodels.jwst.transforms.models import GrismObject

from ..lib.catalog_utils import SkyObject
from jwst.lib.catalog_utils import SkyObject


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -807,6 +807,51 @@ def update_s_region_imaging(model):
model.meta.wcsinfo.s_region = s_region


def update_s_region_lrs(model, reference_files):
"""
Update ``S_REGION`` using V2,V3 of the slit corners from reference file.
s_region for model is updated in place.
Parameters
----------
model : DataModel
Input model
reference_files : list
List of reference files for assign_wcs.
"""
refmodel = MiriLRSSpecwcsModel(reference_files['specwcs'])

v2vert1 = refmodel.meta.v2_vert1
v2vert2 = refmodel.meta.v2_vert2
v2vert3 = refmodel.meta.v2_vert3
v2vert4 = refmodel.meta.v2_vert4

v3vert1 = refmodel.meta.v3_vert1
v3vert2 = refmodel.meta.v3_vert2
v3vert3 = refmodel.meta.v3_vert3
v3vert4 = refmodel.meta.v3_vert4

refmodel.close()
v2 = [v2vert1, v2vert2, v2vert3, v2vert4]
v3 = [v3vert1, v3vert2, v3vert3, v3vert4]

if (any(elem is None for elem in v2) or
any(elem is None for elem in v3)):
log.info("The V2,V3 coordinates of the MIRI LRS-Fixed slit contains NaN values.")
log.info("The s_region will not be updated")

lam = 7.0 # wavelength does not matter for s region so just assign a value in range of LRS
s = model.meta.wcs.transform('v2v3', 'world', v2, v3, lam)
a = s[0]
b = s[1]
footprint = np.array([[a[0], b[0]],
[a[1], b[1]],
[a[2], b[2]],
[a[3], b[3]]])

update_s_region_keyword(model, footprint)

def compute_footprint_spectral(model):
"""
Determine spatial footprint for spectral observations using the instrument model.
Expand Down Expand Up @@ -844,9 +889,11 @@ def compute_footprint_spectral(model):
return footprint, (lam_min, lam_max)



def update_s_region_spectral(model):
""" Update the S_REGION keyword.
"""

footprint, spectral_region = compute_footprint_spectral(model)
update_s_region_keyword(model, footprint)
model.meta.wcsinfo.spectral_region = spectral_region
Expand Down
12 changes: 10 additions & 2 deletions jwst/regtest/test_miri_lrs_slit_spec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jwst.stpipe import Step
from jwst.extract_1d import Extract1dStep

from stcal.alignment import util

@pytest.fixture(scope="module")
def run_pipeline(rtdata_module):
Expand Down Expand Up @@ -67,7 +67,6 @@ def test_miri_lrs_extract1d_from_cal(run_pipeline, rtdata_module, fitsdiff_defau
@pytest.mark.bigdata
def test_miri_lrs_slit_wcs(run_pipeline, rtdata_module, fitsdiff_default_kwargs):
rtdata = rtdata_module

# get input assign_wcs and truth file
output = "jw01530005001_03103_00001_mirimage_assign_wcs.fits"
rtdata.output = output
Expand All @@ -87,3 +86,12 @@ def test_miri_lrs_slit_wcs(run_pipeline, rtdata_module, fitsdiff_default_kwargs)
xtruth, ytruth = im_truth.meta.wcs.backward_transform(ratruth, dectruth, lamtruth)
assert_allclose(xtest, xtruth)
assert_allclose(ytest, ytruth)

# Test the s_region. S_region is formed by footprint which contains
# floats rather than a string. Test footprint
sregion = im.meta.wcsinfo.s_region
sregion_test = im_truth.meta.wcsinfo.s_region
footprint=util.sregion_to_footprint(sregion)
footprint_test = util.sregion_to_footprint(sregion_test)
assert_allclose(footprint, footprint_test)

Loading

0 comments on commit a5a4482

Please sign in to comment.