Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/spacetelescope/jwst into JP…
Browse files Browse the repository at this point in the history
…-3860-assign-wcs
  • Loading branch information
emolter committed Mar 7, 2025
2 parents 673ad54 + a5a4482 commit 3573b62
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 34 deletions.
1 change: 1 addition & 0 deletions changes/9193.assign_wcs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix MIRI LRS s_region in assign_wcs
1 change: 1 addition & 0 deletions changes/9193.resample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix MIRI LRS s_region and WCS in resample_spec
12 changes: 10 additions & 2 deletions jwst/regtest/test_miri_lrs_slit_spec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jwst.stpipe import Step
from jwst.extract_1d import Extract1dStep

from stcal.alignment import util

@pytest.fixture(scope="module")
def run_pipeline(rtdata_module):
Expand Down Expand Up @@ -67,7 +67,6 @@ def test_miri_lrs_extract1d_from_cal(run_pipeline, rtdata_module, fitsdiff_defau
@pytest.mark.bigdata
def test_miri_lrs_slit_wcs(run_pipeline, rtdata_module, fitsdiff_default_kwargs):
rtdata = rtdata_module

# get input assign_wcs and truth file
output = "jw01530005001_03103_00001_mirimage_assign_wcs.fits"
rtdata.output = output
Expand All @@ -87,3 +86,12 @@ def test_miri_lrs_slit_wcs(run_pipeline, rtdata_module, fitsdiff_default_kwargs)
xtruth, ytruth = im_truth.meta.wcs.backward_transform(ratruth, dectruth, lamtruth)
assert_allclose(xtest, xtruth)
assert_allclose(ytest, ytruth)

# Test the s_region. S_region is formed by footprint which contains
# floats rather than a string. Test footprint
sregion = im.meta.wcsinfo.s_region
sregion_test = im_truth.meta.wcsinfo.s_region
footprint=util.sregion_to_footprint(sregion)
footprint_test = util.sregion_to_footprint(sregion_test)
assert_allclose(footprint, footprint_test)

56 changes: 30 additions & 26 deletions jwst/resample/resample_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
)
from astropy.modeling.fitting import LinearLSQFitter
from astropy.stats import sigma_clip

from astropy.utils.exceptions import AstropyUserWarning
from gwcs import wcstools, WCS
from gwcs import coordinate_frames as cf

from stdatamodels.jwst import datamodels

from jwst.assign_wcs.util import compute_scale, wcs_bbox_from_shape, wrap_ra
from jwst.assign_wcs.util import compute_scale, wcs_bbox_from_shape,\
wrap_ra
from jwst.resample import resample_utils
from jwst.resample.resample import ResampleImage
from jwst.datamodels import ModelLibrary


log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -163,6 +164,7 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square",
# Any other customizations (crpix, crval, rotation) are ignored.
if resample_utils.is_sky_like(input_models[0].meta.wcs.output_frame):
if input_models[0].meta.instrument.name != "NIRSPEC":

output_wcs = self.build_interpolated_output_wcs(
input_models,
pixel_scale_ratio=pixel_scale_ratio
Expand Down Expand Up @@ -576,13 +578,15 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):
all_dec_slit = []
xstop = 0

all_wcs = [m.meta.wcs for m in input_models]
for im, model in enumerate(input_models):
wcs = model.meta.wcs
bbox = wcs.bounding_box
grid = wcstools.grid_from_bounding_box(bbox)
ra, dec, lam = np.array(wcs(*grid))
# Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The
# following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order

# Handle vertical (MIRI). The following 2 variables are
# 0 or 1, i.e. zero-indexed in x,y WCS order
spectral_axis = find_dispersion_axis(model)
spatial_axis = spectral_axis ^ 1

Expand All @@ -599,7 +603,7 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):
# sampling.

# Steps to do this for first input model:
# 1. find the middle of the spectrum in wavelength
# 1. Find the middle of the spectrum in wavelength
# 2. Pull out the ra and dec at the center of the slit.
# 3. Find the mean ra,dec and the center of the slit this will
# represent the tangent point
Expand All @@ -614,7 +618,7 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):
lam_center_index = int((bbox[spectral_axis][1] -
bbox[spectral_axis][0]) / 2)
if spatial_axis == 0:
# MIRI LRS, the WCS x axis is spatial
# MIRI LRS spectral = 1, the spatial axis = 0
ra_slice = ra[lam_center_index, :]
dec_slice = dec[lam_center_index, :]
else:
Expand All @@ -637,10 +641,10 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):
x_tan, y_tan = undist2sky1.inverse(ra, dec)

# pull out data from center
if spectral_axis == 0: # MIRI LRS, the WCS x axis is spatial
if spectral_axis == 0:
x_tan_array = x_tan.T[lam_center_index]
y_tan_array = y_tan.T[lam_center_index]
else:
else: # MIRI LRS Spectral Axis = 1, the WCS x axis is spatial
x_tan_array = x_tan[lam_center_index]
y_tan_array = y_tan[lam_center_index]

Expand Down Expand Up @@ -750,26 +754,26 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):

native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180)
undist2sky = tan | native2celestial
# find the spatial size of the output - same in x,y
if swap_xy:
_, x_tan_all = undist2sky.inverse(all_ra, all_dec)
pix_to_tan_slope = pix_to_ytan.slope
else:
x_tan_all, _ = undist2sky.inverse(all_ra, all_dec)
pix_to_tan_slope = pix_to_xtan.slope

x_min = np.amin(x_tan_all)
x_max = np.amax(x_tan_all)
x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_tan_slope)))
## Use all the wcs
min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent(
all_wcs, undist2sky.inverse)
diff_y = np.abs(max_tan_y - min_tan_y)
diff_x = np.abs(max_tan_x - min_tan_x)
pix_to_tan_slope_y = np.abs(pix_to_ytan.slope)
slope_sign_y = np.sign(pix_to_ytan.slope)
pix_to_tan_slope_x = np.abs(pix_to_xtan.slope)
slope_sign_x = np.sign(pix_to_xtan.slope)

if swap_xy:
pix_to_ytan.intercept = -0.5 * (x_size - 1) * pix_to_ytan.slope
ny = int(np.ceil(diff_y / pix_to_tan_slope_y))
else:
pix_to_xtan.intercept = -0.5 * (x_size - 1) * pix_to_xtan.slope
ny = int(np.ceil(diff_x / pix_to_tan_slope_x))

# single model: use size of x_tan_array
# to be consistent with method before
if len(input_models) == 1:
x_size = int(np.ceil(xstop))
offset_y = (ny)/2 * pix_to_tan_slope_y
offset_x = (ny)/2 * pix_to_tan_slope_x
pix_to_ytan.intercept = - slope_sign_y * offset_y
pix_to_xtan.intercept = - slope_sign_x * offset_x

# define the output wcs
transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength
Expand All @@ -789,14 +793,13 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0):
# compute the output array size in WCS axes order, i.e. (x, y)
output_array_size = [0, 0]
output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array)))
output_array_size[spatial_axis] = x_size
output_array_size[spatial_axis] = ny

# turn the size into a numpy shape in (y, x) order
output_wcs.array_shape = output_array_size[::-1]
output_wcs.pixel_shape = output_array_size
bounding_box = wcs_bbox_from_shape(output_array_size[::-1])
output_wcs.bounding_box = bounding_box

return output_wcs

def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio):
Expand Down Expand Up @@ -1011,3 +1014,4 @@ def compute_spectral_pixel_scale(wcs, fiducial=None, disp_axis=1):

pixel_scale = compute_scale(wcs, fiducial, disp_axis=disp_axis)
return float(pixel_scale)

12 changes: 9 additions & 3 deletions jwst/resample/resample_spec_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jwst.datamodels import ModelContainer, ModelLibrary
from jwst.lib.pipe_utils import match_nans_and_flags
from jwst.lib.wcs_utils import get_wavelengths
from jwst.resample.resample_utils import load_custom_wcs
from jwst.resample.resample_utils import load_custom_wcs, find_miri_lrs_sregion

from . import resample_spec, ResampleStep
from ..exp_to_source import multislit_to_container
Expand Down Expand Up @@ -265,9 +265,15 @@ def _process_slit(self, input_models):
else:
result.meta.resample.pixel_scale_ratio = resamp.pixel_scale_ratio
result.meta.resample.pixfrac = self.pixfrac
self.update_slit_metadata(result)
update_s_region_spectral(result)

self.update_slit_metadata(result)
if result.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
s_region_model1 = input_models[0].meta.wcsinfo.s_region
s_region = find_miri_lrs_sregion(s_region_model1, result.meta.wcs)
result.meta.wcsinfo.s_region = s_region
self.log.info(f'Updating S_REGION: {s_region}.')
else:
update_s_region_spectral(result)
return result

def update_slit_metadata(self, model):
Expand Down
86 changes: 86 additions & 0 deletions jwst/resample/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from drizzle.utils import decode_context as _drizzle_decode_context

from stdatamodels.jwst.datamodels.dqflags import pixel
from astropy.coordinates import SkyCoord

from stcal.alignment.util import (
compute_scale,
wcs_from_sregions,
)
from gwcs import wcstools

from stcal.alignment.util import compute_s_region_keyword
from stcal.resample import UnsupportedWCSError
from stcal.resample.utils import compute_mean_pixel_area
from stcal.resample.utils import (
Expand Down Expand Up @@ -462,3 +466,85 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None):
"pixel_scale": user_pixel_scale,
}
return wcs_dict


def find_miri_lrs_sregion(sregion_model1, wcs):
""" Find s region for MIRI LRS resampled data.
Parameters
----------
sregion_model1 : string
s_regions of the first input model
wcs : gwcs.WCS
Spatial/spectral WCS.
Returns
-------
sregion : string
s_region for the resample data.
"""
# use the first sregion to set the width of the slit
spatial_box = sregion_model1
s = spatial_box.split(' ')
a1 = float(s[3])
b1 = float(s[4])
a2 = float(s[5])
b2 = float(s[6])
a3 = float(s[7])
b3 = float(s[8])
a4 = float(s[9])
b4 = float(s[10])

# convert each corner to SkyCoord
coord1 = SkyCoord(a1, b1, unit='deg')
coord2 = SkyCoord(a2, b2, unit='deg')
coord3 = SkyCoord(a3, b3, unit='deg')
coord4 = SkyCoord(a4, b4, unit='deg')

# Find the distance between the corners
# corners are counterclockwize from 1,2,3,4
sep1 = coord1.separation(coord2)
sep2 = coord2.separation(coord3)
sep3 = coord3.separation(coord4)
sep4 = coord4.separation(coord1)

# use the separation values so we can find the min value later
sep = [sep1.value, sep2.value, sep3.value, sep4.value]

# the minimum separation is the slit width
min_sep = np.min(sep)
min_sep = min_sep* u.deg # set the units to degrees

log.info(f'Estimated MIRI LRS slit width: {min_sep*3600} arcsec.')
# now use the combined WCS to map all pixels to the slit center
bbox = wcs.bounding_box
grid = wcstools.grid_from_bounding_box(bbox)
ra, dec, _ = np.array(wcs(*grid))
ra = ra.flatten()
dec = dec.flatten()
# ra and dec are the values along the output resampled slit center
# using the first point and last point find the position angle
star1 = SkyCoord(ra[0]*u.deg, dec[0]*u.deg, frame='icrs')
star2 = SkyCoord(ra[-1]*u.deg, dec[-1]*u.deg, frame='icrs')
position_angle = star1.position_angle(star2).to(u.deg)

# 90 degrees to the position angle of the slit will define s_region
pos_angle = position_angle - 90.0*u.deg

star_c1 = star1.directional_offset_by(pos_angle, min_sep/2)
star_c2 = star1.directional_offset_by(pos_angle, -min_sep/2)
star_c3 = star2.directional_offset_by(pos_angle, min_sep/2)
star_c4 = star2.directional_offset_by(pos_angle, -min_sep/2)

# set these values to footprint
# ra,dec corners are in counter-clockwise direction
footprint = [star_c1.ra.value, star_c1.dec.value,
star_c3.ra.value, star_c3.dec.value,
star_c4.ra.value, star_c4.dec.value,
star_c2.ra.value, star_c2.dec.value]
footprint = np.array(footprint)
s_region = compute_s_region_keyword(footprint)
return s_region



13 changes: 10 additions & 3 deletions jwst/resample/tests/test_resample_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def _set_photom_kwd(im):
def miri_rate_model():
xsize = 72
ysize = 416
sregion = 'POLYGON ICRS 10.355323877 -22.353560934 10.355437846 -22.353464295 ' + \
'10.354477543 -22.352498313 10.354363599 -22.352595345'

shape = (ysize, xsize)
im = ImageModel(shape)
im.data += 5
Expand All @@ -55,7 +58,8 @@ def miri_rate_model():
'v2_ref': -453.5134,
'v3_ref': -373.4826,
'v3yangle': 0.0,
'vparity': -1}
'vparity': -1,
's_region': sregion}
im.meta.instrument = {
'detector': 'MIRIMAGE',
'filter': 'P750L',
Expand Down Expand Up @@ -115,6 +119,8 @@ def miri_cal(miri_rate):
def miri_rate_zero_crossing():
xsize = 1032
ysize = 1024
sregion = 'POLYGON ICRS 10.355323877 -22.353560934 10.355437846 -22.353464295 ' + \
'10.354477543 -22.352498313 10.354363599 -22.352595345'
shape = (ysize, xsize)
im = ImageModel(shape)
im.var_rnoise = np.random.random(shape)
Expand All @@ -125,7 +131,8 @@ def miri_rate_zero_crossing():
'v2_ref': -415.0690466121227,
'v3_ref': -400.575920398547,
'v3yangle': 0.0,
'vparity': -1}
'vparity': -1,
's_region': sregion}
im.meta.instrument = {
'detector': 'MIRIMAGE',
'filter': 'P750L',
Expand Down Expand Up @@ -519,7 +526,7 @@ def test_pixel_scale_ratio_spec_miri(miri_cal, ratio, units):

@pytest.mark.parametrize("units", ["MJy", "MJy/sr"])
@pytest.mark.parametrize("ratio", [0.7, 1.0, 1.3])
def test_pixel_scale_ratio_spec_miri_pair(miri_rate_pair, ratio, units):
def test_pixel_scale_ratio_1spec_miri_pair(miri_rate_pair, ratio, units):
im1, im2 = miri_rate_pair
_set_photom_kwd(im1)
_set_photom_kwd(im2)
Expand Down

0 comments on commit 3573b62

Please sign in to comment.