From 6a60f2b4ba855810de2902a0427f02213c2d10a9 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Mar 2025 12:52:35 -0500 Subject: [PATCH 1/3] Apply code style rules to assign_wcs module --- .pre-commit-config.yaml | 1 - .ruff.toml | 2 - jwst/assign_wcs/__init__.py | 14 +- jwst/assign_wcs/assign_wcs.py | 139 +- jwst/assign_wcs/assign_wcs_step.py | 110 +- jwst/assign_wcs/fgs.py | 94 +- jwst/assign_wcs/miri.py | 617 +++--- jwst/assign_wcs/nircam.py | 358 ++-- jwst/assign_wcs/niriss.py | 378 ++-- jwst/assign_wcs/nirspec.py | 1684 ++++++++++------- jwst/assign_wcs/pointing.py | 200 +- jwst/assign_wcs/tools/__init__.py | 1 + jwst/assign_wcs/tools/miri/__init__.py | 1 + jwst/assign_wcs/tools/nirspec/__init__.py | 1 + .../nirspec/create_configuration_test.py | 204 +- jwst/assign_wcs/util.py | 805 ++++---- 16 files changed, 2762 insertions(+), 1847 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a56d755dc8..ed07d22161 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,6 @@ repos: - id: numpydoc-validation exclude: | (?x)^( - jwst/assign_wcs/.* | jwst/associations/.* | jwst/background/.* | jwst/coron/.* | diff --git a/.ruff.toml b/.ruff.toml index 86f9e8e55e..9c1004d7fd 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -19,7 +19,6 @@ quote-style = "double" indent-style = "space" docstring-code-format = true exclude = [ - "jwst/assign_wcs/**.py", "jwst/associations/**.py", "jwst/background/**.py", "jwst/coron/**.py", @@ -119,7 +118,6 @@ ignore-fully-untyped = true # Turn off annotation checking for fully untyped co "jwst/associations/tests*" = [ "F841", # unused variable ] -"jwst/assign_wcs/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/associations/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/background/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/coron/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] diff --git a/jwst/assign_wcs/__init__.py b/jwst/assign_wcs/__init__.py index bd27aca1fa..688c426562 100644 --- a/jwst/assign_wcs/__init__.py +++ b/jwst/assign_wcs/__init__.py @@ -1,7 +1,15 @@ +"""Assign WCS information to JWST data models.""" + from .assign_wcs_step import AssignWcsStep -from .nirspec import (nrs_wcs_set_input, nrs_ifu_wcs, get_spectral_order_wrange) +from .nirspec import nrs_wcs_set_input, nrs_ifu_wcs, get_spectral_order_wrange from .niriss import niriss_soss_set_input from .util import update_fits_wcsinfo -__all__ = ['AssignWcsStep', "nrs_wcs_set_input", "nrs_ifu_wcs", "get_spectral_order_wrange", - "niriss_soss_set_input", "update_fits_wcsinfo"] +__all__ = [ + "AssignWcsStep", + "nrs_wcs_set_input", + "nrs_ifu_wcs", + "get_spectral_order_wrange", + "niriss_soss_set_input", + "update_fits_wcsinfo", +] diff --git a/jwst/assign_wcs/assign_wcs.py b/jwst/assign_wcs/assign_wcs.py index ca61e6d188..fdc64cd093 100644 --- a/jwst/assign_wcs/assign_wcs.py +++ b/jwst/assign_wcs/assign_wcs.py @@ -1,8 +1,13 @@ import logging import importlib from gwcs.wcs import WCS -from .util import (update_s_region_spectral, update_s_region_imaging, - update_s_region_nrs_ifu, update_s_region_mrs) +from .util import ( + update_s_region_spectral, + update_s_region_imaging, + update_s_region_nrs_ifu, + update_s_region_mrs, + update_s_region_lrs, +) from ..lib.exposure_types import IMAGING_TYPES, SPEC_TYPES, NRS_LAMP_MODE_SPEC_TYPES from ..lib.dispaxis import get_dispersion_direction from ..lib.wcs_utils import get_wavelengths @@ -14,91 +19,113 @@ __all__ = ["load_wcs"] -def load_wcs(input_model, reference_files={}, nrs_slit_y_range=None): +def load_wcs(input_model, reference_files=None, nrs_slit_y_range=None): """ Create a gWCS object and store it in ``Model.meta``. Parameters ---------- input_model : `~jwst.datamodels.JwstDataModel` - The exposure. + The input data model. reference_files : dict - A dict {reftype: reference_file_name} containing all - reference files that apply to this exposure. + Mapping between reftype (keys) and reference file name (vals). nrs_slit_y_range : list The slit y-range for a Nirspec slit. The center is (0, 0). + + Returns + ------- + output_model : `~jwst.datamodels.JwstDataModel` + The data model with the WCS information in the meta attribute. """ - if reference_files: + if reference_files is not None: for ref_type, ref_file in reference_files.items(): if ref_file not in ["N/A", ""]: reference_files[ref_type] = ref_file else: reference_files[ref_type] = None - if not any(reference_files.values()): + if (reference_files is None) or (not any(reference_files.values())): log.critical("assign_wcs needs reference files to compute the WCS, none were passed") raise ValueError("assign_wcs needs reference files to compute the WCS, none were passed") instrument = input_model.meta.instrument.name.lower() - mod = importlib.import_module('.' + instrument, 'jwst.assign_wcs') + mod = importlib.import_module("." + instrument, "jwst.assign_wcs") - if input_model.meta.exposure.type.lower() in SPEC_TYPES or \ - input_model.meta.instrument.lamp_mode.lower() in NRS_LAMP_MODE_SPEC_TYPES: + if ( + input_model.meta.exposure.type.lower() in SPEC_TYPES + or input_model.meta.instrument.lamp_mode.lower() in NRS_LAMP_MODE_SPEC_TYPES + ): input_model.meta.wcsinfo.specsys = "BARYCENT" - input_model.meta.wcsinfo.dispersion_direction = \ - get_dispersion_direction( - input_model.meta.exposure.type, - input_model.meta.instrument.grating, - input_model.meta.instrument.filter, - input_model.meta.instrument.pupil) + input_model.meta.wcsinfo.dispersion_direction = get_dispersion_direction( + input_model.meta.exposure.type, + input_model.meta.instrument.grating, + input_model.meta.instrument.filter, + input_model.meta.instrument.pupil, + ) - if instrument.lower() == 'nirspec': + if instrument.lower() == "nirspec": pipeline = mod.create_pipeline(input_model, reference_files, slit_y_range=nrs_slit_y_range) else: pipeline = mod.create_pipeline(input_model, reference_files) # Initialize the output model as a copy of the input # Make the copy after the WCS pipeline is created in order to pass updates to the model. if pipeline is None: - input_model.meta.cal_step.assign_wcs = 'SKIPPED' + input_model.meta.cal_step.assign_wcs = "SKIPPED" log.warning("assign_wcs: SKIPPED") return input_model - else: - output_model = input_model.copy() - wcs = WCS(pipeline) - output_model.meta.wcs = wcs - output_model.meta.cal_step.assign_wcs = 'COMPLETE' - exclude_types = ['nrc_wfss', 'nrc_tsgrism', 'nis_wfss', - 'nrs_fixedslit', 'nrs_msaspec', - 'nrs_autowave', 'nrs_autoflat', 'nrs_lamp', - 'nrs_brightobj', 'nis_soss'] - if output_model.meta.exposure.type.lower() not in exclude_types: - imaging_types = IMAGING_TYPES.copy() - imaging_types.update(['mir_lrs-fixedslit', 'mir_lrs-slitless']) - if output_model.meta.exposure.type.lower() in imaging_types: - try: - update_s_region_imaging(output_model) - except Exception as exc: - log.error("Unable to update S_REGION for type {}: {}".format( - output_model.meta.exposure.type, exc)) - else: - log.info("assign_wcs updated S_REGION to {0}".format( - output_model.meta.wcsinfo.s_region)) - if output_model.meta.exposure.type.lower() == 'mir_lrs-slitless': - output_model.wavelength = get_wavelengths(output_model) - elif output_model.meta.exposure.type.lower() == "nrs_ifu": - update_s_region_nrs_ifu(output_model, mod) - elif output_model.meta.exposure.type.lower() == 'mir_mrs': - update_s_region_mrs(output_model) + output_model = input_model.copy() + wcs = WCS(pipeline) + output_model.meta.wcs = wcs + output_model.meta.cal_step.assign_wcs = "COMPLETE" + exclude_types = [ + "nrc_wfss", + "nrc_tsgrism", + "nis_wfss", + "nrs_fixedslit", + "nrs_msaspec", + "nrs_autowave", + "nrs_autoflat", + "nrs_lamp", + "nrs_brightobj", + "nis_soss", + ] + + if output_model.meta.exposure.type.lower() not in exclude_types: + imaging_types = IMAGING_TYPES.copy() + imaging_types.update(["mir_lrs-slitless"]) + imaging_lrs_types = ["mir_lrs-fixedslit"] + if output_model.meta.exposure.type.lower() in imaging_lrs_types: + # uses slits corners in V2, V3 that are read in from the + # lrs specwcs reference file + update_s_region_lrs(output_model, reference_files) + elif output_model.meta.exposure.type.lower() in imaging_types: + try: + update_s_region_imaging(output_model) + except Exception as exc: + log.error( + f"Unable to update S_REGION for type {output_model.meta.exposure.type}: {exc}" + ) else: - try: - update_s_region_spectral(output_model) - except Exception as exc: - log.info("Unable to update S_REGION for type {}: {}".format( - output_model.meta.exposure.type, exc)) + log.info(f"assign_wcs updated S_REGION to {output_model.meta.wcsinfo.s_region}") + if output_model.meta.exposure.type.lower() == "mir_lrs-slitless": + output_model.wavelength = get_wavelengths(output_model) + elif output_model.meta.exposure.type.lower() == "nrs_ifu": + update_s_region_nrs_ifu(output_model, mod) + elif output_model.meta.exposure.type.lower() == "mir_mrs": + update_s_region_mrs(output_model) + else: + try: + update_s_region_spectral(output_model) + except Exception as exc: + log.info( + f"Unable to update S_REGION for type {output_model.meta.exposure.type}: {exc}" + ) - # Store position of dithered pointing location in metadata for later spectral extraction - if output_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit': - store_dithered_position(output_model) - log.debug(f"Storing dithered pointing location information:" - f"{output_model.meta.dither.dithered_ra} {output_model.meta.dither.dithered_dec}") + # Store position of dithered pointing location in metadata for later spectral extraction + if output_model.meta.exposure.type.lower() == "mir_lrs-fixedslit": + store_dithered_position(output_model) + log.debug( + "Storing dithered pointing location information: " + f"{output_model.meta.dither.dithered_ra} {output_model.meta.dither.dithered_dec}" + ) log.info("COMPLETED assign_wcs") return output_model diff --git a/jwst/assign_wcs/assign_wcs_step.py b/jwst/assign_wcs/assign_wcs_step.py index 8d5ba034d8..94e6fb7b29 100755 --- a/jwst/assign_wcs/assign_wcs_step.py +++ b/jwst/assign_wcs/assign_wcs_step.py @@ -5,8 +5,7 @@ from ..lib.exposure_types import IMAGING_TYPES import logging from .assign_wcs import load_wcs -from .util import (MSAFileError, wfss_imaging_wcs, - wcs_bbox_from_shape, update_fits_wcsinfo) +from .util import MSAFileError, wfss_imaging_wcs, wcs_bbox_from_shape, update_fits_wcsinfo from .nircam import imaging as nircam_imaging from .niriss import imaging as niriss_imaging @@ -17,7 +16,7 @@ __all__ = ["AssignWcsStep"] -WFSS_TYPES = set(['nrc_wfss', 'nis_wfss']) +WFSS_TYPES = {"nrc_wfss", "nis_wfss"} class AssignWcsStep(Step): @@ -41,11 +40,6 @@ class AssignWcsStep(Step): specwcs Wavelength calibration models (MIRI, NIRCAM, NIRISS) regions Stores location of the regions on the detector (MIRI) wavelengthrange Typical wavelength ranges (MIRI, NIRCAM, NIRISS, NIRSPEC) - - Parameters - ---------- - input : `~jwst.datamodels.ImageModel`, `~jwst.datamodels.IFUImageModel`, `~jwst.datamodels.CubeModel` - Input exposure. """ class_alias = "assign_wcs" @@ -57,40 +51,68 @@ class AssignWcsStep(Step): sip_max_inv_pix_error = float(default=0.01) # max err for SIP fit, inverse. sip_inv_degree = integer(max=6, default=None) # degree for inverse SIP fit, None to use best fit. sip_npoints = integer(default=12) # number of points for SIP - slit_y_low = float(default=-.55) # The lower edge of a slit (NIRSpec only). - slit_y_high = float(default=.55) # The upper edge of a slit (NIRSpec only). - """ # noqa: E501 - - reference_file_types = ['distortion', 'filteroffset', 'specwcs', 'regions', - 'wavelengthrange', 'camera', 'collimator', 'disperser', - 'fore', 'fpa', 'msa', 'ote', 'ifupost', - 'ifufore', 'ifuslicer'] - - def process(self, input, *args, **kwargs): + slit_y_low = float(default=-.55) # The lower edge of a slit. + slit_y_high = float(default=.55) # The upper edge of a slit. + """ # noqa: E501 + + reference_file_types = [ + "distortion", + "filteroffset", + "specwcs", + "regions", + "wavelengthrange", + "camera", + "collimator", + "disperser", + "fore", + "fpa", + "msa", + "ote", + "ifupost", + "ifufore", + "ifuslicer", + ] + + def process(self, input_data): + """ + Run the assign_wcs step. + + Parameters + ---------- + input_data : JwstDataModel or str + Either a jwst data model or a string that is the path to one. + + Returns + ------- + result : JwstDataModel + The data model with the WCS information added. + """ reference_file_names = {} - with datamodels.open(input) as input_model: + with datamodels.open(input_data) as input_model: # If input type is not supported, log warning, set to 'skipped', exit - if not (isinstance(input_model, datamodels.ImageModel) or - isinstance(input_model, datamodels.CubeModel) or - isinstance(input_model, datamodels.IFUImageModel)): + if not ( + isinstance(input_model, datamodels.ImageModel) + or isinstance(input_model, datamodels.CubeModel) + or isinstance(input_model, datamodels.IFUImageModel) + ): log.warning("Input dataset type is not supported.") log.warning("assign_wcs expects ImageModel, IFUImageModel or CubeModel as input.") log.warning("Skipping assign_wcs step.") result = input_model.copy() - result.meta.cal_step.assign_wcs = 'SKIPPED' + result.meta.cal_step.assign_wcs = "SKIPPED" return result for reftype in self.reference_file_types: reffile = self.get_reference_file(input_model, reftype) reference_file_names[reftype] = reffile if reffile else "" - log.debug(f'reference files used in assign_wcs: {reference_file_names}') + log.debug(f"reference files used in assign_wcs: {reference_file_names}") # Get the MSA metadata file if needed and add to reffiles if input_model.meta.exposure.type == "NRS_MSASPEC": msa_metadata_file = input_model.meta.instrument.msa_metadata_file if msa_metadata_file is not None and msa_metadata_file.strip() not in ["", "N/A"]: msa_metadata_file = self.make_input_path(msa_metadata_file) - reference_file_names['msametafile'] = msa_metadata_file + reference_file_names["msametafile"] = msa_metadata_file else: message = "MSA metadata file (MSAMETFL) is required for NRS_MSASPEC exposures." log.error(message) @@ -98,7 +120,10 @@ def process(self, input, *args, **kwargs): slit_y_range = [self.slit_y_low, self.slit_y_high] result = load_wcs(input_model, reference_file_names, slit_y_range) - if not (result.meta.exposure.type.lower() in (IMAGING_TYPES.union(WFSS_TYPES)) and self.sip_approx): + if not ( + result.meta.exposure.type.lower() in (IMAGING_TYPES.union(WFSS_TYPES)) + and self.sip_approx + ): return result result_exptype = result.meta.exposure.type.lower() @@ -112,32 +137,39 @@ def process(self, input, *args, **kwargs): max_inv_pix_error=self.sip_max_inv_pix_error, inv_degree=self.sip_inv_degree, npoints=self.sip_npoints, - crpix=None + crpix=None, ) except (ValueError, RuntimeError) as e: - log.warning("Failed to update 'meta.wcsinfo' with FITS SIP " - "approximation. Reported error is:") + log.warning( + "Failed to update 'meta.wcsinfo' with FITS SIP " + "approximation. Reported error is:" + ) log.warning(f'"{e.args[0]}"') else: # WFSS modes try: # A bounding_box is needed for the imaging WCS bbox = wcs_bbox_from_shape(result.data.shape) - if result_exptype == 'nis_wfss': + if result_exptype == "nis_wfss": imaging_func = niriss_imaging else: imaging_func = nircam_imaging - wfss_imaging_wcs(result, imaging_func, bbox=bbox, - max_pix_error=self.sip_max_pix_error, - degree=self.sip_degree, - max_inv_pix_error=self.sip_max_inv_pix_error, - inv_degree=self.sip_inv_degree, - npoints=self.sip_npoints, - ) + wfss_imaging_wcs( + result, + imaging_func, + bbox=bbox, + max_pix_error=self.sip_max_pix_error, + degree=self.sip_degree, + max_inv_pix_error=self.sip_max_inv_pix_error, + inv_degree=self.sip_inv_degree, + npoints=self.sip_npoints, + ) except (ValueError, RuntimeError) as e: - log.warning("Failed to update 'meta.wcsinfo' with FITS SIP " - "approximation. Reported error is:") + log.warning( + "Failed to update 'meta.wcsinfo' with FITS SIP " + "approximation. Reported error is:" + ) log.warning(f'"{e.args[0]}"') return result diff --git a/jwst/assign_wcs/fgs.py b/jwst/assign_wcs/fgs.py index 8b0a9522fe..d7dcadc0cd 100644 --- a/jwst/assign_wcs/fgs.py +++ b/jwst/assign_wcs/fgs.py @@ -1,6 +1,5 @@ -""" -FGS WCS pipeline - depends on EXP_TYPE. -""" +"""FGS WCS pipeline - depends on EXP_TYPE.""" + import logging from astropy import units as u @@ -10,8 +9,12 @@ from stdatamodels.jwst.datamodels import DistortionModel -from .util import (not_implemented_mode, subarray_transform, - transform_bbox_from_shape, bounding_box_from_subarray) +from .util import ( + not_implemented_mode, + subarray_transform, + transform_bbox_from_shape, + bounding_box_from_subarray, +) from . import pointing log = logging.getLogger(__name__) @@ -28,45 +31,49 @@ def create_pipeline(input_model, reference_files): Parameters ---------- input_model : `~jwst.datamodels.JwstDataModel` - The data model. + The input data model. reference_files : dict - {reftype: file_name} mapping. - Reference files. + Mapping between reftype (keys) and reference file name (vals). + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ exp_type = input_model.meta.exposure.type.lower() pipeline = exp_type2transform[exp_type](input_model, reference_files) - log.info("Creating a FGS {0} pipeline with references {1}".format( - exp_type, reference_files)) + log.info(f"Creating a FGS {exp_type} pipeline with references {reference_files}") return pipeline def imaging(input_model, reference_files): """ - The FGS imaging WCS pipeline. + Create the WCS pipeline for FGS imaging data. It includes 3 coordinate frames - "detector", "v2v3" and "world". - Uses a ``distortion`` reference file. Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` - The data model. + input_model : ImageModel + The input data model. reference_files : dict - {reftype: file_name} mapping. - Reference files. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' reference files Returns ------- pipeline : list - The WCS pipeline. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ # Create coordinate frames for the ``imaging`` mode. - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), axes_names=('v2', 'v3'), - unit=(u.arcsec, u.arcsec)) - v2v3vacorr = cf.Frame2D(name='v2v3vacorr', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.arcsec, u.arcsec)) - world = cf.CelestialFrame(name='world', reference_frame=coord.ICRS()) + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + v2v3 = cf.Frame2D( + name="v2v3", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + v2v3vacorr = cf.Frame2D( + name="v2v3vacorr", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + world = cf.CelestialFrame(name="world", reference_frame=coord.ICRS()) # Create the v2v3 to sky transform. tel2sky = pointing.v23tosky(input_model) @@ -88,29 +95,41 @@ def imaging(input_model, reference_files): va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) - pipeline = [(detector, distortion), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + pipeline = [(detector, distortion), (v2v3, va_corr), (v2v3vacorr, tel2sky), (world, None)] return pipeline def imaging_distortion(input_model, reference_files): """ Create the transform from "detector" to "v2v3". + + Parameters + ---------- + input_model : ImageModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' reference file. + + Returns + ------- + transform : `astropy.modeling.Model + The transform from "detector" to "v2v3". """ - dist = DistortionModel(reference_files['distortion']) + dist = DistortionModel(reference_files["distortion"]) transform = dist.model # Check if the transform in the reference file has a ``bounding_box``. # If not set a ``bounding_box`` equal to the size of the image. try: - transform.bounding_box + transform.bounding_box # noqa: B018 except NotImplementedError: - bind_bounding_box(transform, transform_bbox_from_shape(input_model.data.shape, order="F"), order="F") + bind_bounding_box( + transform, transform_bbox_from_shape(input_model.data.shape, order="F"), order="F" + ) dist.close() return transform @@ -118,9 +137,10 @@ def imaging_distortion(input_model, reference_files): # EXP_TYPE to function mapping. # The function creates the WCS pipeline. -exp_type2transform = {'fgs_image': imaging, - 'fgs_focus': imaging, - 'fgs_skyflat': not_implemented_mode, - 'fgs_intflat': not_implemented_mode, - 'fgs_dark': not_implemented_mode - } +exp_type2transform = { + "fgs_image": imaging, + "fgs_focus": imaging, + "fgs_skyflat": not_implemented_mode, + "fgs_intflat": not_implemented_mode, + "fgs_dark": not_implemented_mode, +} diff --git a/jwst/assign_wcs/miri.py b/jwst/assign_wcs/miri.py index fd3df3a264..4528dd3c5f 100644 --- a/jwst/assign_wcs/miri.py +++ b/jwst/assign_wcs/miri.py @@ -1,25 +1,33 @@ -import os.path +from pathlib import Path import logging import numpy as np from astropy.modeling import bind_bounding_box from astropy.modeling import models from astropy import coordinates as coord from astropy import units as u -from astropy.io import fits - from scipy.interpolate import UnivariateSpline import gwcs.coordinate_frames as cf from gwcs import selector -from stdatamodels.jwst.datamodels import (DistortionModel, FilteroffsetModel, - DistortionMRSModel, WavelengthrangeModel, - RegionsModel, SpecwcsModel) -from stdatamodels.jwst.transforms.models import (MIRI_AB2Slice, IdealToV2V3) +from stdatamodels.jwst.datamodels import ( + DistortionModel, + FilteroffsetModel, + DistortionMRSModel, + WavelengthrangeModel, + RegionsModel, + SpecwcsModel, + MiriLRSSpecwcsModel, +) +from stdatamodels.jwst.transforms.models import MIRI_AB2Slice, IdealToV2V3 from . import pointing -from .util import (not_implemented_mode, subarray_transform, - velocity_correction, transform_bbox_from_shape, - bounding_box_from_subarray) +from .util import ( + not_implemented_mode, + subarray_transform, + velocity_correction, + transform_bbox_from_shape, + bounding_box_from_subarray, +) log = logging.getLogger(__name__) @@ -35,45 +43,51 @@ def create_pipeline(input_model, reference_files): Parameters ---------- - input_model : `jwst.datamodels.ImagingModel`, `~jwst.datamodels.IFUImageModel`, - `~jwst.datamodels.CubeModel` - Data model. + input_model : ImageModel, IFUImageModel, CubeModel + The input data model. reference_files : dict - {reftype: reference file name} mapping. + Mapping between reftype (keys) and reference file name (vals). + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ exp_type = input_model.meta.exposure.type.lower() pipeline = exp_type2transform[exp_type](input_model, reference_files) if pipeline: - log.info("Created a MIRI {0} pipeline with references {1}".format( - exp_type, reference_files)) + log.info("Created a MIRI {exp_type} pipeline with references {reference_files}") return pipeline def imaging(input_model, reference_files): """ - The MIRI Imaging WCS pipeline. + Create the WCS pipeline for MIRI imaging data. - It includes three coordinate frames - - "detector", "v2v3" and "world". + It includes three coordinate frames - "detector", "v2v3" and "world". Parameters ---------- - input_model : `jwst.datamodels.ImagingModel` - Data model. + input_model : ImageModel, IFUImageModel, CubeModel + The input data model. reference_files : dict - Dictionary {reftype: reference file name}. - Uses "distortion" and "filteroffset" reference files. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'filteroffset' reference files. + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - # Create the Frames - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), axes_names=('v2', 'v3'), - unit=(u.arcsec, u.arcsec)) - v2v3vacorr = cf.Frame2D(name='v2v3vacorr', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.arcsec, u.arcsec)) - world = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world') + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + v2v3 = cf.Frame2D( + name="v2v3", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + v2v3vacorr = cf.Frame2D( + name="v2v3vacorr", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + world = cf.CelestialFrame(reference_frame=coord.ICRS(), name="world") # Create the transforms distortion = imaging_distortion(input_model, reference_files) @@ -87,26 +101,24 @@ def imaging(input_model, reference_files): else: # TODO: remove setting the bounding box if it is set in the new ref file. try: - distortion.bounding_box + distortion.bounding_box # noqa: B018 except NotImplementedError: shape = input_model.data.shape - bind_bounding_box(distortion, ((3.5, shape[1] - 4.5), (-0.5, shape[0] - 0.5)), order="F") + bind_bounding_box( + distortion, ((3.5, shape[1] - 4.5), (-0.5, shape[0] - 0.5)), order="F" + ) # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) tel2sky = pointing.v23tosky(input_model) # Create the pipeline - pipeline = [(detector, distortion), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None) - ] + pipeline = [(detector, distortion), (v2v3, va_corr), (v2v3vacorr, tel2sky), (world, None)] return pipeline @@ -122,9 +134,21 @@ def imaging_distortion(input_model, reference_files): 4. Apply the TI matrix (this gives V2/V3 coordinates) (uses "distortion" ref file) 5. Apply V2V3 --> sky transform + Parameters + ---------- + input_model : ImageModel, IFUImageModel, or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'filteroffset' reference files. + + Returns + ------- + distortion : `astropy.modeling.Model` + The transform from "detector" to "v2v3". """ # Read in the distortion. - with DistortionModel(reference_files['distortion']) as dist: + with DistortionModel(reference_files["distortion"]) as dist: distortion = dist.model # Check if the transform in the reference file has a ``bounding_box``. @@ -136,7 +160,7 @@ def imaging_distortion(input_model, reference_files): # Add an offset for the filter obsfilter = input_model.meta.instrument.filter - with FilteroffsetModel(reference_files['filteroffset']) as filter_offset: + with FilteroffsetModel(reference_files["filteroffset"]) as filter_offset: filters = filter_offset.filters col_offset = None @@ -155,7 +179,7 @@ def imaging_distortion(input_model, reference_files): bind_bounding_box( distortion, transform_bbox_from_shape(input_model.data.shape, order="F") if bbox is None else bbox, - order="F" + order="F", ) return distortion @@ -163,48 +187,65 @@ def imaging_distortion(input_model, reference_files): def lrs(input_model, reference_files): """ - The LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline. + Create the WCS pipeline for LRS-FIXEDSLIT and LRS-SLITLESS data. - Notes - ----- - It includes three coordinate frames - - "detector", "v2v3" and "world". + It includes three coordinate frames - "detector", "v2v3" and "world". "v2v3" and "world" each have (spatial, spatial, spectral) components. - Uses the "specwcs" and "distortion" reference files. + Parameters + ---------- + input_model : ImageModel or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'specwcs' reference files. + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ # Define the various coordinate frames. # Original detector frame - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) # Intermediate slit frame - alpha_beta = cf.Frame2D(name='alpha_beta_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('alpha', 'beta')) - spec_local = cf.SpectralFrame(name='alpha_beta_spectral', axes_order=(2,), - unit=(u.micron,), axes_names=('lambda',)) - miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name='alpha_beta') - + alpha_beta = cf.Frame2D( + name="alpha_beta_spatial", + axes_order=(0, 1), + unit=(u.arcsec, u.arcsec), + axes_names=("alpha", "beta"), + ) + spec_local = cf.SpectralFrame( + name="alpha_beta_spectral", axes_order=(2,), unit=(u.micron,), axes_names=("lambda",) + ) + miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name="alpha_beta") # Spectral component - spec = cf.SpectralFrame(name='spec', axes_order=(2,), unit=(u.micron,), axes_names=('lambda',)) + spec = cf.SpectralFrame(name="spec", axes_order=(2,), unit=(u.micron,), axes_names=("lambda",)) # v2v3 spatial component - v2v3_spatial = cf.Frame2D(name='v2v3_spatial', axes_order=(0, 1), unit=(u.arcsec, u.arcsec), - axes_names=('v2', 'v3')) + v2v3_spatial = cf.Frame2D( + name="v2v3_spatial", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=("v2", "v3") + ) v2v3vacorr_spatial = cf.Frame2D( - name='v2v3vacorr_spatial', + name="v2v3vacorr_spatial", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), - axes_names=('v2', 'v3') + axes_names=("v2", "v3"), ) # v2v3 spatial+spectra - v2v3 = cf.CompositeFrame([v2v3_spatial, spec], name='v2v3') - v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr') + v2v3 = cf.CompositeFrame([v2v3_spatial, spec], name="v2v3") + v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name="v2v3vacorr") # 'icrs' frame which is the spatial sky component - icrs = cf.CelestialFrame(name='icrs', reference_frame=coord.ICRS(), - axes_order=(0, 1), unit=(u.deg, u.deg), axes_names=('RA', 'DEC')) + icrs = cf.CelestialFrame( + name="icrs", + reference_frame=coord.ICRS(), + axes_order=(0, 1), + unit=(u.deg, u.deg), + axes_names=("RA", "DEC"), + ) # Final 'world' composite frame with spatial and spectral components world = cf.CompositeFrame(name="world", frames=[icrs, spec]) @@ -218,33 +259,43 @@ def lrs(input_model, reference_files): va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) & models.Identity(1) # Put the transforms together into a single pipeline - pipeline = [(detector, dettoabl), - (miri_focal, abltov2v3l), - (v2v3, va_corr), - (v2v3vacorr, teltosky), - (world, None)] + pipeline = [ + (detector, dettoabl), + (miri_focal, abltov2v3l), + (v2v3, va_corr), + (v2v3vacorr, teltosky), + (world, None), + ] return pipeline def lrs_xytoabl(input_model, reference_files): """ - The first part of LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline. + Build the first transform for the LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline. - Transform from subarray (x, y) to (alpha, beta, lambda) using - the "specwcs" and "distortion" reference files. + Parameters + ---------- + input_model : ImageModel, IFUImageModel, or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'specwcs' reference files. + Returns + ------- + dettoabl : `astropy.modeling.Model` + The transform from subarray (x, y) to (alpha, beta, lambda). """ - # subarray to full array transform subarray2full = subarray_transform(input_model) # full array to v2v3 transform for the ordinary imager - with DistortionModel(reference_files['distortion']) as dist: + with DistortionModel(reference_files["distortion"]) as dist: distortion = dist.model # Combine models to create subarray to v2v3 distortion @@ -253,54 +304,55 @@ def lrs_xytoabl(input_model, reference_files): else: subarray_dist = distortion - ref = fits.open(reference_files['specwcs']) - - with ref: - lrsdata = np.array([d for d in ref[1].data]) - # Get the zero point from the reference data. - # The zero_point is X, Y (which should be COLUMN, ROW) - # These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed - if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit': - zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1 - elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless': - zero_point = ref[0].header['imxsltl'] - 1, ref[0].header['imysltl'] - 1 - # Transform to slitless subarray from full array - zero_point = subarray2full.inverse(zero_point[0], zero_point[1]) + refmodel = MiriLRSSpecwcsModel(reference_files["specwcs"]) + if input_model.meta.exposure.type.lower() == "mir_lrs-fixedslit": + zero_point = refmodel.meta.x_ref - 1, refmodel.meta.y_ref - 1 + elif input_model.meta.exposure.type.lower() == "mir_lrs-slitless": + zero_point = refmodel.meta.x_ref_slitless - 1, refmodel.meta.y_ref_slitless - 1 + # Transform to slitless subarray from full array + zero_point = subarray2full.inverse(zero_point[0], zero_point[1]) # Figure out the typical along-slice pixel scale at the center of the slit v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1]) v2_off, v3_off = subarray_dist(zero_point[0] + 1, zero_point[1]) - pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off,2)) + pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off, 2)) # In the lrsdata reference table, X_center,y_center,wavelength describe the location of the # centroid trace along the detector in pixels relative to nominal location. # x0,y0(ul) x1,y1 (ur) x2,y2(lr) x3,y3(ll) define corners of the box within which the distortion # and wavelength calibration was derived - xcen = lrsdata[:, 0] - ycen = lrsdata[:, 1] - wavetab = lrsdata[:, 2] - x0 = lrsdata[:, 3] - y0 = lrsdata[:, 4] - x1 = lrsdata[:, 5] - y2 = lrsdata[:, 8] - + xcen = refmodel.wavetable.x_center + ycen = refmodel.wavetable.y_center + wavetab = refmodel.wavetable.wavelength + x0 = refmodel.wavetable.x0 + y0 = refmodel.wavetable.y0 + x1 = refmodel.wavetable.x1 + y2 = refmodel.wavetable.y2 + refmodel.close() # If in fixed slit mode, define the bounding box using the corner locations provided in # the CDP reference file. - if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit': - - bb_sub = ((np.floor(x0.min() + zero_point[0]) - 0.5, np.ceil(x1.max() + zero_point[0]) + 0.5), - (np.floor(y2.min() + zero_point[1]) - 0.5, np.ceil(y0.max() + zero_point[1]) + 0.5)) + if input_model.meta.exposure.type.lower() == "mir_lrs-fixedslit": + bb_sub = ( + (np.floor(x0.min() + zero_point[0]) - 0.5, np.ceil(x1.max() + zero_point[0]) + 0.5), + (np.floor(y2.min() + zero_point[1]) - 0.5, np.ceil(y0.max() + zero_point[1]) + 0.5), + ) # If in slitless mode, define the bounding box X locations using the subarray x boundaries # and the y locations using the corner locations in the CDP reference file. Make sure to # omit the 4 reference pixels on the left edge of slitless subarray. - if input_model.meta.exposure.type.lower() == 'mir_lrs-slitless': - bb_sub = ((input_model.meta.subarray.xstart - 1 + 4 - 0.5, input_model.meta.subarray.xsize - 1 + 0.5), - (np.floor(y2.min() + zero_point[1]) - 0.5, np.ceil(y0.max() + zero_point[1]) + 0.5)) + if input_model.meta.exposure.type.lower() == "mir_lrs-slitless": + bb_sub = ( + ( + input_model.meta.subarray.xstart - 1 + 4 - 0.5, + input_model.meta.subarray.xsize - 1 + 0.5, + ), + (np.floor(y2.min() + zero_point[1]) - 0.5, np.ceil(y0.max() + zero_point[1]) + 0.5), + ) # Now deal with the fact that the spectral trace isn't perfectly up and down along detector. - # This information is contained in the xcenter/ycenter values in the CDP table, but we'll handle it - # as a simple x shift using a linear fit to this relation provided by the CDP. + # This information is contained in the xcenter/ycenter values in the CDP table, + # but we'll handle it as a simple x shift using a linear fit + # to this relation provided by the CDP. # First convert the values in CDP table to subarray x/y xcen_subarray = xcen + zero_point[0] ycen_subarray = ycen + zero_point[1] @@ -311,10 +363,14 @@ def lrs_xytoabl(input_model, reference_files): # Evaluate the fit at the y reference points xshiftref = spl(ycen_subarray) # This function will give slit dX as a function of Y subarray pixel value - dxmodel = models.Tabular1D(lookup_table=xshiftref, points=ycen_subarray, name='xshiftref', - bounds_error=False, fill_value=np.nan) - - if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit': + dxmodel = models.Tabular1D( + lookup_table=xshiftref, + points=ycen_subarray, + name="xshiftref", + bounds_error=False, + fill_value=np.nan, + ) + if input_model.meta.exposure.type.lower() == "mir_lrs-fixedslit": bb_sub = (bb_sub[0], (dxmodel.points[0].min(), dxmodel.points[0].max())) # Fit for the wavelength as a function of Y # Reverse the vectors so that yinv is increasing (needed for spline fitting function) @@ -323,9 +379,13 @@ def lrs_xytoabl(input_model, reference_files): # Evaluate the fit at the y reference points wavereference = spl(ycen_subarray) # This model will now give the wavelength corresponding to a given Y subarray pixel value - wavemodel = models.Tabular1D(lookup_table=wavereference, points=ycen_subarray, name='waveref', - bounds_error=False, fill_value=np.nan) - + wavemodel = models.Tabular1D( + lookup_table=wavereference, + points=ycen_subarray, + name="waveref", + bounds_error=False, + fill_value=np.nan, + ) # Wavelength barycentric correction try: velosys = input_model.meta.wcsinfo.velosys @@ -335,7 +395,9 @@ def lrs_xytoabl(input_model, reference_files): if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) wavemodel = wavemodel | velocity_corr - log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value)) + log.info( + f"Applied Barycentric velocity correction : {velocity_corr[1].amplitude.value}" + ) # What is the effective slit X as a function of subarray x,y? xmodel = models.Mapping([0], n_inputs=2) - (models.Mapping([1], n_inputs=2) | dxmodel) @@ -344,18 +406,20 @@ def lrs_xytoabl(input_model, reference_files): # What is the effective XY as a function of subarray x,y? xymodel = models.Mapping((0, 1, 0, 1)) | xmodel & ymodel # What is the alpha as a function of slit XY? - alphamodel = models.Mapping([0], n_inputs=2) | \ - models.Shift(-zero_point[0]) | \ - models.Polynomial1D(1, c0=0, c1=pscale) + alphamodel = ( + models.Mapping([0], n_inputs=2) + | models.Shift(-zero_point[0]) + | models.Polynomial1D(1, c0=0, c1=pscale) + ) # What is the alpha,beta as a function of slit XY? (beta is always zero) abmodel = models.Mapping((0, 1, 0)) | alphamodel & models.Const1D(0) # Define a shift by the reference point and immediately back again - # This doesn't do anything effectively, but it stores the reference point for later use in pathloss - reftransform = models.Shift(-zero_point[0]) & \ - models.Shift(-zero_point[1]) | \ - models.Shift(+zero_point[0]) & \ - models.Shift(+zero_point[1]) + # This doesn't do anything effectively, + # but it stores the reference point for later use in pathloss + reftransform = models.Shift(-zero_point[0]) & models.Shift(-zero_point[1]) | models.Shift( + +zero_point[0] + ) & models.Shift(+zero_point[1]) # Put the transforms together xytoab = reftransform | xymodel | abmodel @@ -365,7 +429,7 @@ def lrs_xytoabl(input_model, reference_files): # Construct the inverse distortion model (alpha,beta,wavelength -> xsub,ysub) # Go from alpha to slit-X - slitxmodel = models.Polynomial1D(1, c0=0, c1=1/pscale) | models.Shift(zero_point[0]) + slitxmodel = models.Polynomial1D(1, c0=0, c1=1 / pscale) | models.Shift(zero_point[0]) # Go from lambda to real y lam_to_y = wavemodel.inverse # Go from slit-x and real y to real-x @@ -378,25 +442,35 @@ def lrs_xytoabl(input_model, reference_files): # Go from alpha,beta,lam, to real x,y dettoabl.inverse = models.Mapping((0, 1, 2, 0, 1, 2)) | aa & bb - # Bounding box is the subarray bounding box, because we're assuming subarray coordinates passed in + # Bounding box is the subarray bounding box, + # because we're assuming subarray coordinates passed in bind_bounding_box(dettoabl, bb_sub, order="F") return dettoabl + def lrs_abltov2v3l(input_model, reference_files): """ - The second part of LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline. + Build the first transform for the LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline. - Transform from (alpha, beta, lambda) to (v2, v3, lambda) using - the "specwcs" and "distortion" reference files. + Parameters + ---------- + input_model : ImageModel, IFUImageModel, or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'specwcs' reference files. + Returns + ------- + abl_to_v2v3l : `astropy.modeling.Model` + The transform from (alpha, beta, lambda) to (v2, v3, lambda). """ - # subarray to full array transform subarray2full = subarray_transform(input_model) # full array to v2v3 transform for the ordinary imager - with DistortionModel(reference_files['distortion']) as dist: + with DistortionModel(reference_files["distortion"]) as dist: distortion = dist.model # Combine models to create subarray to v2v3 distortion @@ -405,23 +479,19 @@ def lrs_abltov2v3l(input_model, reference_files): else: subarray_dist = distortion - ref = fits.open(reference_files['specwcs']) - - with ref: - # Get the zero point from the reference data. - # The zero_point is X, Y (which should be COLUMN, ROW) - # These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed - if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit': - zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1 - elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless': - zero_point = ref[0].header['imxsltl'] - 1, ref[0].header['imysltl'] - 1 - # Transform to slitless subarray from full array - zero_point = subarray2full.inverse(zero_point[0], zero_point[1]) + refmodel = MiriLRSSpecwcsModel(reference_files["specwcs"]) + if input_model.meta.exposure.type.lower() == "mir_lrs-fixedslit": + zero_point = refmodel.meta.x_ref - 1, refmodel.meta.y_ref - 1 + elif input_model.meta.exposure.type.lower() == "mir_lrs-slitless": + zero_point = refmodel.meta.x_ref_slitless - 1, refmodel.meta.y_ref_slitless - 1 + # Transform to slitless subarray from full array + zero_point = subarray2full.inverse(zero_point[0], zero_point[1]) + refmodel.close() # Figure out the typical along-slice pixel scale at the center of the slit v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1]) v2_off, v3_off = subarray_dist(zero_point[0] + 1, zero_point[1]) - pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off,2)) + pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off, 2)) # Go from alpha to slit-X slitxmodel = models.Polynomial1D(1, c0=0, c1=1 / pscale) | models.Shift(zero_point[0]) @@ -443,61 +513,90 @@ def lrs_abltov2v3l(input_model, reference_files): # Go from v2,v3 to alpha, beta aa = v2v3_to_xydet | alphamodel & betamodel # Go from v2,v3,lambda to alpha,beta,lambda - abl_to_v2v3l.inverse = models.Mapping((0,1,2)) | aa & models.Identity(1) + abl_to_v2v3l.inverse = models.Mapping((0, 1, 2)) | aa & models.Identity(1) return abl_to_v2v3l + def ifu(input_model, reference_files): """ - The MIRI MRS WCS pipeline. + Create the WCS pipeline for MIRI IFU data. + + It has the following coordinate frames: "detector", "alpha_beta", "v2v3", "world". - It has the following coordinate frames: - "detector", "alpha_beta", "v2v3", "world". + Parameters + ---------- + input_model : ImageModel or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion', 'specwcs', 'regions', and 'wavelengthrange' reference files. - It uses the "distortion", "regions", "specwcs" - and "wavelengthrange" reference files. + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ # Define coordinate frames. - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - alpha_beta = cf.Frame2D(name='alpha_beta_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('alpha', 'beta')) - spec_local = cf.SpectralFrame(name='alpha_beta_spectral', axes_order=(2,), - unit=(u.micron,), axes_names=('lambda',)) - miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name='alpha_beta') - v23_spatial = cf.Frame2D(name='v2v3_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3')) - v2v3vacorr_spatial = cf.Frame2D(name='v2v3vacorr_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3')) - - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('lambda',)) - v2v3 = cf.CompositeFrame([v23_spatial, spec], name='v2v3') - v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr') - icrs = cf.CelestialFrame(name='icrs', reference_frame=coord.ICRS(), - axes_order=(0, 1), unit=(u.deg, u.deg), axes_names=('RA', 'DEC')) - world = cf.CompositeFrame([icrs, spec], name='world') + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + alpha_beta = cf.Frame2D( + name="alpha_beta_spatial", + axes_order=(0, 1), + unit=(u.arcsec, u.arcsec), + axes_names=("alpha", "beta"), + ) + spec_local = cf.SpectralFrame( + name="alpha_beta_spectral", axes_order=(2,), unit=(u.micron,), axes_names=("lambda",) + ) + miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name="alpha_beta") + v23_spatial = cf.Frame2D( + name="v2v3_spatial", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=("v2", "v3") + ) + v2v3vacorr_spatial = cf.Frame2D( + name="v2v3vacorr_spatial", + axes_order=(0, 1), + unit=(u.arcsec, u.arcsec), + axes_names=("v2", "v3"), + ) + + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("lambda",) + ) + v2v3 = cf.CompositeFrame([v23_spatial, spec], name="v2v3") + v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name="v2v3vacorr") + icrs = cf.CelestialFrame( + name="icrs", + reference_frame=coord.ICRS(), + axes_order=(0, 1), + unit=(u.deg, u.deg), + axes_names=("RA", "DEC"), + ) + world = cf.CompositeFrame([icrs, spec], name="world") # Define the actual transforms - det2abl = (detector_to_abl(input_model, reference_files)).rename( - "detector_to_abl") + det2abl = (detector_to_abl(input_model, reference_files)).rename("detector_to_abl") abl2v2v3l = (abl_to_v2v3l(input_model, reference_files)).rename("abl_to_v2v3l") # Compute differential velocity aberration (DVA) correction: va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) & models.Identity(1) tel2sky = pointing.v23tosky(input_model) & models.Identity(1) # Put the transforms together into a single transform - bind_bounding_box(det2abl, transform_bbox_from_shape(input_model.data.shape, order="F"), order="F") - pipeline = [(detector, det2abl), - (miri_focal, abl2v2v3l), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + bind_bounding_box( + det2abl, transform_bbox_from_shape(input_model.data.shape, order="F"), order="F" + ) + pipeline = [ + (detector, det2abl), + (miri_focal, abl2v2v3l), + (v2v3, va_corr), + (v2v3vacorr, tel2sky), + (world, None), + ] return pipeline @@ -516,21 +615,34 @@ def detector_to_abl(input_model, reference_files): {channel_wave_range (): LabelMapperDict} {beta: slice_number} selector is {slice_number: x_transform & y_transform} + + Parameters + ---------- + input_model : ImageModel or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion', 'specwcs', 'regions', and 'wavelengthrange' reference files. + + Returns + ------- + det2alpha_beta : `astropy.modeling.Model` + The transform from "detector" to "alpha_beta" frame. """ band = input_model.meta.instrument.band channel = input_model.meta.instrument.channel # used to read the wavelength range - with DistortionMRSModel(reference_files['distortion']) as dist: + with DistortionMRSModel(reference_files["distortion"]) as dist: alpha_model = dist.alpha_model beta_model = dist.beta_model x_model = dist.x_model y_model = dist.y_model - bzero = dict(zip(dist.bzero.channel_band, dist.bzero.beta_zero)) - bdel = dict(zip(dist.bdel.channel_band, dist.bdel.delta_beta)) + bzero = dict(zip(dist.bzero.channel_band, dist.bzero.beta_zero, strict=True)) + bdel = dict(zip(dist.bdel.channel_band, dist.bdel.delta_beta, strict=True)) slices = dist.slices - with SpecwcsModel(reference_files['specwcs']) as f: + with SpecwcsModel(reference_files["specwcs"]) as f: lambda_model = f.model try: @@ -541,9 +653,11 @@ def detector_to_abl(input_model, reference_files): if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) lambda_model = [m | velocity_corr for m in lambda_model] - log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value)) + log.info( + f"Applied Barycentric velocity correction : {velocity_corr[1].amplitude.value}" + ) - with RegionsModel(reference_files['regions']) as f: + with RegionsModel(reference_files["regions"]) as f: allregions = f.regions.copy() # Use the 80% throughput slice mask regions = allregions[7, :, :] @@ -552,29 +666,33 @@ def detector_to_abl(input_model, reference_files): transforms = {} for i, sl in enumerate(slices): - forward = models.Mapping([1, 0, 0, 1, 0]) | \ - alpha_model[i] & beta_model[i] & lambda_model[i] + forward = models.Mapping([1, 0, 0, 1, 0]) | alpha_model[i] & beta_model[i] & lambda_model[i] inv = models.Mapping([2, 0, 2, 0]) | x_model[i] & y_model[i] forward.inverse = inv transforms[sl] = forward - with WavelengthrangeModel(reference_files['wavelengthrange']) as f: - wr = dict(zip(f.waverange_selector, f.wavelengthrange)) + with WavelengthrangeModel(reference_files["wavelengthrange"]) as f: + wr = dict(zip(f.waverange_selector, f.wavelengthrange, strict=True)) ch_dict = {} for c in channel: cb = c + band mapper = MIRI_AB2Slice(bzero[cb], bdel[cb], c) - lm = selector.LabelMapper(inputs=('alpha', 'beta', 'lam'), - mapper=mapper, inputs_mapping=models.Mapping((1,), n_inputs=3)) + lm = selector.LabelMapper( + inputs=("alpha", "beta", "lam"), + mapper=mapper, + inputs_mapping=models.Mapping((1,), n_inputs=3), + ) ch_dict[tuple(wr[cb])] = lm - alpha_beta_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), ch_dict, - models.Mapping((2,))) + alpha_beta_mapper = selector.LabelMapperRange( + ("alpha", "beta", "lam"), ch_dict, models.Mapping((2,)) + ) label_mapper.inverse = alpha_beta_mapper - det2alpha_beta = selector.RegionsSelector(('x', 'y'), ('alpha', 'beta', 'lam'), - label_mapper=label_mapper, selector=transforms) + det2alpha_beta = selector.RegionsSelector( + ("x", "y"), ("alpha", "beta", "lam"), label_mapper=label_mapper, selector=transforms + ) return det2alpha_beta @@ -593,27 +711,41 @@ def abl_to_v2v3l(input_model, reference_files): label_mapper is LabelMapperDict() {channel_wave_range (): channel_number} selector is {channel_number: v22ab & v32ab} + + Parameters + ---------- + input_model : ImageModel or CubeModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'wavelengthrange' reference files. + + Returns + ------- + abl2v2v3l : `astropy.modeling.Model` + The transform from "alpha_beta" to "v2v3" frame. """ band = input_model.meta.instrument.band channel = input_model.meta.instrument.channel # used to read the wavelength range channels = [c + band for c in channel] - with DistortionMRSModel(reference_files['distortion']) as dist: - v23 = dict(zip(dist.abv2v3_model.channel_band, dist.abv2v3_model.model)) + with DistortionMRSModel(reference_files["distortion"]) as dist: + v23 = dict(zip(dist.abv2v3_model.channel_band, dist.abv2v3_model.model, strict=True)) - with WavelengthrangeModel(reference_files['wavelengthrange']) as f: - wr = dict(zip(f.waverange_selector, f.wavelengthrange)) + with WavelengthrangeModel(reference_files["wavelengthrange"]) as f: + wr = dict(zip(f.waverange_selector, f.wavelengthrange, strict=True)) dict_mapper = {} sel = {} # Since there are two channels in each reference file we need to loop over them for c in channels: ch = int(c[0]) - dict_mapper[tuple(wr[c])] = (models.Mapping((2,), name="mapping_lam") | - models.Const1D(ch, name="channel #")) - ident1 = models.Identity(1, name='identity_lam') - ident1._inputs = ('lam',) + dict_mapper[tuple(wr[c])] = models.Mapping((2,), name="mapping_lam") | models.Const1D( + ch, name="channel #" + ) + ident1 = models.Identity(1, name="identity_lam") + ident1._inputs = ("lam",) # noqa: SLF001 chan_v23 = v23[c] v23chan_backward = chan_v23.inverse @@ -624,31 +756,39 @@ def abl_to_v2v3l(input_model, reference_files): v23c = v23_spatial & ident1 sel[ch] = v23c - wave_range_mapper = selector.LabelMapperRange(('alpha', 'beta', 'lam'), dict_mapper, - inputs_mapping=models.Mapping([2, ])) + wave_range_mapper = selector.LabelMapperRange( + ("alpha", "beta", "lam"), + dict_mapper, + inputs_mapping=models.Mapping( + [ + 2, + ] + ), + ) wave_range_mapper.inverse = wave_range_mapper.copy() - abl2v2v3l = selector.RegionsSelector(('alpha', 'beta', 'lam'), ('v2', 'v3', 'lam'), - label_mapper=wave_range_mapper, - selector=sel) + abl2v2v3l = selector.RegionsSelector( + ("alpha", "beta", "lam"), ("v2", "v3", "lam"), label_mapper=wave_range_mapper, selector=sel + ) return abl2v2v3l -exp_type2transform = {'mir_image': imaging, - 'mir_tacq': imaging, - 'mir_lyot': imaging, - 'mir_4qpm': imaging, - 'mir_coroncal': imaging, - 'mir_lrs-fixedslit': lrs, - 'mir_lrs-slitless': lrs, - 'mir_mrs': ifu, - 'mir_flatmrs': not_implemented_mode, - 'mir_flatimage': not_implemented_mode, - 'mir_flat-mrs': not_implemented_mode, - 'mir_flat-image': not_implemented_mode, - 'mir_dark': not_implemented_mode, - 'mir_taconfirm': imaging, - } +exp_type2transform = { + "mir_image": imaging, + "mir_tacq": imaging, + "mir_lyot": imaging, + "mir_4qpm": imaging, + "mir_coroncal": imaging, + "mir_lrs-fixedslit": lrs, + "mir_lrs-slitless": lrs, + "mir_mrs": ifu, + "mir_flatmrs": not_implemented_mode, + "mir_flatimage": not_implemented_mode, + "mir_flat-mrs": not_implemented_mode, + "mir_flat-image": not_implemented_mode, + "mir_dark": not_implemented_mode, + "mir_taconfirm": imaging, +} def get_wavelength_range(input_model, path=None): @@ -659,48 +799,57 @@ def get_wavelength_range(input_model, path=None): Parameters ---------- - input_model : `jwst.datamodels.ImageModel` + input_model : ImageModel Data model after assign_wcs has been run. path : str Directory where the reference file is. (optional) + + Returns + ------- + wave_range : set + A set of tuples containing the channel and wavelength + range for each channel used in the WCS. """ - fname = input_model.meta.ref_file.wavelengthrange.name.split('/')[-1] - if path is None and not os.path.exists(fname): - raise IOError("Reference file {0} not found. Please specify a path.".format(fname)) + fname = Path(input_model.meta.ref_file.wavelengthrange.name.split("/")[-1]) + if path is None and not fname.exists(): + raise OSError(f"Reference file {fname} not found. Please specify a path.") else: - fname = os.path.join(path, fname) + fname = Path(path) / fname f = WavelengthrangeModel(fname) - wave_range = f.tree['wavelengthrange'].copy() - wave_channels = f.tree['channels'] + wave_range = f.tree["wavelengthrange"].copy() + wave_channels = f.tree["channels"] f.close() - wr = dict(zip(wave_channels, wave_range)) + wr = dict(zip(wave_channels, wave_range, strict=True)) channel = input_model.meta.instrument.channel band = input_model.meta.instrument.band - return dict([(ch + band, wr[ch + band]) for ch in channel]) + return {(ch + band, wr[ch + band]) for ch in channel} def store_dithered_position(input_model): - """Store the location of the dithered pointing - location in the dither metadata + """ + Store the location of the dithered pointing location in the dither metadata. Parameters ---------- - input_model : `jwst.datamodels.ImageModel` + input_model : ImageModel Data model containing dither offset information """ # V2_ref and v3_ref should be in arcsec idltov23 = IdealToV2V3( input_model.meta.wcsinfo.v3yangle, - input_model.meta.wcsinfo.v2_ref, input_model.meta.wcsinfo.v3_ref, - input_model.meta.wcsinfo.vparity + input_model.meta.wcsinfo.v2_ref, + input_model.meta.wcsinfo.v3_ref, + input_model.meta.wcsinfo.vparity, ) - dithered_v2, dithered_v3 = idltov23(input_model.meta.dither.x_offset, input_model.meta.dither.y_offset) + dithered_v2, dithered_v3 = idltov23( + input_model.meta.dither.x_offset, input_model.meta.dither.y_offset + ) - v23toworld = input_model.meta.wcs.get_transform('v2v3', 'world') + v23toworld = input_model.meta.wcs.get_transform("v2v3", "world") # v23toworld requires a wavelength along with v2, v3, but value does not affect return dithered_ra, dithered_dec, _ = v23toworld(dithered_v2, dithered_v3, 0.0) diff --git a/jwst/assign_wcs/nircam.py b/jwst/assign_wcs/nircam.py index 1c8c7a2fea..2c7dc45851 100644 --- a/jwst/assign_wcs/nircam.py +++ b/jwst/assign_wcs/nircam.py @@ -8,14 +8,21 @@ import asdf -from stdatamodels.jwst.datamodels import (ImageModel, NIRCAMGrismModel, DistortionModel) -from stdatamodels.jwst.transforms.models import (NIRCAMForwardRowGrismDispersion, - NIRCAMForwardColumnGrismDispersion, - NIRCAMBackwardGrismDispersion) +from stdatamodels.jwst.datamodels import ImageModel, NIRCAMGrismModel, DistortionModel +from stdatamodels.jwst.transforms.models import ( + NIRCAMForwardRowGrismDispersion, + NIRCAMForwardColumnGrismDispersion, + NIRCAMBackwardGrismDispersion, +) from . import pointing -from .util import (not_implemented_mode, subarray_transform, velocity_correction, - transform_bbox_from_shape, bounding_box_from_subarray) +from .util import ( + not_implemented_mode, + subarray_transform, + velocity_correction, + transform_bbox_from_shape, + bounding_box_from_subarray, +) from ..lib.reffile_utils import find_row @@ -32,19 +39,17 @@ def create_pipeline(input_model, reference_files): Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing - reference_files : dict {reftype: reference file name} - The dictionary of reference file names and their associated files. + input_model : `~jwst.datamodels.JwstDataModel` + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - - log.debug(f'reference files used in NIRCAM WCS pipeline: {reference_files}') + log.debug(f"reference files used in NIRCAM WCS pipeline: {reference_files}") exp_type = input_model.meta.exposure.type.lower() pipeline = exp_type2transform[exp_type](input_model, reference_files) @@ -53,33 +58,31 @@ def create_pipeline(input_model, reference_files): def imaging(input_model, reference_files): """ - The NIRCAM imaging WCS pipeline. + Create the WCS pipeline for NIRCAM imaging data. + + It includes three coordinate frames - "detector", "v2v3", and "world" Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : ImageModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files - {reftype: reference file name}. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and filteroffset' reference files. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. - - Notes - ----- - It includes three coordinate frames - "detector", "v2v3", and "world", - and uses the "distortion" reference file. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), axes_names=('v2', 'v3'), - unit=(u.arcsec, u.arcsec)) - v2v3vacorr = cf.Frame2D(name='v2v3vacorr', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.arcsec, u.arcsec)) - world = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world') + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + v2v3 = cf.Frame2D( + name="v2v3", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + v2v3vacorr = cf.Frame2D( + name="v2v3vacorr", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + world = cf.CelestialFrame(reference_frame=coord.ICRS(), name="world") distortion = imaging_distortion(input_model, reference_files) subarray2full = subarray_transform(input_model) @@ -94,14 +97,11 @@ def imaging(input_model, reference_files): va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) tel2sky = pointing.v23tosky(input_model) - pipeline = [(detector, distortion), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + pipeline = [(detector, distortion), (v2v3, va_corr), (v2v3vacorr, tel2sky), (world, None)] return pipeline @@ -111,17 +111,18 @@ def imaging_distortion(input_model, reference_files): Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : ImageModel or CubeModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and filteroffset' reference files. Returns ------- - The transform model - + distortion : `astropy.modeling.Model` + The transform from "detector" to "v2v3". """ - dist = DistortionModel(reference_files['distortion']) + dist = DistortionModel(reference_files["distortion"]) transform = dist.model try: @@ -136,19 +137,19 @@ def imaging_distortion(input_model, reference_files): dist.close() # Add an offset for the filter - if reference_files['filteroffset'] is not None: + if reference_files["filteroffset"] is not None: obsfilter = input_model.meta.instrument.filter obspupil = input_model.meta.instrument.pupil - with asdf.open(reference_files['filteroffset']) as filter_offset: - filters = filter_offset.tree['filters'] + with asdf.open(reference_files["filteroffset"]) as filter_offset: + filters = filter_offset.tree["filters"] - match_keys = {'filter': obsfilter, 'pupil': obspupil} + match_keys = {"filter": obsfilter, "pupil": obspupil} row = find_row(filters, match_keys) if row is not None: - col_offset = row.get('column_offset', 'N/A') - row_offset = row.get('row_offset', 'N/A') + col_offset = row.get("column_offset", "N/A") + row_offset = row.get("row_offset", "N/A") log.debug(f"Offsets from filteroffset file are {col_offset}, {row_offset}") - if col_offset != 'N/A' and row_offset != 'N/A': + if col_offset != "N/A" and row_offset != "N/A": transform = Shift(col_offset) & Shift(row_offset) | transform else: log.debug("No match in fitleroffset file.") @@ -158,27 +159,28 @@ def imaging_distortion(input_model, reference_files): bind_bounding_box( transform, transform_bbox_from_shape(input_model.data.shape, order="F") if bbox is None else bbox, - order="F" + order="F", ) return transform def tsgrism(input_model, reference_files): - """Create WCS pipeline for a NIRCAM Time Series Grism observation. + """ + Create WCS pipeline for a NIRCAM Time Series Grism observation. Parameters ---------- - input_model : `~jwst.datamodels.ImagingModel` - The input datamodel, derived from datamodels + input_model : CubeModel + The input data model. reference_files : dict - Dictionary of reference file names {reftype: reference file name}. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion', 'filteroffset' and 'specwcs' reference files. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. + The WCS pipeline, suitable for input into `gwcs.WCS`. Notes ----- @@ -195,23 +197,22 @@ def tsgrism(input_model, reference_files): offset special requirements may be encoded in the X_OFFSET parameter, but those are handled in extract_2d. """ - # make sure this is a grism image if "NRC_TSGRISM" != input_model.meta.exposure.type: - raise ValueError('The input exposure is not a NIRCAM time series grism') + raise ValueError("The input exposure is not a NIRCAM time series grism") if input_model.meta.instrument.module != "A": - raise ValueError('NRC_TSGRISM mode only supports module A') + raise ValueError("NRC_TSGRISM mode only supports module A") if input_model.meta.instrument.pupil != "GRISMR": - raise ValueError('NRC_TSGRIM mode only supports GRISMR') + raise ValueError("NRC_TSGRIM mode only supports GRISMR") frames = create_coord_frames() # translate the x,y detector-in to x,y detector out coordinates # Get the disperser parameters which are defined as a model for each # spectral order - with NIRCAMGrismModel(reference_files['specwcs']) as f: + with NIRCAMGrismModel(reference_files["specwcs"]) as f: displ = f.displ dispx = f.dispx dispy = f.dispy @@ -220,19 +221,23 @@ def tsgrism(input_model, reference_files): orders = f.orders # now create the appropriate model for the grismr - det2det = NIRCAMForwardRowGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - inv_lmodels=invdispl, - inv_xmodels=invdispx) - - det2det.inverse = NIRCAMBackwardGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - inv_lmodels=invdispl, - inv_xmodels=invdispx) + det2det = NIRCAMForwardRowGrismDispersion( + orders, + lmodels=displ, + xmodels=dispx, + ymodels=dispy, + inv_lmodels=invdispl, + inv_xmodels=invdispx, + ) + + det2det.inverse = NIRCAMBackwardGrismDispersion( + orders, + lmodels=displ, + xmodels=dispx, + ymodels=dispy, + inv_lmodels=invdispl, + inv_xmodels=invdispx, + ) # Add in the wavelength shift from the velocity dispersion try: @@ -241,7 +246,7 @@ def tsgrism(input_model, reference_files): pass if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) - log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value)) + log.info(f"Added Barycentric velocity correction: {velocity_corr[1].amplitude.value}") det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1) # input into the forward transform is x,y,x0,y0,order @@ -255,10 +260,10 @@ def tsgrism(input_model, reference_files): xc, yc = (input_model.meta.wcsinfo.siaf_xref_sci, input_model.meta.wcsinfo.siaf_yref_sci) if xc is None: - raise ValueError('XREF_SCI is missing.') + raise ValueError("XREF_SCI is missing.") if yc is None: - raise ValueError('YREF_SCI is missing.') + raise ValueError("YREF_SCI is missing.") xcenter = Const1D(xc) xcenter.inverse = Const1D(xc) @@ -274,13 +279,16 @@ def tsgrism(input_model, reference_files): # get the shift to full frame coordinates sub_trans = subarray_transform(input_model) if sub_trans is not None: - sub2direct = (sub_trans & Identity(1) | Mapping((0, 1, 0, 1, 2)) | - (Identity(2) & xcenter & ycenter & Identity(1)) | - det2det) + sub2direct = ( + sub_trans & Identity(1) + | Mapping((0, 1, 0, 1, 2)) + | (Identity(2) & xcenter & ycenter & Identity(1)) + | det2det + ) else: - sub2direct = (Mapping((0, 1, 0, 1, 2)) | - (Identity(2) & xcenter & ycenter & Identity(1)) | - det2det) + sub2direct = ( + Mapping((0, 1, 0, 1, 2)) | (Identity(2) & xcenter & ycenter & Identity(1)) | det2det + ) # take us from full frame detector to v2v3 distortion = imaging_distortion(input_model, reference_files) & Identity(2) @@ -289,7 +297,7 @@ def tsgrism(input_model, reference_files): va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) & Identity(2) # v2v3 to the sky @@ -302,11 +310,13 @@ def tsgrism(input_model, reference_files): newinverse = Mapping((0, 1, 0, 1)) | setra & setdec & Identity(2) | t2skyinverse tel2sky.inverse = newinverse - pipeline = [(frames['grism_detector'], sub2direct), - (frames['direct_image'], distortion), - (frames['v2v3'], va_corr), - (frames['v2v3vacorr'], tel2sky), - (frames['world'], None)] + pipeline = [ + (frames["grism_detector"], sub2direct), + (frames["direct_image"], distortion), + (frames["v2v3"], va_corr), + (frames["v2v3vacorr"], tel2sky), + (frames["world"], None), + ] return pipeline @@ -317,16 +327,16 @@ def wfss(input_model, reference_files): Parameters ---------- - input_model: `~jwst.datamodels.ImagingModel` - The input datamodel, derived from datamodels - reference_files: dict - Dictionary {reftype: reference file name}. + input_model : ImageModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion', 'filteroffset', and 'specwcs' reference files. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. + The WCS pipeline, suitable for input into `gwcs.WCS`. Notes ----- @@ -358,25 +368,29 @@ def wfss(input_model, reference_files): bounding box are saved in the photometry catalog in units of RA, DEC so they can be translated to pixels by the dispersed image's imaging-wcs. """ - # The input is the grism image if not isinstance(input_model, ImageModel): - raise TypeError('The input data model must be an ImageModel.') + raise TypeError("The input data model must be an ImageModel.") # make sure this is a grism image if "NRC_WFSS" not in input_model.meta.exposure.type: - raise ValueError('The input exposure is not a NIRCAM grism') + raise ValueError("The input exposure is not a NIRCAM grism") # Create the empty detector as a 2D coordinate frame in pixel units - gdetector = cf.Frame2D(name='grism_detector', axes_order=(0, 1), - axes_names=('x_grism', 'y_grism'), unit=(u.pix, u.pix)) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) + gdetector = cf.Frame2D( + name="grism_detector", + axes_order=(0, 1), + axes_names=("x_grism", "y_grism"), + unit=(u.pix, u.pix), + ) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) # translate the x,y detector-in to x,y detector out coordinates # Get the disperser parameters which are defined as a model for each # spectral order - with NIRCAMGrismModel(reference_files['specwcs']) as f: + with NIRCAMGrismModel(reference_files["specwcs"]) as f: displ = f.displ dispx = f.dispx dispy = f.dispy @@ -387,30 +401,36 @@ def wfss(input_model, reference_files): # now create the appropriate model for the grism[R/C] if "GRISMR" in input_model.meta.instrument.pupil: - det2det = NIRCAMForwardRowGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - inv_lmodels=invdispl, - inv_xmodels=invdispx, - inv_ymodels=invdispy) + det2det = NIRCAMForwardRowGrismDispersion( + orders, + lmodels=displ, + xmodels=dispx, + ymodels=dispy, + inv_lmodels=invdispl, + inv_xmodels=invdispx, + inv_ymodels=invdispy, + ) elif "GRISMC" in input_model.meta.instrument.pupil: - det2det = NIRCAMForwardColumnGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - inv_lmodels=invdispl, - inv_xmodels=invdispx, - inv_ymodels=invdispy) - - det2det.inverse = NIRCAMBackwardGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - inv_lmodels=invdispl, - inv_xmodels=invdispx, - inv_ymodels=invdispy) + det2det = NIRCAMForwardColumnGrismDispersion( + orders, + lmodels=displ, + xmodels=dispx, + ymodels=dispy, + inv_lmodels=invdispl, + inv_xmodels=invdispx, + inv_ymodels=invdispy, + ) + + det2det.inverse = NIRCAMBackwardGrismDispersion( + orders, + lmodels=displ, + xmodels=dispx, + ymodels=dispy, + inv_lmodels=invdispl, + inv_xmodels=invdispx, + inv_ymodels=invdispy, + ) # Add in the wavelength shift from the velocity dispersion try: @@ -419,7 +439,7 @@ def wfss(input_model, reference_files): pass if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) - log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value)) + log.info(f"Added Barycentric velocity correction: {velocity_corr[1].amplitude.value}") det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1) # create the pipeline to construct a WCS object for the whole image @@ -440,51 +460,67 @@ def wfss(input_model, reference_files): # pass the x0,y0, wave, order, through the pipeline imagepipe = [] world = image_pipeline.pop()[0] - world.name = 'sky' + world.name = "sky" for cframe, trans in image_pipeline: trans = trans & (Identity(2)) name = cframe.name - cframe.name = name + 'spatial' - spatial_and_spectral = cf.CompositeFrame([cframe, spec], - name=name) + cframe.name = name + "spatial" + spatial_and_spectral = cf.CompositeFrame([cframe, spec], name=name) imagepipe.append((spatial_and_spectral, trans)) # Output frame is Celestial + Spectral - imagepipe.append((cf.CompositeFrame([world, spec], name='world'), None)) + imagepipe.append((cf.CompositeFrame([world, spec], name="world"), None)) grism_pipeline.extend(imagepipe) return grism_pipeline def create_coord_frames(): - gdetector = cf.Frame2D(name='grism_detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - detector = cf.Frame2D(name='full_detector', axes_order=(0, 1), - axes_names=('dx', 'dy'), unit=(u.pix, u.pix)) - v2v3_spatial = cf.Frame2D(name='v2v3_spatial', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.deg, u.deg)) - v2v3vacorr_spatial = cf.Frame2D(name='v2v3vacorr_spatial', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.arcsec, u.arcsec)) - sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='icrs') - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) - frames = {'grism_detector': gdetector, - 'direct_image': cf.CompositeFrame([detector, spec], name='direct_image'), - 'v2v3': cf.CompositeFrame([v2v3_spatial, spec], name='v2v3'), - 'v2v3vacorr': cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr'), - 'world': cf.CompositeFrame([sky_frame, spec], name='world') - } + """ + Create the coordinate frames for NIRCAM imaging and grism modes. + + Returns + ------- + frames : dict + Dictionary of the coordinate frames. + """ + gdetector = cf.Frame2D(name="grism_detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + detector = cf.Frame2D( + name="full_detector", axes_order=(0, 1), axes_names=("dx", "dy"), unit=(u.pix, u.pix) + ) + v2v3_spatial = cf.Frame2D( + name="v2v3_spatial", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.deg, u.deg) + ) + v2v3vacorr_spatial = cf.Frame2D( + name="v2v3vacorr_spatial", + axes_order=(0, 1), + axes_names=("v2", "v3"), + unit=(u.arcsec, u.arcsec), + ) + sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name="icrs") + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + frames = { + "grism_detector": gdetector, + "direct_image": cf.CompositeFrame([detector, spec], name="direct_image"), + "v2v3": cf.CompositeFrame([v2v3_spatial, spec], name="v2v3"), + "v2v3vacorr": cf.CompositeFrame([v2v3vacorr_spatial, spec], name="v2v3vacorr"), + "world": cf.CompositeFrame([sky_frame, spec], name="world"), + } return frames -exp_type2transform = {'nrc_image': imaging, - 'nrc_wfss': wfss, - 'nrc_tacq': imaging, - 'nrc_taconfirm': imaging, - 'nrc_coron': imaging, - 'nrc_focus': imaging, - 'nrc_tsimage': imaging, - 'nrc_tsgrism': tsgrism, - 'nrc_led': not_implemented_mode, - 'nrc_dark': not_implemented_mode, - 'nrc_flat': not_implemented_mode, - 'nrc_grism': not_implemented_mode, - } +exp_type2transform = { + "nrc_image": imaging, + "nrc_wfss": wfss, + "nrc_tacq": imaging, + "nrc_taconfirm": imaging, + "nrc_coron": imaging, + "nrc_focus": imaging, + "nrc_tsimage": imaging, + "nrc_tsgrism": tsgrism, + "nrc_led": not_implemented_mode, + "nrc_dark": not_implemented_mode, + "nrc_flat": not_implemented_mode, + "nrc_grism": not_implemented_mode, +} diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 754cb50f06..157e1c2df0 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -1,4 +1,5 @@ import logging +import warnings import asdf from astropy import coordinates as coord @@ -10,14 +11,20 @@ from gwcs import wcs from stdatamodels.jwst.datamodels import ImageModel, NIRISSGrismModel, DistortionModel -from stdatamodels.jwst.transforms.models import (NirissSOSSModel, - NIRISSForwardRowGrismDispersion, - NIRISSBackwardGrismDispersion, - NIRISSForwardColumnGrismDispersion) - -from .util import (not_implemented_mode, subarray_transform, - velocity_correction, bounding_box_from_subarray, - transform_bbox_from_shape) +from stdatamodels.jwst.transforms.models import ( + NirissSOSSModel, + NIRISSForwardRowGrismDispersion, + NIRISSBackwardGrismDispersion, + NIRISSForwardColumnGrismDispersion, +) + +from .util import ( + not_implemented_mode, + subarray_transform, + velocity_correction, + bounding_box_from_subarray, + transform_bbox_from_shape, +) from . import pointing from ..lib.reffile_utils import find_row @@ -28,23 +35,22 @@ def create_pipeline(input_model, reference_files): - """Create the WCS pipeline based on EXP_TYPE. + """ + Create the WCS pipeline based on EXP_TYPE. Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : JwstDataModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files - {reftype: reference file name}. + Mapping between reftype (keys) and reference file name (vals). Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - log.debug(f'reference files used in NIRISS WCS pipeline: {reference_files}') + log.debug(f"reference files used in NIRISS WCS pipeline: {reference_files}") exp_type = input_model.meta.exposure.type.lower() pipeline = exp_type2transform[exp_type](input_model, reference_files) @@ -53,45 +59,48 @@ def create_pipeline(input_model, reference_files): def niriss_soss_set_input(model, order_number): """ - Extract a WCS fr a specific spectral order. + Extract a WCS for a specific spectral order. Parameters ---------- model : `~jwst.datamodels.ImageModel` An instance of an ImageModel order_number : int - the spectral order + The spectral order Returns ------- - WCS - the WCS corresponding to the spectral order. - + gwcs.WCS + The WCS corresponding to the spectral order. """ - # Make sure the spectral order is available. if order_number < 1 or order_number > 3: - raise ValueError('Order must be between 1 and 3') + raise ValueError("Order must be between 1 and 3") # Return the correct transform based on the order_number obj = model.meta.wcs.forward_transform.get_model(order_number) # use the size of the input subarray - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) - sky = cf.CelestialFrame(reference_frame=coord.ICRS(), - axes_names=('ra', 'dec'), - axes_order=(0, 1), unit=(u.deg, u.deg), name='sky') - world = cf.CompositeFrame([sky, spec], name='world') - pipeline = [(detector, obj), - (world, None) - ] + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + sky = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_names=("ra", "dec"), + axes_order=(0, 1), + unit=(u.deg, u.deg), + name="sky", + ) + world = cf.CompositeFrame([sky, spec], name="world") + pipeline = [(detector, obj), (world, None)] return wcs.WCS(pipeline) def _niriss_order_bounding_box(input_model, order): import numpy as np + bbox_y = np.array([-0.5, input_model.meta.subarray.ysize - 0.5]) bbox_x = np.array([-0.5, input_model.meta.subarray.xsize - 0.5]) @@ -102,42 +111,62 @@ def _niriss_order_bounding_box(input_model, order): elif order == 3: return tuple(bbox_y), tuple(bbox_x) else: - raise ValueError(f'Invalid spectral order: {order} provided. Spectral order must be 1, 2, or 3.') + raise ValueError( + f"Invalid spectral order: {order} provided. Spectral order must be 1, 2, or 3." + ) def niriss_bounding_box(input_model): - bbox = {(order,): _niriss_order_bounding_box(input_model, order) - for order in [1, 2, 3]} + """ + Create a bounding box for the NIRISS model. + + .. deprecated:: 1.17.2 + :py:func:`niriss_bounding_box` has been deprecated and will be removed + in a future release. + + Parameters + ---------- + input_model : JwstDataModel + The input datamodel. + + Returns + ------- + CompoundBoundingBox + The bounding box for the NIRISS model. + """ + warnings.warn( + "'niriss_bounding_bo()' has been deprecated since 1.17.2 and " + "will be removed in a future release. ", + DeprecationWarning, + stacklevel=2, + ) + + bbox = {(order,): _niriss_order_bounding_box(input_model, order) for order in [1, 2, 3]} model = input_model.meta.wcs.forward_transform - return CompoundBoundingBox.validate(model, bbox, slice_args=[('spectral_order', True)], order='F') + return CompoundBoundingBox.validate( + model, bbox, slice_args=[("spectral_order", True)], order="F" + ) def niriss_soss(input_model, reference_files): """ - The NIRISS SOSS WCS pipeline. + Create the WCS pipeline for NIRISS SOSS data. + + It includes TWO coordinate frames - "detector" and "world". Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : ImageModel, IFUImageModel, or CubeModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files - {reftype: reference file name}. + Mapping between reftype (keys) and reference file name (vals). + Requires 'specwcs' reference file. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. - - Notes - ----- - It includes tWO coordinate frames - - "detector" and "world". - - It uses the "specwcs" reference file. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - # Get the target RA and DEC, they will be used for setting the WCS RA # and DEC based on a conversation with Kevin Volk. try: @@ -145,26 +174,35 @@ def niriss_soss(input_model, reference_files): target_dec = float(input_model.meta.target.dec) except TypeError: # There was an error getting the target RA and DEC, so we are not going to continue. - raise ValueError('Problem getting the TARG_RA or TARG_DEC from input model {}'.format(input_model)) + raise ValueError( + "Problem getting the TARG_RA or TARG_DEC from input model {input_model}" + ) from None # Define the frames - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) - sky = cf.CelestialFrame(reference_frame=coord.ICRS(), - axes_names=('ra', 'dec'), - axes_order=(0, 1), unit=(u.deg, u.deg), name='sky') - world = cf.CompositeFrame([sky, spec], name='world') + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + sky = cf.CelestialFrame( + reference_frame=coord.ICRS(), + axes_names=("ra", "dec"), + axes_order=(0, 1), + unit=(u.deg, u.deg), + name="sky", + ) + world = cf.CompositeFrame([sky, spec], name="world") try: - bytesio_or_path = reference_files['specwcs'] + bytesio_or_path = reference_files["specwcs"] with asdf.open(bytesio_or_path) as af: wl1 = af.tree[1].copy() wl2 = af.tree[2].copy() wl3 = af.tree[3].copy() except Exception as e: - raise IOError(f"Error reading wavelength correction from {reference_files['specwcs']}") from e + raise OSError( + f"Error reading wavelength correction from {reference_files['specwcs']}" + ) from e velosys = input_model.meta.wcsinfo.velosys if velosys is not None: @@ -176,15 +214,15 @@ def niriss_soss(input_model, reference_files): # Reverse the order of inputs passed to Tabular because it's in python order in modeling. # Consider changing it in modeling ? - cm_order1 = (Mapping((0, 1, 1, 0)) | - (Const1D(target_ra) & Const1D(target_dec) & wl1) - ).rename('Order1') - cm_order2 = (Mapping((0, 1, 1, 0)) | - (Const1D(target_ra) & Const1D(target_dec) & wl2) - ).rename('Order2') - cm_order3 = (Mapping((0, 1, 1, 0)) | - (Const1D(target_ra) & Const1D(target_dec) & wl3) - ).rename('Order3') + cm_order1 = (Mapping((0, 1, 1, 0)) | (Const1D(target_ra) & Const1D(target_dec) & wl1)).rename( + "Order1" + ) + cm_order2 = (Mapping((0, 1, 1, 0)) | (Const1D(target_ra) & Const1D(target_dec) & wl2)).rename( + "Order2" + ) + cm_order3 = (Mapping((0, 1, 1, 0)) | (Const1D(target_ra) & Const1D(target_dec) & wl3)).rename( + "Order3" + ) subarray2full = subarray_transform(input_model) if subarray2full is not None: @@ -192,55 +230,52 @@ def niriss_soss(input_model, reference_files): cm_order2 = subarray2full | cm_order2 cm_order3 = subarray2full | cm_order3 - bbox = ((-0.5, input_model.meta.subarray.xsize - 0.5), - (-0.5, input_model.meta.subarray.ysize - 0.5)) - bind_bounding_box(cm_order1, bbox, order='F') - bind_bounding_box(cm_order2, bbox, order='F') - bind_bounding_box(cm_order3, bbox, order='F') + bbox = ( + (-0.5, input_model.meta.subarray.xsize - 0.5), + (-0.5, input_model.meta.subarray.ysize - 0.5), + ) + bind_bounding_box(cm_order1, bbox, order="F") + bind_bounding_box(cm_order2, bbox, order="F") + bind_bounding_box(cm_order3, bbox, order="F") # Define the transforms, they should accept (x,y) and return (ra, dec, lambda) - soss_model = NirissSOSSModel([1, 2, 3], - [cm_order1, cm_order2, cm_order3] - ).rename('3-order SOSS Model') + soss_model = NirissSOSSModel([1, 2, 3], [cm_order1, cm_order2, cm_order3]).rename( + "3-order SOSS Model" + ) # Define the pipeline based on the frames and models above. - pipeline = [(detector, soss_model), - (world, None) - ] + pipeline = [(detector, soss_model), (world, None)] return pipeline def imaging(input_model, reference_files): """ - The NIRISS imaging WCS pipeline. + Create the WCS pipeline for NIRISS imaging data. + + It includes three coordinate frames - "detector" "v2v3" and "world". Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : ImageModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files - {reftype: reference file name}. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' reference file. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. - - Notes - ----- - It includes three coordinate frames - - "detector" "v2v3" and "world". - It uses the "distortion" reference file. + The WCS pipeline, suitable for input into `gwcs.WCS`. """ - detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix)) - v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), axes_names=('v2', 'v3'), - unit=(u.arcsec, u.arcsec)) - v2v3vacorr = cf.Frame2D(name='v2v3vacorr', axes_order=(0, 1), - axes_names=('v2', 'v3'), unit=(u.arcsec, u.arcsec)) - world = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world') + detector = cf.Frame2D(name="detector", axes_order=(0, 1), unit=(u.pix, u.pix)) + v2v3 = cf.Frame2D( + name="v2v3", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + v2v3vacorr = cf.Frame2D( + name="v2v3vacorr", axes_order=(0, 1), axes_names=("v2", "v3"), unit=(u.arcsec, u.arcsec) + ) + world = cf.CelestialFrame(reference_frame=coord.ICRS(), name="world") distortion = imaging_distortion(input_model, reference_files) @@ -248,7 +283,7 @@ def imaging(input_model, reference_files): va_corr = pointing.dva_corr_model( va_scale=input_model.meta.velocity_aberration.scale_factor, v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref + v3_ref=input_model.meta.wcsinfo.v3_ref, ) subarray2full = subarray_transform(input_model) @@ -260,33 +295,32 @@ def imaging(input_model, reference_files): bind_bounding_box(distortion, bounding_box_from_subarray(input_model, order="F"), order="F") tel2sky = pointing.v23tosky(input_model) - pipeline = [(detector, distortion), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + pipeline = [(detector, distortion), (v2v3, va_corr), (v2v3vacorr, tel2sky), (world, None)] return pipeline def imaging_distortion(input_model, reference_files): - """ Create the transform from "detector" to "v2v3". + """ + Create the transform from "detector" to "v2v3". Parameters ---------- - input_model : `~jwst.datamodel.JwstDataModel` - Input datamodel for processing + input_model : ImageModel + The input data model. reference_files : dict - The dictionary of reference file names and their associated files. + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and filteroffset' reference files. Returns ------- - The transform model - + distortion : `astropy.modeling.Model` + The transform from "detector" to "v2v3". """ - dist = DistortionModel(reference_files['distortion']) + dist = DistortionModel(reference_files["distortion"]) distortion = dist.model try: - bbox = distortion.bounding_box.bounding_box(order='F') + bbox = distortion.bounding_box.bounding_box(order="F") except NotImplementedError: # Check if the transform in the reference file has a ``bounding_box``. # If not set a ``bounding_box`` equal to the size of the image after @@ -295,20 +329,20 @@ def imaging_distortion(input_model, reference_files): dist.close() # Add an offset for the filter - if reference_files['filteroffset'] is not None: + if reference_files["filteroffset"] is not None: obsfilter = input_model.meta.instrument.filter obspupil = input_model.meta.instrument.pupil - with asdf.open(reference_files['filteroffset']) as filter_offset: - filters = filter_offset.tree['filters'] + with asdf.open(reference_files["filteroffset"]) as filter_offset: + filters = filter_offset.tree["filters"] - match_keys = {'filter': obsfilter, 'pupil': obspupil} + match_keys = {"filter": obsfilter, "pupil": obspupil} row = find_row(filters, match_keys) if row is not None: - col_offset = row.get('column_offset', 'N/A') - row_offset = row.get('row_offset', 'N/A') + col_offset = row.get("column_offset", "N/A") + row_offset = row.get("row_offset", "N/A") log.info(f"Offsets from filteroffset file are {col_offset}, {row_offset}") - if col_offset != 'N/A' and row_offset != 'N/A': + if col_offset != "N/A" and row_offset != "N/A": distortion = Shift(col_offset) & Shift(row_offset) | distortion else: log.debug("No match in fitleroffset file.") @@ -318,7 +352,7 @@ def imaging_distortion(input_model, reference_files): bind_bounding_box( distortion, transform_bbox_from_shape(input_model.data.shape, order="F") if bbox is None else bbox, - order="F" + order="F", ) return distortion @@ -330,24 +364,19 @@ def wfss(input_model, reference_files): Parameters ---------- - input_model: `~jwst.datamodels.ImagingModel` - The input datamodel, derived from datamodels - reference_files: dict - Dictionary specifying reference file names + input_model : ImageModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'distortion' and 'specwcs' reference files. Returns ------- pipeline : list - The pipeline list that is returned is suitable for - input into gwcs.wcs.WCS to create a GWCS object. + The WCS pipeline, suitable for input into `gwcs.WCS`. Notes ----- - reference_files = { - "specwcs": 'GR150C_F090W.asdf' - "distortion": 'NIRISS_FULL_distortion.asdf' - } - The tree in the grism reference file has a section for each order/beam as well as the link to the filter data file, not sure if there will be a separate passband reference file needed for the wavelength scaling or the @@ -386,27 +415,30 @@ def wfss(input_model, reference_files): which contains the filter names. Source catalog use moved to extract_2d. - """ - # The input is the grism image if not isinstance(input_model, ImageModel): - raise TypeError('The input data model must be an ImageModel.') + raise TypeError("The input data model must be an ImageModel.") # make sure this is a grism image if "NIS_WFSS" != input_model.meta.exposure.type: - raise ValueError('The input exposure is not NIRISS grism') + raise ValueError("The input exposure is not NIRISS grism") # Create the empty detector as a 2D coordinate frame in pixel units - gdetector = cf.Frame2D(name='grism_detector', axes_order=(0, 1), - axes_names=('x_grism', 'y_grism'), unit=(u.pix, u.pix)) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) + gdetector = cf.Frame2D( + name="grism_detector", + axes_order=(0, 1), + axes_names=("x_grism", "y_grism"), + unit=(u.pix, u.pix), + ) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) # translate the x,y detector-in to x,y detector out coordinates # Get the disperser parameters which are defined as a model for each # spectral order - with NIRISSGrismModel(reference_files['specwcs']) as f: + with NIRISSGrismModel(reference_files["specwcs"]) as f: dispx = f.dispx dispy = f.dispy displ = f.displ @@ -417,33 +449,26 @@ def wfss(input_model, reference_files): # This is the actual rotation from the input model fwcpos = input_model.meta.instrument.filter_position if fwcpos is None: - raise ValueError('FWCPOS keyword value not found in input image') + raise ValueError("FWCPOS keyword value not found in input image") # sep the row and column grism models # In "DMS" orientation (same parity as the sky), the GR150C spectra # are aligned more closely with the rows, and the GR150R spectra are # aligned more closely with the columns. - if input_model.meta.instrument.filter.endswith('C'): - det2det = NIRISSForwardRowGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - theta=fwcpos - fwcpos_ref) - elif input_model.meta.instrument.filter.endswith('R'): - det2det = NIRISSForwardColumnGrismDispersion(orders, - lmodels=displ, - xmodels=dispx, - ymodels=dispy, - theta=fwcpos - fwcpos_ref) + if input_model.meta.instrument.filter.endswith("C"): + det2det = NIRISSForwardRowGrismDispersion( + orders, lmodels=displ, xmodels=dispx, ymodels=dispy, theta=fwcpos - fwcpos_ref + ) + elif input_model.meta.instrument.filter.endswith("R"): + det2det = NIRISSForwardColumnGrismDispersion( + orders, lmodels=displ, xmodels=dispx, ymodels=dispy, theta=fwcpos - fwcpos_ref + ) else: - raise ValueError("FILTER keyword {} is not valid." - .format(input_model.meta.instrument.filter)) - - backward = NIRISSBackwardGrismDispersion(orders, - lmodels=invdispl, - xmodels=dispx, - ymodels=dispy, - theta=-(fwcpos - fwcpos_ref)) + raise ValueError("FILTER keyword {input_model.meta.instrument.filter} is not valid.") + + backward = NIRISSBackwardGrismDispersion( + orders, lmodels=invdispl, xmodels=dispx, ymodels=dispy, theta=-(fwcpos - fwcpos_ref) + ) det2det.inverse = backward # Add in the wavelength shift from the velocity dispersion @@ -453,7 +478,7 @@ def wfss(input_model, reference_files): pass if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) - log.info("Added Barycentric velocity correction: {}".format(velocity_corr[1].amplitude.value)) + log.info(f"Added Barycentric velocity correction: {velocity_corr[1].amplitude.value}") det2det = det2det | Mapping((0, 1, 2, 3)) | Identity(2) & velocity_corr & Identity(1) # create the pipeline to construct a WCS object for the whole image @@ -478,26 +503,27 @@ def wfss(input_model, reference_files): imagepipe = [] world = image_pipeline.pop()[0] - world.name = 'sky' + world.name = "sky" for cframe, trans in image_pipeline: trans = trans & (Identity(2)) name = cframe.name - cframe.name = name + 'spatial' + cframe.name = name + "spatial" spatial_and_spectral = cf.CompositeFrame([cframe, spec], name=name) imagepipe.append((spatial_and_spectral, trans)) - imagepipe.append((cf.CompositeFrame([world, spec], name='world'), None)) + imagepipe.append((cf.CompositeFrame([world, spec], name="world"), None)) grism_pipeline.extend(imagepipe) return grism_pipeline -exp_type2transform = {'nis_image': imaging, - 'nis_wfss': wfss, - 'nis_soss': niriss_soss, - 'nis_ami': imaging, - 'nis_tacq': imaging, - 'nis_taconfirm': imaging, - 'nis_focus': imaging, - 'nis_dark': not_implemented_mode, - 'nis_lamp': not_implemented_mode, - } +exp_type2transform = { + "nis_image": imaging, + "nis_wfss": wfss, + "nis_soss": niriss_soss, + "nis_ami": imaging, + "nis_tacq": imaging, + "nis_taconfirm": imaging, + "nis_focus": imaging, + "nis_dark": not_implemented_mode, + "nis_lamp": not_implemented_mode, +} diff --git a/jwst/assign_wcs/nirspec.py b/jwst/assign_wcs/nirspec.py index 146e2b4def..867dd9b4d8 100644 --- a/jwst/assign_wcs/nirspec.py +++ b/jwst/assign_wcs/nirspec.py @@ -2,8 +2,8 @@ Tools to create the WCS pipeline NIRSPEC modes. Calls create_pipeline() which redirects based on EXP_TYPE. - """ + import logging import numpy as np import copy @@ -16,34 +16,55 @@ from gwcs import coordinate_frames as cf from gwcs.wcstools import grid_from_bounding_box -from stdatamodels.jwst.datamodels import (CollimatorModel, CameraModel, DisperserModel, FOREModel, - IFUFOREModel, MSAModel, OTEModel, IFUPostModel, IFUSlicerModel, - WavelengthrangeModel, FPAModel) -from stdatamodels.jwst.transforms.models import (Rotation3DToGWA, DirCos2Unitless, Slit2Msa, - AngleFromGratingEquation, WavelengthFromGratingEquation, - Gwa2Slit, Unitless2DirCos, Logical, Slit, Snell, - RefractionIndexFromPrism) - -from .util import ( - MSAFileError, - NoDataOnDetectorError, - not_implemented_mode, - velocity_correction +from stdatamodels.jwst.datamodels import ( + CollimatorModel, + CameraModel, + DisperserModel, + FOREModel, + IFUFOREModel, + MSAModel, + OTEModel, + IFUPostModel, + IFUSlicerModel, + WavelengthrangeModel, + FPAModel, ) +from stdatamodels.jwst.transforms.models import ( + Rotation3DToGWA, + DirCos2Unitless, + Slit2Msa, + AngleFromGratingEquation, + WavelengthFromGratingEquation, + Gwa2Slit, + Unitless2DirCos, + Logical, + Slit, + Snell, + RefractionIndexFromPrism, +) + +from .util import MSAFileError, NoDataOnDetectorError, not_implemented_mode, velocity_correction from . import pointing from ..lib.exposure_types import is_nrs_ifu_lamp log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -FIXED_SLIT_NUMS = {'NONE': 0, 'S200A1': 1, 'S200A2': 2, - 'S400A1': 3, 'S1600A1': 4, 'S200B1': 5} +FIXED_SLIT_NUMS = {"NONE": 0, "S200A1": 1, "S200A2": 2, "S400A1": 3, "S1600A1": 4, "S200B1": 5} # Approximate fallback values for MSA slit scaling MSA_SLIT_SCALES = (1.35, 1.15) -__all__ = ["create_pipeline", "imaging", "ifu", "slits_wcs", "get_open_slits", "nrs_wcs_set_input", - "nrs_ifu_wcs", "get_spectral_order_wrange"] +__all__ = [ + "create_pipeline", + "imaging", + "ifu", + "slits_wcs", + "get_open_slits", + "nrs_wcs_set_input", + "nrs_ifu_wcs", + "get_spectral_order_wrange", +] def create_pipeline(input_model, reference_files, slit_y_range): @@ -52,27 +73,33 @@ def create_pipeline(input_model, reference_files, slit_y_range): Parameters ---------- - input_model : `~jwst.datamodels.ImageModel`, `~jwst.datamodels.IFUImageModel`, `~jwst.datamodels.CubeModel` - The input exposure. + input_model : JwstDataModel + The input data model. reference_files : dict - {reftype: reference_file_name} mapping. - slit_y_range : list + Mapping between reftype (keys) and reference file name (vals). + slit_y_range : tuple The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ exp_type = input_model.meta.exposure.type.lower() if input_model.meta.instrument.grating.lower() == "mirror": pipeline = imaging(input_model, reference_files) else: - pipeline = exp_type2transform[exp_type](input_model, reference_files, slit_y_range=slit_y_range) + pipeline = exp_type2transform[exp_type]( + input_model, reference_files, slit_y_range=slit_y_range + ) if pipeline: - log.info("Created a NIRSPEC {0} pipeline with references {1}".format( - exp_type, reference_files)) + log.info(f"Created a NIRSPEC {exp_type} pipeline with references {reference_files}") return pipeline def imaging(input_model, reference_files): """ - Imaging pipeline. + Create the WCS pipeline for NIRSpec imaging data. It has the following coordinate frames: "detector" : the science frame @@ -82,9 +109,22 @@ def imaging(input_model, reference_files): "oteip" : after the FWA "v2v3" and "world" + Parameters + ---------- + input_model : JwstDataModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires 'disperser', 'collimator', 'wavelengthrange', 'fpa', 'camera', + 'ote', and 'fore' reference files. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ # Get the corrected disperser model - disperser = get_disperser(input_model, reference_files['disperser']) + disperser = get_disperser(input_model, reference_files["disperser"]) # DMS to SCA transform dms2detector = dms_to_sca(input_model) @@ -94,25 +134,24 @@ def imaging(input_model, reference_files): gwa_through = Const1D(-1) * Identity(1) & Const1D(-1) * Identity(1) & Identity(1) - angles = [disperser['theta_x'], disperser['theta_y'], - disperser['theta_z'], disperser['tilt_y']] - rotation = Rotation3DToGWA(angles, axes_order="xyzy", name='rotation').inverse - dircos2unitless = DirCos2Unitless(name='directional_cosines2unitless') + angles = [disperser["theta_x"], disperser["theta_y"], disperser["theta_z"], disperser["tilt_y"]] + rotation = Rotation3DToGWA(angles, axes_order="xyzy", name="rotation").inverse + dircos2unitless = DirCos2Unitless(name="directional_cosines2unitless") - col_model = CollimatorModel(reference_files['collimator']) + col_model = CollimatorModel(reference_files["collimator"]) col = col_model.model col_model.close() # Get the default spectral order and wavelength range and record them in the model. - sporder, wrange = get_spectral_order_wrange(input_model, reference_files['wavelengthrange']) + sporder, wrange = get_spectral_order_wrange(input_model, reference_files["wavelengthrange"]) input_model.meta.wcsinfo.waverange_start = wrange[0] input_model.meta.wcsinfo.waverange_end = wrange[1] input_model.meta.wcsinfo.spectral_order = sporder - lam = wrange[0] + (wrange[1] - wrange[0]) * .5 + lam = wrange[0] + (wrange[1] - wrange[0]) * 0.5 # Scale wavelengths to microns if msa coordinates are terminal - if input_model.meta.instrument.filter == 'OPAQUE': + if input_model.meta.instrument.filter == "OPAQUE": lam *= 1e6 lam_model = Mapping((0, 1, 1)) | Identity(2) & Const1D(lam) @@ -123,50 +162,50 @@ def imaging(input_model, reference_files): # Create coordinate frames in the NIRSPEC WCS pipeline # "detector", "gwa", "msa", "oteip", "v2v3", "v2v3vacorr", "world" det, sca, gwa, msa_frame, oteip, v2v3, v2v3vacorr, world = create_imaging_frames() - if input_model.meta.instrument.filter != 'OPAQUE': - # MSA to OTEIP transform - msa2ote = msa_to_oteip(reference_files) - msa2oteip = msa2ote | Mapping((0, 1), n_inputs=3) - map1 = Mapping((0, 1, 0, 1)) - minv = msa2ote.inverse - del minv.inverse - msa2oteip.inverse = map1 | minv | Mapping((0, 1), n_inputs=3) - - # OTEIP to V2,V3 transform - with OTEModel(reference_files['ote']) as f: - oteip2v23 = f.model - - # Compute differential velocity aberration (DVA) correction: - va_corr = pointing.dva_corr_model( - va_scale=input_model.meta.velocity_aberration.scale_factor, - v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref - ) - - # V2, V3 to world (RA, DEC) transform - tel2sky = pointing.v23tosky(input_model) - - imaging_pipeline = [(det, dms2detector), - (sca, det2gwa), - (gwa, gwa2msa), - (msa_frame, msa2oteip), - (oteip, oteip2v23), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] - else: + if input_model.meta.instrument.filter == "OPAQUE": # Pipeline ends with MSA coordinates - imaging_pipeline = [(det, dms2detector), - (sca, det2gwa), - (gwa, gwa2msa), - (msa_frame, None)] + imaging_pipeline = [(det, dms2detector), (sca, det2gwa), (gwa, gwa2msa), (msa_frame, None)] + return imaging_pipeline + + # MSA to OTEIP transform + msa2ote = msa_to_oteip(reference_files) + msa2oteip = msa2ote | Mapping((0, 1), n_inputs=3) + map1 = Mapping((0, 1, 0, 1)) + minv = msa2ote.inverse + del minv.inverse + msa2oteip.inverse = map1 | minv | Mapping((0, 1), n_inputs=3) + + # OTEIP to V2,V3 transform + with OTEModel(reference_files["ote"]) as f: + oteip2v23 = f.model + + # Compute differential velocity aberration (DVA) correction: + va_corr = pointing.dva_corr_model( + va_scale=input_model.meta.velocity_aberration.scale_factor, + v2_ref=input_model.meta.wcsinfo.v2_ref, + v3_ref=input_model.meta.wcsinfo.v3_ref, + ) + + # V2, V3 to world (RA, DEC) transform + tel2sky = pointing.v23tosky(input_model) + + imaging_pipeline = [ + (det, dms2detector), + (sca, det2gwa), + (gwa, gwa2msa), + (msa_frame, msa2oteip), + (oteip, oteip2v23), + (v2v3, va_corr), + (v2v3vacorr, tel2sky), + (world, None), + ] return imaging_pipeline -def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): +def ifu(input_model, reference_files, slit_y_range=(-0.55, 0.55)): """ - The Nirspec IFU WCS pipeline. + Create the WCS pipeline for Nirspec IFU data. The coordinate frames are: "detector" : the science frame @@ -180,41 +219,50 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel The input data model. reference_files : dict - The reference files used for this mode. - slit_y_range : list - The slit dimensions relative to the center of the slit. + Mapping between reftype (keys) and reference file name (vals). + Requires the 'ifufore', 'ifuslicer', 'ifupost', 'disperser', 'wavelengthrange', + 'fpa', 'camera', 'collimator', 'fore', and 'ote' reference files. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ detector = input_model.meta.instrument.detector grating = input_model.meta.instrument.grating - filter = input_model.meta.instrument.filter + filt = input_model.meta.instrument.filter # Check for ifu reference files - if reference_files['ifufore'] is None and \ - reference_files['ifuslicer'] is None and \ - reference_files['ifupost'] is None: + if ( + reference_files["ifufore"] is None + and reference_files["ifuslicer"] is None + and reference_files["ifupost"] is None + ): # No ifu reference files, won't be able to create pipeline - log_message = 'No ifufore, ifuslicer or ifupost reference files' + log_message = "No ifufore, ifuslicer or ifupost reference files" log.critical(log_message) raise RuntimeError(log_message) # Check for data actually being present on NRS2 - log_message = "No IFU slices fall on detector {0}".format(detector) - if detector == "NRS2" and grating.endswith('M'): + log_message = f"No IFU slices fall on detector {detector}" + if detector == "NRS2" and grating.endswith("M"): # Mid-resolution gratings do not project on NRS2. log.critical(log_message) raise NoDataOnDetectorError(log_message) - if detector == "NRS2" and grating == "G140H" and filter == "F070LP": + if detector == "NRS2" and grating == "G140H" and filt == "F070LP": # This combination of grating and filter does not project on NRS2. log.critical(log_message) raise NoDataOnDetectorError(log_message) slits = np.arange(30) # Get the corrected disperser model - disperser = get_disperser(input_model, reference_files['disperser']) + disperser = get_disperser(input_model, reference_files["disperser"]) # Get the default spectral order and wavelength range and record them in the model. - sporder, wrange = get_spectral_order_wrange(input_model, reference_files['wavelengthrange']) + sporder, wrange = get_spectral_order_wrange(input_model, reference_files["wavelengthrange"]) input_model.meta.wcsinfo.waverange_start = wrange[0] input_model.meta.wcsinfo.waverange_end = wrange[1] input_model.meta.wcsinfo.spectral_order = sporder @@ -222,74 +270,82 @@ def ifu(input_model, reference_files, slit_y_range=[-.55, .55]): # DMS to SCA transform dms2detector = dms_to_sca(input_model) # DETECTOR to GWA transform - det2gwa = Identity(2) & detector_to_gwa(reference_files, - input_model.meta.instrument.detector, - disperser) + det2gwa = Identity(2) & detector_to_gwa( + reference_files, input_model.meta.instrument.detector, disperser + ) # GWA to SLIT gwa2slit = gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) # SLIT to MSA transform - slit2slicer = ifuslit_to_slicer(slits, reference_files, input_model) + slit2slicer = ifuslit_to_slicer(slits, reference_files) # SLICER to MSA Entrance slicer2msa = slicer_to_msa(reference_files) - det, sca, gwa, slit_frame, slicer_frame, msa_frame, oteip, v2v3, v2v3vacorr, world = create_frames() + det, sca, gwa, slit_frame, slicer_frame, msa_frame, oteip, v2v3, v2v3vacorr, world = ( + create_frames() + ) exp_type = input_model.meta.exposure.type.upper() - is_lamp_exposure = exp_type in ['NRS_LAMP', 'NRS_AUTOWAVE', 'NRS_AUTOFLAT'] + is_lamp_exposure = exp_type in ["NRS_LAMP", "NRS_AUTOWAVE", "NRS_AUTOFLAT"] + + if input_model.meta.instrument.filter == "OPAQUE" or is_lamp_exposure: + # If filter is "OPAQUE" or if internal lamp exposure + # the NIRSPEC WCS pipeline stops at the MSA. + pipeline = [ + (det, dms2detector), + (sca, det2gwa.rename("detector2gwa")), + (gwa, gwa2slit.rename("gwa2slit")), + (slit_frame, slit2slicer), + (slicer_frame, slicer2msa), + (msa_frame, None), + ] + return pipeline + + # MSA to OTEIP transform + msa2oteip = ifu_msa_to_oteip(reference_files) + # OTEIP to V2,V3 transform + # This includes a wavelength unit conversion from meters to microns. + oteip2v23 = oteip_to_v23(reference_files) + + # Compute differential velocity aberration (DVA) correction: + va_corr = pointing.dva_corr_model( + va_scale=input_model.meta.velocity_aberration.scale_factor, + v2_ref=input_model.meta.wcsinfo.v2_ref, + v3_ref=input_model.meta.wcsinfo.v3_ref, + ) & Identity(1) + + # V2, V3 to sky + tel2sky = pointing.v23tosky(input_model) & Identity(1) - if input_model.meta.instrument.filter == 'OPAQUE' or is_lamp_exposure: - # If filter is "OPAQUE" or if internal lamp exposure the NIRSPEC WCS pipeline stops at the MSA. - pipeline = [(det, dms2detector), - (sca, det2gwa.rename('detector2gwa')), - (gwa, gwa2slit.rename('gwa2slit')), - (slit_frame, slit2slicer), - (slicer_frame, slicer2msa), - (msa_frame, None)] - else: - # MSA to OTEIP transform - msa2oteip = ifu_msa_to_oteip(reference_files) - # OTEIP to V2,V3 transform - # This includes a wavelength unit conversion from meters to microns. - oteip2v23 = oteip_to_v23(reference_files, input_model) - - # Compute differential velocity aberration (DVA) correction: - va_corr = pointing.dva_corr_model( - va_scale=input_model.meta.velocity_aberration.scale_factor, - v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref - ) & Identity(1) - - # V2, V3 to sky - tel2sky = pointing.v23tosky(input_model) & Identity(1) - - # Create coordinate frames in the NIRSPEC WCS pipeline" - # - # The oteip2v2v3 transform converts the wavelength from meters (which is assumed - # in the whole pipeline) to microns (which is the expected output) - # - # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world" - - pipeline = [(det, dms2detector), - (sca, det2gwa.rename('detector2gwa')), - (gwa, gwa2slit.rename('gwa2slit')), - (slit_frame, slit2slicer), - (slicer_frame, slicer2msa), - (msa_frame, msa2oteip.rename('msa2oteip')), - (oteip, oteip2v23.rename('oteip2v23')), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + # Create coordinate frames in the NIRSPEC WCS pipeline" + # + # The oteip2v2v3 transform converts the wavelength from meters (which is assumed + # in the whole pipeline) to microns (which is the expected output) + # + # "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world" + + pipeline = [ + (det, dms2detector), + (sca, det2gwa.rename("detector2gwa")), + (gwa, gwa2slit.rename("gwa2slit")), + (slit_frame, slit2slicer), + (slicer_frame, slicer2msa), + (msa_frame, msa2oteip.rename("msa2oteip")), + (oteip, oteip2v23.rename("oteip2v23")), + (v2v3, va_corr), + (v2v3vacorr, tel2sky), + (world, None), + ] return pipeline def slits_wcs(input_model, reference_files, slit_y_range): """ - The WCS pipeline for MOS and fixed slits. + Create the WCS pipeline for MOS and fixed slits. The coordinate frames are: "detector" : the science frame @@ -303,18 +359,25 @@ def slits_wcs(input_model, reference_files, slit_y_range): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel The input data model. reference_files : dict - The reference files used for this mode. - slit_y_range : list - The slit dimensions relative to the center of the slit. + Mapping between reftype (keys) and reference file name (vals). + Requires the 'msa', 'msametafile', 'disperser', 'wavelengthrange', + 'fpa', 'camera', 'collimator', 'fore', and 'ote' reference files. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ open_slits_id = get_open_slits(input_model, reference_files, slit_y_range) if not open_slits_id: return None n_slits = len(open_slits_id) - log.info("Computing WCS for {0} open slitlets".format(n_slits)) + log.info(f"Computing WCS for {n_slits} open slitlets") msa_pipeline = slitlets_wcs(input_model, reference_files, open_slits_id) @@ -323,29 +386,47 @@ def slits_wcs(input_model, reference_files, slit_y_range): def slitlets_wcs(input_model, reference_files, open_slits_id): """ - Create The WCS pipeline for MOS and Fixed slits for the - specific opened shutters/slits. ``slit_y_range`` is taken from - ``slit.ymin`` and ``slit.ymax``. + Create WCS pipeline for MOS and Fixed slits for the specific opened shutters/slits. + + ``slit_y_range`` is taken from ``slit.ymin`` and ``slit.ymax``. - Note: This function is also used by the ``msaflagopen`` step. + Parameters + ---------- + input_model : JwstDataModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'msa', 'disperser', 'wavelengthrange', + 'fpa', 'camera', 'collimator', 'fore', and 'ote' reference files. + open_slits_id : list + A list of slit IDs that are open. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. + + Notes + ----- + This function is also used by the ``msaflagopen`` step. """ # Get the corrected disperser model - disperser = get_disperser(input_model, reference_files['disperser']) + disperser = get_disperser(input_model, reference_files["disperser"]) # Get the default spectral order and wavelength range and record them in the model. - sporder, wrange = get_spectral_order_wrange(input_model, reference_files['wavelengthrange']) + sporder, wrange = get_spectral_order_wrange(input_model, reference_files["wavelengthrange"]) input_model.meta.wcsinfo.waverange_start = wrange[0] input_model.meta.wcsinfo.waverange_end = wrange[1] - log.info("SPORDER= {0}, wrange={1}".format(sporder, wrange)) + log.info(f"SPORDER= {sporder}, wrange={wrange}") input_model.meta.wcsinfo.spectral_order = sporder # DMS to SCA transform dms2detector = dms_to_sca(input_model) - dms2detector.name = 'dms2sca' + dms2detector.name = "dms2sca" # DETECTOR to GWA transform - det2gwa = Identity(2) & detector_to_gwa(reference_files, - input_model.meta.instrument.detector, - disperser) + det2gwa = Identity(2) & detector_to_gwa( + reference_files, input_model.meta.instrument.detector, disperser + ) det2gwa.name = "det2gwa" # GWA to SLIT @@ -353,7 +434,7 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): gwa2slit.name = "gwa2slit" # SLIT to MSA transform - slit2msa = slit_to_msa(open_slits_id, reference_files['msa']) + slit2msa = slit_to_msa(open_slits_id, reference_files["msa"]) slit2msa.name = "slit2msa" # Create coordinate frames in the NIRSPEC WCS pipeline" @@ -363,70 +444,96 @@ def slitlets_wcs(input_model, reference_files, open_slits_id): exp_type = input_model.meta.exposure.type.upper() - is_lamp_exposure = exp_type in ['NRS_LAMP', 'NRS_AUTOWAVE', 'NRS_AUTOFLAT'] + is_lamp_exposure = exp_type in ["NRS_LAMP", "NRS_AUTOWAVE", "NRS_AUTOFLAT"] - if input_model.meta.instrument.filter == 'OPAQUE' or is_lamp_exposure: + if input_model.meta.instrument.filter == "OPAQUE" or is_lamp_exposure: # convert to microns if the pipeline ends earlier - msa_pipeline = [(det, dms2detector), - (sca, det2gwa), - (gwa, gwa2slit), - (slit_frame, slit2msa), - (msa_frame, None)] - else: - # MSA to OTEIP transform - msa2oteip = msa_to_oteip(reference_files) - msa2oteip.name = "msa2oteip" - - # OTEIP to V2,V3 transform - # This includes a wavelength unit conversion from meters to microns. - oteip2v23 = oteip_to_v23(reference_files, input_model) - oteip2v23.name = "oteip2v23" - - # Compute differential velocity aberration (DVA) correction: - va_corr = pointing.dva_corr_model( - va_scale=input_model.meta.velocity_aberration.scale_factor, - v2_ref=input_model.meta.wcsinfo.v2_ref, - v3_ref=input_model.meta.wcsinfo.v3_ref - ) & Identity(1) - - # V2, V3 to sky - tel2sky = pointing.v23tosky(input_model) & Identity(1) - tel2sky.name = "v2v3_to_sky" - - msa_pipeline = [(det, dms2detector), - (sca, det2gwa), - (gwa, gwa2slit), - (slit_frame, slit2msa), - (msa_frame, msa2oteip), - (oteip, oteip2v23), - (v2v3, va_corr), - (v2v3vacorr, tel2sky), - (world, None)] + msa_pipeline = [ + (det, dms2detector), + (sca, det2gwa), + (gwa, gwa2slit), + (slit_frame, slit2msa), + (msa_frame, None), + ] + return msa_pipeline + + # MSA to OTEIP transform + msa2oteip = msa_to_oteip(reference_files) + msa2oteip.name = "msa2oteip" + + # OTEIP to V2,V3 transform + # This includes a wavelength unit conversion from meters to microns. + oteip2v23 = oteip_to_v23(reference_files) + oteip2v23.name = "oteip2v23" + + # Compute differential velocity aberration (DVA) correction: + va_corr = pointing.dva_corr_model( + va_scale=input_model.meta.velocity_aberration.scale_factor, + v2_ref=input_model.meta.wcsinfo.v2_ref, + v3_ref=input_model.meta.wcsinfo.v3_ref, + ) & Identity(1) + + # V2, V3 to sky + tel2sky = pointing.v23tosky(input_model) & Identity(1) + tel2sky.name = "v2v3_to_sky" + + msa_pipeline = [ + (det, dms2detector), + (sca, det2gwa), + (gwa, gwa2slit), + (slit_frame, slit2msa), + (msa_frame, msa2oteip), + (oteip, oteip2v23), + (v2v3, va_corr), + (v2v3vacorr, tel2sky), + (world, None), + ] return msa_pipeline -def get_open_slits(input_model, reference_files=None, slit_y_range=[-.55, .55]): - """Return the opened slits/shutters in a MOS or Fixed Slits exposure. +def get_open_slits(input_model, reference_files=None, slit_y_range=(-0.55, 0.55)): + """ + Return the opened slits/shutters in a MOS or Fixed Slits exposure. + + Parameters + ---------- + input_model : JwstDataModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'msa', 'msametafile', 'disperser', 'wavelengthrange', + 'fpa', 'camera', and 'collimator' reference files. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + slits : list[Slit] + A list of `~stdatamodels.jwst.transforms.models.Slit` objects. """ exp_type = input_model.meta.exposure.type.lower() lamp_mode = input_model.meta.instrument.lamp_mode if isinstance(lamp_mode, str): lamp_mode = lamp_mode.lower() else: - lamp_mode = 'none' + lamp_mode = "none" # MOS/MSA exposure requiring MSA metadata file - if exp_type in ["nrs_msaspec", "nrs_autoflat"] or ((exp_type in ["nrs_lamp", "nrs_autowave"]) and - (lamp_mode == "msaspec")): + if exp_type in ["nrs_msaspec", "nrs_autoflat"] or ( + (exp_type in ["nrs_lamp", "nrs_autowave"]) and (lamp_mode == "msaspec") + ): prog_id = input_model.meta.observation.program_number.lstrip("0") - msa_metadata_file, msa_metadata_id, dither_point = get_msa_metadata(input_model, reference_files) - if reference_files is not None and 'msa' in reference_files: - slit_scales = get_msa_slit_scales(reference_files['msa']) + msa_metadata_file, msa_metadata_id, dither_point = get_msa_metadata( + input_model, reference_files + ) + if reference_files is not None and "msa" in reference_files: + slit_scales = get_msa_slit_scales(reference_files["msa"]) else: slit_scales = None - slits = get_open_msa_slits(prog_id, msa_metadata_file, msa_metadata_id, - dither_point, slit_y_range, slit_scales) + slits = get_open_msa_slits( + prog_id, msa_metadata_file, msa_metadata_id, dither_point, slit_y_range, slit_scales + ) # Fixed slits exposure (non-TSO) elif exp_type == "nrs_fixedslit": @@ -434,32 +541,48 @@ def get_open_slits(input_model, reference_files=None, slit_y_range=[-.55, .55]): # Bright object (TSO) exposure in S1600A1 fixed slit elif exp_type == "nrs_brightobj": - slits = [Slit('S1600A1', 3, 0, 0, 0, slit_y_range[0], slit_y_range[1], 5, 1)] + slits = [Slit("S1600A1", 3, 0, 0, 0, slit_y_range[0], slit_y_range[1], 5, 1)] # Lamp exposure using fixed slits elif exp_type in ["nrs_lamp", "nrs_autowave"]: - if lamp_mode in ['fixedslit', 'brightobj']: + if lamp_mode in ["fixedslit", "brightobj"]: slits = get_open_fixed_slits(input_model, slit_y_range) else: - raise ValueError("EXP_TYPE {0} is not supported".format(exp_type.upper())) + raise ValueError(f"EXP_TYPE {exp_type.upper()} is not supported") if reference_files is not None and slits: slits = validate_open_slits(input_model, slits, reference_files) - log.info("Slits projected on detector {0}: {1}".format(input_model.meta.instrument.detector, - [sl.name for sl in slits])) + log.info( + f"Slits projected on detector {input_model.meta.instrument.detector}: " + f"{[sl.name for sl in slits]}" + ) if not slits: - log_message = "No open slits fall on detector {0}.".format(input_model.meta.instrument.detector) + log_message = f"No open slits fall on detector {input_model.meta.instrument.detector}." log.critical(log_message) raise NoDataOnDetectorError(log_message) return slits -def get_open_fixed_slits(input_model, slit_y_range=[-.55, .55]): - """ Return the opened fixed slits.""" +def get_open_fixed_slits(input_model, slit_y_range=(-0.55, 0.55)): + """ + Return the opened fixed slits. + + Parameters + ---------- + input_model : JwstDataModel + The input data model. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + slits : list[Slit] + A list of `~stdatamodels.jwst.transforms.models.Slit` objects. + """ if input_model.meta.subarray.name is None: raise ValueError("Input file is missing SUBARRAY value/keyword.") if input_model.meta.instrument.fixed_slit is None: - input_model.meta.instrument.fixed_slit = 'NONE' + input_model.meta.instrument.fixed_slit = "NONE" primary_slit = input_model.meta.instrument.fixed_slit ylow, yhigh = slit_y_range @@ -474,21 +597,61 @@ def get_open_fixed_slits(input_model, slit_y_range=[-.55, .55]): # quadrant 5. # # Slit(Name, ShutterID, DitherPos, Xcen, Ycen, Ymin, Ymax, Quad, SourceID) - s2a1 = Slit('S200A1', 0, 0, 0, 0, ylow, yhigh, 5, - 1 if primary_slit == 'S200A1' - else 10 * FIXED_SLIT_NUMS[primary_slit] + 1) - s2a2 = Slit('S200A2', 1, 0, 0, 0, ylow, yhigh, 5, - 1 if primary_slit == 'S200A2' - else 10 * FIXED_SLIT_NUMS[primary_slit] + 2) - s4a1 = Slit('S400A1', 2, 0, 0, 0, ylow, yhigh, 5, - 1 if primary_slit == 'S400A1' - else 10 * FIXED_SLIT_NUMS[primary_slit] + 3) - s16a1 = Slit('S1600A1', 3, 0, 0, 0, ylow, yhigh, 5, - 1 if primary_slit == 'S1600A1' - else 10 * FIXED_SLIT_NUMS[primary_slit] + 4) - s2b1 = Slit('S200B1', 4, 0, 0, 0, ylow, yhigh, 5, - 1 if primary_slit == 'S200B1' - else 10 * FIXED_SLIT_NUMS[primary_slit] + 5) + s2a1 = Slit( + "S200A1", + 0, + 0, + 0, + 0, + ylow, + yhigh, + 5, + 1 if primary_slit == "S200A1" else 10 * FIXED_SLIT_NUMS[primary_slit] + 1, + ) + s2a2 = Slit( + "S200A2", + 1, + 0, + 0, + 0, + ylow, + yhigh, + 5, + 1 if primary_slit == "S200A2" else 10 * FIXED_SLIT_NUMS[primary_slit] + 2, + ) + s4a1 = Slit( + "S400A1", + 2, + 0, + 0, + 0, + ylow, + yhigh, + 5, + 1 if primary_slit == "S400A1" else 10 * FIXED_SLIT_NUMS[primary_slit] + 3, + ) + s16a1 = Slit( + "S1600A1", + 3, + 0, + 0, + 0, + ylow, + yhigh, + 5, + 1 if primary_slit == "S1600A1" else 10 * FIXED_SLIT_NUMS[primary_slit] + 4, + ) + s2b1 = Slit( + "S200B1", + 4, + 0, + 0, + 0, + ylow, + yhigh, + 5, + 1 if primary_slit == "S200B1" else 10 * FIXED_SLIT_NUMS[primary_slit] + 5, + ) # Decide which slits need to be added to this exposure subarray = input_model.meta.subarray.name.upper() @@ -499,8 +662,7 @@ def get_open_fixed_slits(input_model, slit_y_range=[-.55, .55]): slits.append(s2a2) elif subarray == "SUBS400A1": slits.append(s4a1) - elif subarray in ("SUB2048", "SUB512", "SUB512S", - "SUB1024A", "SUB1024B"): + elif subarray in ("SUB2048", "SUB512", "SUB512S", "SUB1024A", "SUB1024B"): slits.append(s16a1) elif subarray == "SUBS200B1": slits.append(s2b1) @@ -514,17 +676,29 @@ def get_msa_metadata(input_model, reference_files): """ Get the MSA metadata file (MSAMTFL) and the msa metadata ID (MSAMETID). + Parameters + ---------- + input_model : JwstDataModel + The input data model. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'msametafile' reference file. + + Returns + ------- + slits : list[Slit] + A list of `~stdatamodels.jwst.transforms.models.Slit` objects. """ try: - msa_config = reference_files['msametafile'] + msa_config = reference_files["msametafile"] except (KeyError, TypeError): - log.info('MSA metadata file not in reference files dict') - log.info('Getting MSA metadata file from MSAMETFL keyword') + log.info("MSA metadata file not in reference files dict") + log.info("Getting MSA metadata file from MSAMETFL keyword") msa_config = input_model.meta.instrument.msa_metadata_file if msa_config is None: message = "msa_metadata_file is None." log.critical(message) - raise MSAFileError(message) + raise MSAFileError(message) from None msa_metadata_id = input_model.meta.instrument.msa_metadata_id if msa_metadata_id is None: message = "Missing msa_metadata_id (keyword MSAMETID)." @@ -556,17 +730,23 @@ def get_msa_slit_scales(msa_ref_file): msa = MSAModel(msa_ref_file) scales = {} for quadrant in range(1, 5): - msa_quadrant = getattr(msa, 'Q{0}'.format(quadrant)) + msa_quadrant = getattr(msa, f"Q{quadrant}") msa_data = msa_quadrant.data - scale_x = (msa_data['XC'][1] - msa_data['XC'][0]) / msa_data['SIZEX'][0] - scale_y = msa_data['YC'][365] / msa_data['SIZEY'][0] + scale_x = (msa_data["XC"][1] - msa_data["XC"][0]) / msa_data["SIZEX"][0] + scale_y = msa_data["YC"][365] / msa_data["SIZEY"][0] scales[quadrant] = (scale_x, scale_y) return scales -def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, - slit_y_range=[-.55, .55], slit_scales=None): +def get_open_msa_slits( + prog_id, + msa_file, + msa_metadata_id, + dither_position, + slit_y_range=(-0.55, 0.55), + slit_scales=None, +): """ Return the opened MOS slitlets. @@ -600,8 +780,11 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, The MSA meta id for the science file, FITS keyword ``MSAMETID``. dither_position : int The index in the dither pattern, FITS keyword ``PATT_NUM``. - slit_y_range : list or tuple of size 2 - The lower and upper limit of the slit. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + slit_scales : dict + A dictionary of scaling factors for MSA shutters. Keys are integer quadrant values + (one-indexed). Values are 2-tuples of float values (scale_x, scale_y). Returns ------- @@ -610,7 +793,6 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, ("name", "shutter_id", "xcen", "ycen", "ymin", "ymax", "quadrant", "source_id", "shutter_state", "source_name", "source_alias", "stellarity", "source_xpos", "source_ypos", "source_ra", "source_dec") - """ slitlets = [] ylow, yhigh = slit_y_range @@ -619,33 +801,38 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, try: msa_file = fits.open(msa_file, memmap=False) except FileNotFoundError: - message = "Missing MSA meta (MSAMETFL) file {}".format(msa_file) + message = f"Missing MSA meta (MSAMETFL) file {msa_file}" log.error(message) - raise MSAFileError(message) + raise MSAFileError(message) from None except OSError: - message = "Unable to read MSA FITS file (MSAMETFL) {0}".format(msa_file) + message = f"Unable to read MSA FITS file (MSAMETFL) {msa_file}" log.error(message) - raise MSAFileError(message) + raise MSAFileError(message) from None except Exception: - message = "Problem reading MSA metafile (MSAMETFL) {0}".format(msa_file) + message = f"Problem reading MSA metafile (MSAMETFL) {msa_file}" log.error(message) - raise MSAFileError(message) + raise MSAFileError(message) from None # Set an empty dictionary for slit_scales if not provided if slit_scales is None: slit_scales = {} # Get the shutter and source info tables from the _msa.fits file. - msa_conf = msa_file[('SHUTTER_INFO', 1)] # EXTNAME = 'SHUTTER_INFO' + msa_conf = msa_file[("SHUTTER_INFO", 1)] # EXTNAME = 'SHUTTER_INFO' msa_source = msa_file[("SOURCE_INFO", 1)].data # EXTNAME = 'SOURCE_INFO' # First we are going to filter the msa_file data on the msa_metadata_id # and dither_point_index. - msa_data = [x for x in msa_conf.data if x['msa_metadata_id'] == msa_metadata_id - and x['dither_point_index'] == dither_position] - log.debug(f'msa_data with msa_metadata_id = {msa_metadata_id} {msa_data}') - log.info(f'Retrieving open MSA slitlets for msa_metadata_id = {msa_metadata_id} ' - f'and dither_index = {dither_position}') + msa_data = [ + x + for x in msa_conf.data + if x["msa_metadata_id"] == msa_metadata_id and x["dither_point_index"] == dither_position + ] + log.debug(f"msa_data with msa_metadata_id = {msa_metadata_id} {msa_data}") + log.info( + f"Retrieving open MSA slitlets for msa_metadata_id = {msa_metadata_id} " + f"and dither_index = {dither_position}" + ) # Sort the MSA rows by slitlet_id slitlet_sets = {} @@ -653,9 +840,8 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, # Check for fixed slit: if set, then slitlet_id is null is_fs = False try: - fixed_slit = row['fixed_slit'] - if (fixed_slit in FIXED_SLIT_NUMS.keys() - and fixed_slit != 'NONE'): + fixed_slit = row["fixed_slit"] + if fixed_slit in FIXED_SLIT_NUMS.keys() and fixed_slit != "NONE": is_fs = True except (IndexError, ValueError, KeyError): # May be old-style MSA file without a fixed_slit column @@ -666,7 +852,7 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, slitlet_id = fixed_slit else: # MSA - use the slitlet ID - slitlet_id = row['slitlet_id'] + slitlet_id = row["slitlet_id"] # Append the row for the slitlet if slitlet_id in slitlet_sets: @@ -680,17 +866,19 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, # Now let's look at each unique slitlet id for slitlet_id, slitlet_rows in slitlet_sets.items(): # Get the open shutter information from the slitlet rows - open_shutters = [x['shutter_column'] for x in slitlet_rows] + open_shutters = [x["shutter_column"] for x in slitlet_rows] # How many shutters in the slitlet are labeled as "main" or "primary"? - n_main_shutter = len([s for s in slitlet_rows if s['primary_source'] == 'Y']) + n_main_shutter = len([s for s in slitlet_rows if s["primary_source"] == "Y"]) # Check for fixed slit sources defined in the MSA file is_fs = [False] * len(slitlet_rows) for i, slitlet in enumerate(slitlet_rows): try: - if (slitlet['fixed_slit'] in FIXED_SLIT_NUMS.keys() - and slitlet['fixed_slit'] != 'NONE'): + if ( + slitlet["fixed_slit"] in FIXED_SLIT_NUMS.keys() + and slitlet["fixed_slit"] != "NONE" + ): is_fs[i] = True except (IndexError, ValueError, KeyError): # May be old-style MSA file without a fixed_slit column @@ -716,41 +904,42 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, # Source position and id if n_main_shutter == 1: # Source is marked primary - source_id = slitlet['source_id'] - source_xpos = np.nan_to_num(slitlet['estimated_source_in_shutter_x'], nan=0.5) - source_ypos = np.nan_to_num(slitlet['estimated_source_in_shutter_y'], nan=0.5) + source_id = slitlet["source_id"] + source_xpos = np.nan_to_num(slitlet["estimated_source_in_shutter_x"], nan=0.5) + source_ypos = np.nan_to_num(slitlet["estimated_source_in_shutter_y"], nan=0.5) - log.info(f'Found fixed slit {slitlet_id} with source_id = {source_id}.') + log.info(f"Found fixed slit {slitlet_id} with source_id = {source_id}.") # Get source info for this slitlet: # note that slits with a real source assigned have source_id > 0, # while slits with source_id < 0 contain "virtual" sources try: source_name, source_alias, stellarity, source_ra, source_dec = [ - (s['source_name'], s['alias'], s['stellarity'], s['ra'], s['dec']) - for s in msa_source if s['source_id'] == source_id][0] + (s["source_name"], s["alias"], s["stellarity"], s["ra"], s["dec"]) + for s in msa_source + if s["source_id"] == source_id + ][0] except IndexError: # Missing source information: assign a virtual source name log.warning("Could not retrieve source info from MSA file") source_name = f"{prog_id}_VRT{slitlet_id}" - source_alias = "VRT{}".format(slitlet_id) + source_alias = f"VRT{slitlet_id}" stellarity = 0.0 source_ra = 0.0 source_dec = 0.0 else: - log.warning(f'Fixed slit {slitlet_id} is not a primary source; ' - f'skipping it.') + log.warning(f"Fixed slit {slitlet_id} is not a primary source; skipping it.") continue elif any(is_fs): # Unsupported fixed slit configuration - message = ("For slitlet_id = {}, metadata_id = {}, " - "dither_index = {}".format( - slitlet_id, msa_metadata_id, dither_position)) + message = ( + f"For slitlet_id = {slitlet_id}, metadata_id = {msa_metadata_id}, " + f"dither_index = {dither_position}" + ) log.warning(message) - message = ("MSA configuration file has an unsupported " - "fixed slit configuration.") + message = "MSA configuration file has an unsupported fixed slit configuration." log.warning(message) msa_file.close() raise MSAFileError(message) @@ -761,15 +950,16 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, if len(open_shutters) == 1: jmin = jmax = j = open_shutters[0] else: - jmin = min([s['shutter_column'] for s in slitlet_rows]) - jmax = max([s['shutter_column'] for s in slitlet_rows]) + jmin = min([s["shutter_column"] for s in slitlet_rows]) + jmax = max([s["shutter_column"] for s in slitlet_rows]) j = jmin + (jmax - jmin) // 2 ymax = yhigh + margin + (jmax - j) * 1.15 ymin = -(-ylow + margin) + (jmin - j) * 1.15 - quadrant = slitlet_rows[0]['shutter_quadrant'] + quadrant = slitlet_rows[0]["shutter_quadrant"] ycen = j - xcen = slitlet_rows[0]['shutter_row'] # grab the first as they are all the same - shutter_id = np.int64(xcen) + (np.int64(ycen) - 1) * 365 # shutter numbers in MSA file are 1-indexed + xcen = slitlet_rows[0]["shutter_row"] # grab the first as they are all the same + # shutter numbers in MSA file are 1-indexed + shutter_id = np.int64(xcen) + (np.int64(ycen) - 1) * 365 # Background slits all have source_id=0 in the msa_file, # so assign a unique id based on the slitlet_id source_id = slitlet_id @@ -779,24 +969,31 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, source_xpos = 0.5 source_ypos = 0.5 source_name = f"{prog_id}_BKG{slitlet_id}" - source_alias = "BKG{}".format(slitlet_id) + source_alias = f"BKG{slitlet_id}" stellarity = 0.0 source_ra = 0.0 source_dec = 0.0 - log.info(f'Slitlet {slitlet_id} is background only; assigned source_id={source_id}') + log.info(f"Slitlet {slitlet_id} is background only; assigned source_id={source_id}") # There is 1 main shutter: this is a slit containing either a real or virtual source elif n_main_shutter == 1: xcen, ycen, quadrant, source_xpos, source_ypos = [ - (s['shutter_row'], s['shutter_column'], s['shutter_quadrant'], - np.nan_to_num(s['estimated_source_in_shutter_x'], nan=0.5), - np.nan_to_num(s['estimated_source_in_shutter_y'], nan=0.5)) - for s in slitlet_rows if s['background'] == 'N'][0] - shutter_id = np.int64(xcen) + (np.int64(ycen) - 1) * 365 # shutter numbers in MSA file are 1-indexed + ( + s["shutter_row"], + s["shutter_column"], + s["shutter_quadrant"], + np.nan_to_num(s["estimated_source_in_shutter_x"], nan=0.5), + np.nan_to_num(s["estimated_source_in_shutter_y"], nan=0.5), + ) + for s in slitlet_rows + if s["background"] == "N" + ][0] + # shutter numbers in MSA file are 1-indexed + shutter_id = np.int64(xcen) + (np.int64(ycen) - 1) * 365 # y-size - jmin = min([s['shutter_column'] for s in slitlet_rows]) - jmax = max([s['shutter_column'] for s in slitlet_rows]) + jmin = min([s["shutter_column"] for s in slitlet_rows]) + jmax = max([s["shutter_column"] for s in slitlet_rows]) j = ycen ymax = yhigh + margin + (jmax - j) * 1.15 ymin = -(-ylow + margin) + (jmin - j) * 1.15 @@ -804,33 +1001,40 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, # Get the source_id from the primary shutter entry source_id = None for i in range(len(slitlet_rows)): - if slitlet_rows[i]['primary_source'] == 'Y': - source_id = slitlet_rows[i]['source_id'] + if slitlet_rows[i]["primary_source"] == "Y": + source_id = slitlet_rows[i]["source_id"] # Get source info for this slitlet; # note that slits with a real source assigned have source_id > 0, # while slits with source_id < 0 contain "virtual" sources try: source_name, source_alias, stellarity, source_ra, source_dec = [ - (s['source_name'], s['alias'], s['stellarity'], s['ra'], s['dec']) - for s in msa_source if s['source_id'] == source_id][0] + (s["source_name"], s["alias"], s["stellarity"], s["ra"], s["dec"]) + for s in msa_source + if s["source_id"] == source_id + ][0] except IndexError: source_name = f"{prog_id}_VRT{slitlet_id}" - source_alias = "VRT{}".format(slitlet_id) + source_alias = f"VRT{slitlet_id}" stellarity = 0.0 source_ra = 0.0 source_dec = 0.0 - log.warning(f"Could not retrieve source info from MSA file; " - f"assigning virtual source_name={source_name}") + log.warning( + f"Could not retrieve source info from MSA file; " + f"assigning virtual source_name={source_name}" + ) if source_id < 0: - log.info(f'Slitlet {slitlet_id} contains virtual source, ' - f'with source_id={source_id}') + log.info( + f"Slitlet {slitlet_id} contains virtual source, with source_id={source_id}" + ) # More than 1 main shutter: Not allowed! else: - message = ("For slitlet_id = {}, metadata_id = {}, " - "and dither_index = {}".format(slitlet_id, msa_metadata_id, dither_position)) + message = ( + f"For slitlet_id = {slitlet_id}, metadata_id = {msa_metadata_id}, " + f"and dither_index = {dither_position}" + ) log.warning(message) message = "MSA configuration file has more than 1 shutter with primary source" log.warning(message) @@ -855,11 +1059,28 @@ def get_open_msa_slits(prog_id, msa_file, msa_metadata_id, dither_position, # Create the output list of tuples that contain the required # data for further computations - slit_parameters = (slitlet_id, shutter_id, dither_position, xcen, ycen, ymin, ymax, - quadrant, source_id, all_shutters, source_name, source_alias, - stellarity, source_xpos, source_ypos, source_ra, source_dec, - scale_x, scale_y) - log.debug(f'Appending slit: {slit_parameters}') + slit_parameters = ( + slitlet_id, + shutter_id, + dither_position, + xcen, + ycen, + ymin, + ymax, + quadrant, + source_id, + all_shutters, + source_name, + source_alias, + stellarity, + source_xpos, + source_ypos, + source_ra, + source_dec, + scale_x, + scale_y, + ) + log.debug(f"Appending slit: {slit_parameters}") slitlets.append(Slit(*slit_parameters)) msa_file.close() @@ -874,8 +1095,8 @@ def _shutter_id_to_str(open_shutters, ycen): ---------- open_shutters : list List of IDs (shutter_id) of open shutters. - xcen : int - X coordinate of main shutter. + ycen : int + Y coordinate of main shutter. Returns ------- @@ -890,7 +1111,7 @@ def _shutter_id_to_str(open_shutters, ycen): all_shutters[all_shutters == i] = 1 all_shutters[all_shutters != 1] = 0 all_shutters = all_shutters.astype(str) - all_shutters[cen_ind] = 'x' + all_shutters[cen_ind] = "x" return "".join(all_shutters) @@ -900,32 +1121,39 @@ def get_spectral_order_wrange(input_model, wavelengthrange_file): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel The input data model. wavelengthrange_file : str Reference file of type "wavelengthrange". + + Returns + ------- + order : int + The spectral order. + wrange : list + The wavelength range. """ # Nirspec full spectral range - full_range = [.6e-6, 5.3e-6] + full_range = [0.6e-6, 5.3e-6] - filter = input_model.meta.instrument.filter + filt = input_model.meta.instrument.filter lamp = input_model.meta.instrument.lamp_state grating = input_model.meta.instrument.grating exp_type = input_model.meta.exposure.type - is_lamp_exposure = exp_type in ['NRS_LAMP', 'NRS_AUTOWAVE', 'NRS_AUTOFLAT'] + is_lamp_exposure = exp_type in ["NRS_LAMP", "NRS_AUTOWAVE", "NRS_AUTOFLAT"] wave_range_model = WavelengthrangeModel(wavelengthrange_file) wrange_selector = wave_range_model.waverange_selector - if filter == "OPAQUE" or is_lamp_exposure: - keyword = lamp + '_' + grating + if filt == "OPAQUE" or is_lamp_exposure: + keyword = lamp + "_" + grating else: - keyword = filter + '_' + grating + keyword = filt + "_" + grating try: index = wrange_selector.index(keyword) except (KeyError, ValueError): # Combination of filter_grating is not in wavelengthrange file. - gratings = [s.split('_')[1] for s in wrange_selector] + gratings = [s.split("_")[1] for s in wrange_selector] try: index = gratings.index(grating) except ValueError: # grating not in list @@ -934,8 +1162,10 @@ def get_spectral_order_wrange(input_model, wavelengthrange_file): else: order = wave_range_model.order[index] wrange = wave_range_model.wavelengthrange[index] - log.info("Combination {0} missing in wavelengthrange file, setting " - "order to {1} and range to {2}.".format(keyword, order, wrange)) + log.info( + f"Combination {keyword} missing in wavelengthrange file, setting " + f"order to {order} and range to {wrange}." + ) else: # Combination of filter_grating is found in wavelengthrange file. order = wave_range_model.order[index] @@ -945,29 +1175,29 @@ def get_spectral_order_wrange(input_model, wavelengthrange_file): return order, wrange -def ifuslit_to_slicer(slits, reference_files, input_model): +def ifuslit_to_slicer(slits, reference_files): """ - The transform from ``slit_frame`` to ``slicer`` frame. + Create the transform from ``slit_frame`` to ``slicer`` frame. Parameters ---------- slits : list A list of slit IDs for all slices. reference_files : dict - {reference_type: reference_file_name} - input_model : `~jwst.datamodels.IFUImageModel` + Mapping between reftype (keys) and reference file name (vals). + Requires the 'ifuslicer' reference file. Returns ------- - model : `~stdatamodels.jwst.transforms.Slit2Msa` model. + model : `~astropy.modeling.Model`. Transform from ``slit_frame`` to ``slicer`` frame. """ - ifuslicer = IFUSlicerModel(reference_files['ifuslicer']) + ifuslicer = IFUSlicerModel(reference_files["ifuslicer"]) models = [] ifuslicer_model = ifuslicer.model for slit in slits: slitdata = ifuslicer.data[slit] - slitdata_model = (get_slit_location_model(slitdata)).rename('slitdata_model') + slitdata_model = (get_slit_location_model(slitdata)).rename("slitdata_model") slicer_model = slitdata_model | ifuslicer_model msa_transform = slicer_model @@ -979,12 +1209,20 @@ def ifuslit_to_slicer(slits, reference_files, input_model): def slicer_to_msa(reference_files): """ - Transform from slicer coordinates to MSA entrance. + Transform from slicer coordinates to MSA entrance (the IFUFORE transform). - Applies the IFUFORE transform. + Parameters + ---------- + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'ifufore' reference file. + Returns + ------- + model : `~astropy.modeling.Model` + Transform from ``slicer`` frame to ``msa_frame``. """ - with IFUFOREModel(reference_files['ifufore']) as f: + with IFUFOREModel(reference_files["ifufore"]) as f: ifufore = f.model slicer2fore_mapping = Mapping((0, 1, 2, 2)) slicer2fore_mapping.inverse = Identity(3) @@ -996,7 +1234,7 @@ def slicer_to_msa(reference_files): def slit_to_msa(open_slits, msafile): """ - The transform from ``slit_frame`` to ``msa_frame``. + Create the transform from ``slit_frame`` to ``msa_frame``. Parameters ---------- @@ -1015,7 +1253,7 @@ def slit_to_msa(open_slits, msafile): slits = [] for quadrant in range(1, 6): slits_in_quadrant = [s for s in open_slits if s.quadrant == quadrant] - msa_quadrant = getattr(msa, 'Q{0}'.format(quadrant)) + msa_quadrant = getattr(msa, f"Q{quadrant}") if any(slits_in_quadrant): msa_data = msa_quadrant.data @@ -1038,28 +1276,27 @@ def slit_to_msa(open_slits, msafile): def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range): """ - The transform from ``gwa`` to ``slit_frame``. + Create the transform from ``gwa`` to ``slit_frame``. Parameters ---------- slits : list A list of slit IDs for all IFU slits 0-29. + input_model : JwstDataModel + The input data model. disperser : `~jwst.datamodels.DisperserModel` A disperser model with the GWA correction applied to it. - filter : str - The filter used. - grating : str - The grating used in the observation. - reference_files: dict - Dictionary with reference files returned by CRDS. - slit_y_range : list or tuple of size 2 - The lower and upper bounds of a slit. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'ifufore' reference file. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. Returns ------- model : `~stdatamodels.jwst.transforms.Gwa2Slit` model. Transform from ``gwa`` frame to ``slit_frame``. - """ + """ ymin, ymax = slit_y_range agreq = angle_from_disperser(disperser, input_model) @@ -1073,33 +1310,41 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) lgreq = lgreq | velocity_corr - log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value)) + log.info( + f"Applied Barycentric velocity correction : {velocity_corr[1].amplitude.value}" + ) # The wavelength units up to this point are # meters as required by the pipeline but the desired output wavelength units is microns. # So we are going to Scale the spectral units by 1e6 (meters -> microns) - is_lamp_exposure = input_model.meta.exposure.type in ['NRS_LAMP', 'NRS_AUTOWAVE', 'NRS_AUTOFLAT'] - if input_model.meta.instrument.filter == 'OPAQUE' or is_lamp_exposure: + is_lamp_exposure = input_model.meta.exposure.type in [ + "NRS_LAMP", + "NRS_AUTOWAVE", + "NRS_AUTOFLAT", + ] + if input_model.meta.instrument.filter == "OPAQUE" or is_lamp_exposure: lgreq = lgreq | Scale(1e6) - lam_cen = 0.5 * (input_model.meta.wcsinfo.waverange_end - - input_model.meta.wcsinfo.waverange_start - ) + input_model.meta.wcsinfo.waverange_start + lam_cen = ( + 0.5 * (input_model.meta.wcsinfo.waverange_end - input_model.meta.wcsinfo.waverange_start) + + input_model.meta.wcsinfo.waverange_start + ) collimator2gwa = collimator_to_gwa(reference_files, disperser) mask = mask_slit(ymin, ymax) - ifuslicer = IFUSlicerModel(reference_files['ifuslicer']) - ifupost = IFUPostModel(reference_files['ifupost']) + ifuslicer = IFUSlicerModel(reference_files["ifuslicer"]) + ifupost = IFUPostModel(reference_files["ifupost"]) slit_models = [] ifuslicer_model = ifuslicer.model for slit in slits: slitdata = ifuslicer.data[slit] slitdata_model = get_slit_location_model(slitdata) - ifuslicer_transform = (slitdata_model | ifuslicer_model) - ifupost_sl = getattr(ifupost, "slice_{0}".format(slit)) + ifuslicer_transform = slitdata_model | ifuslicer_model + ifupost_sl = getattr(ifupost, f"slice_{slit}") # construct IFU post transform ifupost_transform = _create_ifupost_transform(ifupost_sl) msa2gwa = ifuslicer_transform & Const1D(lam_cen) | ifupost_transform | collimator2gwa - gwa2slit = gwa_to_ymsa(msa2gwa, lam_cen=lam_cen, slit_y_range=slit_y_range) # TODO: Use model sets here + # TODO: Use model sets here + gwa2slit = gwa_to_ymsa(msa2gwa, lam_cen=lam_cen, slit_y_range=slit_y_range) # The comments below list the input coordinates. bgwa2msa = ( @@ -1110,13 +1355,14 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) # ( 0, sy, alpha_in, beta_in, gamma_in, alpha_out, beta_out) # (0, sy, alpha_in, beta_in,alpha_out) # (0, sy, lambda_computed) - Mapping((0, 1, 0, 1), n_inputs=3) | - Const1D(0) * Identity(1) & Const1D(-1) * Identity(1) & Identity(2) | \ - Identity(1) & gwa2slit & Identity(2) | \ - Mapping((0, 1, 0, 1, 1, 2, 3)) | \ - Identity(2) & msa2gwa & Identity(2) | \ - Mapping((0, 1, 2, 3, 5), n_inputs=7) | \ - Identity(2) & lgreq | mask + Mapping((0, 1, 0, 1), n_inputs=3) + | Const1D(0) * Identity(1) & Const1D(-1) * Identity(1) & Identity(2) + | Identity(1) & gwa2slit & Identity(2) + | Mapping((0, 1, 0, 1, 1, 2, 3)) + | Identity(2) & msa2gwa & Identity(2) + | Mapping((0, 1, 2, 3, 5), n_inputs=7) + | Identity(2) & lgreq + | mask ) # transform from ``msa_frame`` to ``gwa`` frame (before the GWA going from detector to sky). @@ -1130,23 +1376,21 @@ def gwa_to_ifuslit(slits, input_model, disperser, reference_files, slit_y_range) return Gwa2Slit(slits, slit_models) -def gwa_to_slit(open_slits, input_model, disperser, - reference_files): +def gwa_to_slit(open_slits, input_model, disperser, reference_files): """ - The transform from ``gwa`` to ``slit_frame``. + Create the transform from ``gwa`` to ``slit_frame``. Parameters ---------- open_slits : list A list of slit IDs for all open shutters/slitlets. + input_model : JwstDataModel + The input data model. disperser : dict A corrected disperser ASDF object. - filter : str - The filter used. - grating : str - The grating used in the observation. - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'collimator' and 'msa' reference files. Returns ------- @@ -1165,22 +1409,28 @@ def gwa_to_slit(open_slits, input_model, disperser, if velosys is not None: velocity_corr = velocity_correction(input_model.meta.wcsinfo.velosys) lgreq = lgreq | velocity_corr - log.info("Applied Barycentric velocity correction : {}".format(velocity_corr[1].amplitude.value)) + log.info( + f"Applied Barycentric velocity correction : {velocity_corr[1].amplitude.value}" + ) # The wavelength units up to this point are # meters as required by the pipeline but the desired output wavelength units is microns. # So we are going to Scale the spectral units by 1e6 (meters -> microns) - is_lamp_exposure = input_model.meta.exposure.type in ['NRS_LAMP', 'NRS_AUTOWAVE', 'NRS_AUTOFLAT'] - if input_model.meta.instrument.filter == 'OPAQUE' or is_lamp_exposure: + is_lamp_exposure = input_model.meta.exposure.type in [ + "NRS_LAMP", + "NRS_AUTOWAVE", + "NRS_AUTOFLAT", + ] + if input_model.meta.instrument.filter == "OPAQUE" or is_lamp_exposure: lgreq = lgreq | Scale(1e6) - msa = MSAModel(reference_files['msa']) + msa = MSAModel(reference_files["msa"]) slit_models = [] slits = [] for quadrant in range(1, 6): slits_in_quadrant = [s for s in open_slits if s.quadrant == quadrant] - log.info("There are {0} open slits in quadrant {1}".format(len(slits_in_quadrant), quadrant)) - msa_quadrant = getattr(msa, 'Q{0}'.format(quadrant)) + log.info(f"There are {len(slits_in_quadrant)} open slits in quadrant {quadrant}") + msa_quadrant = getattr(msa, f"Q{quadrant}") if any(slits_in_quadrant): msa_model = msa_quadrant.model @@ -1196,14 +1446,20 @@ def gwa_to_slit(open_slits, input_model, disperser, slit_id -= 1 slitdata = msa_data[slit_id] slitdata_model = get_slit_location_model(slitdata) - msa_transform = (slitdata_model | msa_model) - msa2gwa = (msa_transform | collimator2gwa) - gwa2msa = gwa_to_ymsa(msa2gwa, slit=slit, slit_y_range=(slit.ymin, slit.ymax)) # TODO: Use model sets here - bgwa2msa = Mapping((0, 1, 0, 1), n_inputs=3) | \ - Const1D(0) * Identity(1) & Const1D(-1) * Identity(1) & Identity(2) | \ - Identity(1) & gwa2msa & Identity(2) | \ - Mapping((0, 1, 0, 1, 2, 3)) | Identity(2) & msa2gwa & Identity(2) | \ - Mapping((0, 1, 2, 3, 5), n_inputs=7) | Identity(2) & lgreq | mask + msa_transform = slitdata_model | msa_model + msa2gwa = msa_transform | collimator2gwa + # TODO: Use model sets here + gwa2msa = gwa_to_ymsa(msa2gwa, slit=slit, slit_y_range=(slit.ymin, slit.ymax)) + bgwa2msa = ( + Mapping((0, 1, 0, 1), n_inputs=3) + | Const1D(0) * Identity(1) & Const1D(-1) * Identity(1) & Identity(2) + | Identity(1) & gwa2msa & Identity(2) + | Mapping((0, 1, 0, 1, 2, 3)) + | Identity(2) & msa2gwa & Identity(2) + | Mapping((0, 1, 2, 3, 5), n_inputs=7) + | Identity(2) & lgreq + | mask + ) # Mapping((0, 1, 2, 5), n_inputs=7) | Identity(2) & lgreq | mask # and modify lgreq to accept alpha_in, beta_in, alpha_out # msa to before_gwa @@ -1217,61 +1473,96 @@ def gwa_to_slit(open_slits, input_model, disperser, def angle_from_disperser(disperser, input_model): """ + Figure out the angle from the disperser model. + For gratings this returns a form of the grating equation which computes the angle when lambda is known. - For prism data this returns the Snell model. + + Parameters + ---------- + disperser : dict + A corrected disperser ASDF object. + input_model : JwstDataModel + The input data model. + + Returns + ------- + model : `~astropy.modeling.Model`. + Transform from wavelength to angle. """ sporder = input_model.meta.wcsinfo.spectral_order - if input_model.meta.instrument.grating.lower() != 'prism': - agreq = AngleFromGratingEquation(disperser.groovedensity, - sporder, name='alpha_from_greq') + if input_model.meta.instrument.grating.lower() != "prism": + agreq = AngleFromGratingEquation(disperser.groovedensity, sporder, name="alpha_from_greq") return agreq - else: - system_temperature = input_model.meta.instrument.gwa_tilt - system_pressure = disperser['pref'] - snell = Snell(disperser['angle'], disperser['kcoef'], disperser['lcoef'], - disperser['tcoef'], disperser['tref'], disperser['pref'], - system_temperature, system_pressure, name="snell_law") - return snell + system_temperature = input_model.meta.instrument.gwa_tilt + system_pressure = disperser["pref"] + + snell = Snell( + disperser["angle"], + disperser["kcoef"], + disperser["lcoef"], + disperser["tcoef"], + disperser["tref"], + disperser["pref"], + system_temperature, + system_pressure, + name="snell_law", + ) + return snell def wavelength_from_disperser(disperser, input_model): """ + Figure out the wavelength from the disperser model. + For gratings this returns a form of the grating equation which computes lambda when all angles are known. For prism data this returns a lookup table model computing lambda from a known refraction index. + + Parameters + ---------- + disperser : dict + A corrected disperser ASDF object. + input_model : JwstDataModel + The input data model. + + Returns + ------- + model : `~astropy.modeling.Model`. + Transform from angle to wavelength """ sporder = input_model.meta.wcsinfo.spectral_order - if input_model.meta.instrument.grating.lower() != 'prism': - lgreq = WavelengthFromGratingEquation(disperser.groovedensity, - sporder, name='lambda_from_gratingeq') + if input_model.meta.instrument.grating.lower() != "prism": + lgreq = WavelengthFromGratingEquation( + disperser.groovedensity, sporder, name="lambda_from_gratingeq" + ) return lgreq - else: - lam = np.arange(0.5, 6.005, 0.005) * 1e-6 - system_temperature = input_model.meta.instrument.gwa_tilt - if system_temperature is None: - message = "Missing reference temperature (keyword GWA_TILT)." - log.critical(message) - raise KeyError(message) - system_pressure = disperser['pref'] - tref = disperser['tref'] - pref = disperser['pref'] - kcoef = disperser['kcoef'][:] - lcoef = disperser['lcoef'][:] - tcoef = disperser['tcoef'][:] - n = Snell.compute_refraction_index(lam, system_temperature, tref, pref, - system_pressure, kcoef, lcoef, tcoef - ) - n = np.flipud(n) - lam = np.flipud(lam) - n_from_prism = RefractionIndexFromPrism(disperser['angle'], name='n_prism') - - tab = Tabular1D(points=(n,), lookup_table=lam, bounds_error=False) - return n_from_prism | tab + + lam = np.arange(0.5, 6.005, 0.005) * 1e-6 + system_temperature = input_model.meta.instrument.gwa_tilt + if system_temperature is None: + message = "Missing reference temperature (keyword GWA_TILT)." + log.critical(message) + raise KeyError(message) + system_pressure = disperser["pref"] + tref = disperser["tref"] + pref = disperser["pref"] + kcoef = disperser["kcoef"][:] + lcoef = disperser["lcoef"][:] + tcoef = disperser["tcoef"][:] + n = Snell.compute_refraction_index( + lam, system_temperature, tref, pref, system_pressure, kcoef, lcoef, tcoef + ) + n = np.flipud(n) + lam = np.flipud(lam) + n_from_prism = RefractionIndexFromPrism(disperser["angle"], name="n_prism") + + tab = Tabular1D(points=(n,), lookup_table=lam, bounds_error=False) + return n_from_prism | tab def detector_to_gwa(reference_files, detector, disperser): @@ -1280,8 +1571,9 @@ def detector_to_gwa(reference_files, detector, disperser): Parameters ---------- - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'fpa' and 'camera' reference files. detector : str The detector keyword. disperser : dict @@ -1291,19 +1583,17 @@ def detector_to_gwa(reference_files, detector, disperser): ------- model : `~astropy.modeling.core.Model` model. Transform from DETECTOR frame to GWA frame. - """ - with FPAModel(reference_files['fpa']) as f: - fpa = getattr(f, detector.lower() + '_model') - with CameraModel(reference_files['camera']) as f: + with FPAModel(reference_files["fpa"]) as f: + fpa = getattr(f, detector.lower() + "_model") + with CameraModel(reference_files["camera"]) as f: camera = f.model - angles = [disperser['theta_x'], disperser['theta_y'], - disperser['theta_z'], disperser['tilt_y']] - rotation = Rotation3DToGWA(angles, axes_order="xyzy", name='rotation') - u2dircos = Unitless2DirCos(name='unitless2directional_cosines') + angles = [disperser["theta_x"], disperser["theta_y"], disperser["theta_z"], disperser["tilt_y"]] + rotation = Rotation3DToGWA(angles, axes_order="xyzy", name="rotation") + u2dircos = Unitless2DirCos(name="unitless2directional_cosines") # NIRSPEC 1- vs 0- based pixel coordinates issue #1781 - ''' + """ The pipeline works with 0-based pixel coordinates. The Nirspec model, stored in reference files, is also 0-based. However, the algorithm specified by the IDT team specifies that pixel coordinates are 1-based. This is @@ -1322,14 +1612,24 @@ def detector_to_gwa(reference_files, detector, disperser): model = models.Shift(1) & models.Shift(1) | \ models.Shift(-1) & models.Shift(-1) | fpa | camera | u2dircos | rotation - ''' + """ model = fpa | camera | u2dircos | rotation return model def dms_to_sca(input_model): """ - Transforms from ``detector`` to ``sca`` coordinates. + Transform from ``detector`` to ``sca`` coordinates. + + Parameters + ---------- + input_model : JwstDataModel + The input data model. + + Returns + ------- + model : `~astropy.modeling.core.Model` model. + Transform from DMS frame to SCA frame. """ detector = input_model.meta.instrument.detector xstart = input_model.meta.subarray.xstart @@ -1344,42 +1644,46 @@ def dms_to_sca(input_model): # If xstart was 0-based and the inputs were 0-based -> # Shift(+1) subarray2full = models.Shift(xstart - 1) & models.Shift(ystart - 1) - if detector == 'NRS2': + if detector == "NRS2": model = models.Shift(-2047) & models.Shift(-2047) | models.Scale(-1) & models.Scale(-1) - elif detector == 'NRS1': + elif detector == "NRS1": model = models.Identity(2) return subarray2full | model -def mask_slit(ymin=-.55, ymax=.55): +def mask_slit(ymin=-0.55, ymax=0.55): """ - Returns a model which masks out pixels in a NIRSpec cutout outside the slit. + Return a model which masks out pixels in a NIRSpec cutout outside the slit. Uses ymin, ymax for the slit and the wavelength range to define the location of the slit. Parameters ---------- ymin, ymax : float - ymin and ymax relative boundary of a slit. + The relative min, max boundary of a slit. Returns ------- model : `~astropy.modeling.core.Model` A model which takes x_slit, y_slit, lam inputs and substitutes the values outside the slit with NaN. - """ - greater_than_ymax = Logical(condition='GT', compareto=ymax, value=np.nan) - less_than_ymin = Logical(condition='LT', compareto=ymin, value=np.nan) - - model = Mapping((0, 1, 2, 1)) | Identity(3) & (greater_than_ymax | less_than_ymin | models.Scale(0)) | \ - Mapping((0, 1, 3, 2, 3)) | Identity(1) & Mapping((0,), n_inputs=2) + Mapping((1,)) & \ - Mapping((0,), n_inputs=2) + Mapping((1,)) + greater_than_ymax = Logical(condition="GT", compareto=ymax, value=np.nan) + less_than_ymin = Logical(condition="LT", compareto=ymin, value=np.nan) + + model = ( + Mapping((0, 1, 2, 1)) + | Identity(3) & (greater_than_ymax | less_than_ymin | models.Scale(0)) + | Mapping((0, 1, 3, 2, 3)) + | Identity(1) + & Mapping((0,), n_inputs=2) + Mapping((1,)) + & Mapping((0,), n_inputs=2) + Mapping((1,)) + ) model.inverse = Identity(3) return model -def compute_bounding_box(transform, wavelength_range, slit_ymin=-.55, slit_ymax=.55): +def compute_bounding_box(transform, wavelength_range, slit_ymin=-0.55, slit_ymax=0.55): """ Compute the bounding box of the projection of a slit/slice on the detector. @@ -1396,9 +1700,14 @@ def compute_bounding_box(transform, wavelength_range, slit_ymin=-.55, slit_ymax= `nrs_wcs_set_input` uses "detector to slit", validate_open_slits uses "slit to detector". wavelength_range : tuple The wavelength range for the combination of grating and filter. + slit_ymin, slit_ymax : float + The lower and upper bounds of the slit. + Returns + ------- + bbox : tuple + The bounding box of the projection of the slit on the detector. """ - # If transform has inverse then it must be slit to detector if transform.has_inverse(): slit2detector = transform.inverse @@ -1419,11 +1728,9 @@ def bbox_from_range(x_range, y_range): # The -1 on both is technically because the output of slit2detector is 1-based coordinates. # add 10 px margin - pad_x = (max(0, x_range.min() - 1 - 10) - 0.5, - min(2047, x_range.max() - 1 + 10) + 0.5) + pad_x = (max(0, x_range.min() - 1 - 10) - 0.5, min(2047, x_range.max() - 1 + 10) + 0.5) # add 2 px margin - pad_y = (max(0, y_range.min() - 1 - 2) - 0.5, - min(2047, y_range.max() - 1 + 2) + 0.5) + pad_y = (max(0, y_range.min() - 1 - 2) - 0.5, min(2047, y_range.max() - 1 + 2) + 0.5) return pad_x, pad_y @@ -1457,8 +1764,9 @@ def collimator_to_gwa(reference_files, disperser): Parameters ---------- - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'collimator' reference file. disperser : dict A corrected disperser ASDF object. @@ -1466,27 +1774,24 @@ def collimator_to_gwa(reference_files, disperser): ------- model : `~astropy.modeling.core.Model` model. Transform from collimator to ``gwa`` frame. - """ - with CollimatorModel(reference_files['collimator']) as f: + with CollimatorModel(reference_files["collimator"]) as f: collimator = f.model - angles = [disperser['theta_x'], disperser['theta_y'], - disperser['theta_z'], disperser['tilt_y']] - rotation = Rotation3DToGWA(angles, axes_order="xyzy", name='rotation') - u2dircos = Unitless2DirCos(name='unitless2directional_cosines') + angles = [disperser["theta_x"], disperser["theta_y"], disperser["theta_z"], disperser["tilt_y"]] + rotation = Rotation3DToGWA(angles, axes_order="xyzy", name="rotation") + u2dircos = Unitless2DirCos(name="unitless2directional_cosines") return collimator.inverse | u2dircos | rotation def get_disperser(input_model, disperserfile): """ - Return the disperser data model with the GWA - correction applied. + Return the disperser data model with the GWA correction applied. Parameters ---------- - input_model : `jwst.datamodels.JwstDataModel` - The input data model - either an ImageModel or a CubeModel. + input_model : JwstDataModel + The input data model. disperserfile : str The name of the disperser reference file. @@ -1508,47 +1813,47 @@ def correct_tilt(disperser, xtilt, ytilt): Parameters ---------- + disperser : `~jwst.datamodels.DisperserModel` + Disperser information. xtilt : float Value of GWAXTILT keyword - angle in arcsec ytilt : float Value of GWAYTILT keyword - angle in arcsec - disperser : `~jwst.datamodels.DisperserModel` - Disperser information. - - Notes - ----- - The GWA_XTILT keyword is used to correct the THETA_Y angle. - The GWA_YTILT keyword is used to correct the THETA_X angle. Returns ------- disp : `~jwst.datamodels.DisperserModel` Corrected DisperserModel. + Notes + ----- + The GWA_XTILT keyword is used to correct the THETA_Y angle. + The GWA_YTILT keyword is used to correct the THETA_X angle. """ + def _get_correction(gwa_tilt, tilt_angle): phi_exposure = gwa_tilt.tilt_model(tilt_angle) phi_calibrator = gwa_tilt.tilt_model(gwa_tilt.zeroreadings[0]) - del_theta = 0.5 * (phi_exposure - phi_calibrator) / 3600. # in deg + del_theta = 0.5 * (phi_exposure - phi_calibrator) / 3600.0 # in deg return del_theta disp = disperser.copy() disperser.close() - log.info("gwa_ytilt is {0} deg".format(ytilt)) - log.info("gwa_xtilt is {0} deg".format(xtilt)) + log.info(f"gwa_ytilt is {ytilt} deg") + log.info("gwa_xtilt is {xtilt} deg") if xtilt is not None: theta_y_correction = _get_correction(disp.gwa_tiltx, xtilt) - log.info('theta_y correction: {0} deg'.format(theta_y_correction)) - disp['theta_y'] = disp.theta_y + theta_y_correction + log.info(f"theta_y correction: {theta_y_correction} deg") + disp["theta_y"] = disp.theta_y + theta_y_correction else: - log.info('gwa_xtilt not applied') + log.info("gwa_xtilt not applied") if ytilt is not None: theta_x_correction = _get_correction(disp.gwa_tilty, ytilt) - log.info('theta_x correction: {0} deg'.format(theta_x_correction)) + log.info("theta_x correction: {theta_x_correction} deg") disp.theta_x = disp.theta_x + theta_x_correction else: - log.info('gwa_ytilt not applied') + log.info("gwa_ytilt not applied") return disp @@ -1558,19 +1863,20 @@ def ifu_msa_to_oteip(reference_files): Parameters ---------- - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'fore' reference file. Returns ------- model : `~astropy.modeling.core.Model` model. Transform from MSA to OTEIP. """ - with FOREModel(reference_files['fore']) as f: + with FOREModel(reference_files["fore"]) as f: fore = f.model - msa2fore_mapping = Mapping((0, 1, 2, 2), name='msa2fore_mapping') - msa2fore_mapping.inverse = Mapping((0, 1, 2, 2), name='fore2msa') + msa2fore_mapping = Mapping((0, 1, 2, 2), name="msa2fore_mapping") + msa2fore_mapping.inverse = Mapping((0, 1, 2, 2), name="fore2msa") fore_transform = msa2fore_mapping | fore & Identity(1) return fore_transform @@ -1581,41 +1887,41 @@ def msa_to_oteip(reference_files): Parameters ---------- - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'fore' reference file. Returns ------- model : `~astropy.modeling.core.Model` model. Transform from MSA to OTEIP. - """ - with FOREModel(reference_files['fore']) as f: + with FOREModel(reference_files["fore"]) as f: fore = f.model - msa2fore_mapping = Mapping((0, 1, 2, 2), name='msa2fore_mapping') + msa2fore_mapping = Mapping((0, 1, 2, 2), name="msa2fore_mapping") msa2fore_mapping.inverse = Identity(3) return msa2fore_mapping | (fore & Identity(1)) -def oteip_to_v23(reference_files, input_model): +def oteip_to_v23(reference_files): """ Transform from ``oteip`` frame to ``v2v3`` frame. Parameters ---------- - reference_files: dict - Dictionary with reference files returned by CRDS. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'ote' reference file. Returns ------- model : `~astropy.modeling.core.Model` model. Transform from ``oteip`` to ``v2v3`` frame. - """ - with OTEModel(reference_files['ote']) as f: + with OTEModel(reference_files["ote"]) as f: ote = f.model - fore2ote_mapping = Identity(3, name='fore2ote_mapping') + fore2ote_mapping = Identity(3, name="fore2ote_mapping") fore2ote_mapping.inverse = Mapping((0, 1, 2, 2)) # Create the transform to v2/v3/lambda. The wavelength units up to this point are # meters as required by the pipeline but the desired output wavelength units is microns. @@ -1630,66 +1936,85 @@ def create_frames(): """ Create the coordinate frames in the NIRSPEC WCS pipeline. - These are - "detector", "gwa", "slit_frame", "msa_frame", "oteip", "v2v3", "world". - """ - det = cf.Frame2D(name='detector', axes_order=(0, 1)) - sca = cf.Frame2D(name='sca', axes_order=(0, 1)) - gwa = cf.Frame2D(name="gwa", axes_order=(0, 1), unit=(u.rad, u.rad), - axes_names=('alpha_in', 'beta_in')) - msa_spatial = cf.Frame2D(name='msa_spatial', axes_order=(0, 1), unit=(u.m, u.m), - axes_names=('x_msa', 'y_msa')) - slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), unit=("", ""), - axes_names=('x_slit', 'y_slit')) - slicer_spatial = cf.Frame2D(name='slicer_spatial', axes_order=(0, 1), unit=("", ""), - axes_names=('x_slicer', 'y_slicer')) - sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), reference_frame=coord.ICRS()) - v2v3_spatial = cf.Frame2D(name='v2v3_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3')) - v2v3vacorr_spatial = cf.Frame2D(name='v2v3vacorr_spatial', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3')) + Returns + ------- + det, sca, gwa, slit_frame, slicer_frame, msa_frame, oteip, v2v3, v2v3vacorr, world : tuple + The coordinate frames. Each is a `~gwcs.coordinate_frames.CoordinateFrame` object. + """ + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + sca = cf.Frame2D(name="sca", axes_order=(0, 1)) + gwa = cf.Frame2D( + name="gwa", axes_order=(0, 1), unit=(u.rad, u.rad), axes_names=("alpha_in", "beta_in") + ) + msa_spatial = cf.Frame2D( + name="msa_spatial", axes_order=(0, 1), unit=(u.m, u.m), axes_names=("x_msa", "y_msa") + ) + slit_spatial = cf.Frame2D( + name="slit_spatial", axes_order=(0, 1), unit=("", ""), axes_names=("x_slit", "y_slit") + ) + slicer_spatial = cf.Frame2D( + name="slicer_spatial", axes_order=(0, 1), unit=("", ""), axes_names=("x_slicer", "y_slicer") + ) + sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) + v2v3_spatial = cf.Frame2D( + name="v2v3_spatial", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=("v2", "v3") + ) + v2v3vacorr_spatial = cf.Frame2D( + name="v2v3vacorr_spatial", + axes_order=(0, 1), + unit=(u.arcsec, u.arcsec), + axes_names=("v2", "v3"), + ) # The oteip_to_v23 incorporates a scale to convert the spectral units from # meters to microns. So the v2v3 output frame will be in u.deg, u.deg, u.micron - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), unit=(u.micron,), - axes_names=('wavelength',)) - v2v3 = cf.CompositeFrame([v2v3_spatial, spec], name='v2v3') - v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name='v2v3vacorr') - slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame') - slicer_frame = cf.CompositeFrame([slicer_spatial, spec], name='slicer') - msa_frame = cf.CompositeFrame([msa_spatial, spec], name='msa_frame') - oteip_spatial = cf.Frame2D(name='oteip', axes_order=(0, 1), unit=(u.deg, u.deg), - axes_names=('X_OTEIP', 'Y_OTEIP')) - oteip = cf.CompositeFrame([oteip_spatial, spec], name='oteip') - world = cf.CompositeFrame([sky, spec], name='world') + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + v2v3 = cf.CompositeFrame([v2v3_spatial, spec], name="v2v3") + v2v3vacorr = cf.CompositeFrame([v2v3vacorr_spatial, spec], name="v2v3vacorr") + slit_frame = cf.CompositeFrame([slit_spatial, spec], name="slit_frame") + slicer_frame = cf.CompositeFrame([slicer_spatial, spec], name="slicer") + msa_frame = cf.CompositeFrame([msa_spatial, spec], name="msa_frame") + oteip_spatial = cf.Frame2D( + name="oteip", axes_order=(0, 1), unit=(u.deg, u.deg), axes_names=("X_OTEIP", "Y_OTEIP") + ) + oteip = cf.CompositeFrame([oteip_spatial, spec], name="oteip") + world = cf.CompositeFrame([sky, spec], name="world") return det, sca, gwa, slit_frame, slicer_frame, msa_frame, oteip, v2v3, v2v3vacorr, world def create_imaging_frames(): """ Create the coordinate frames in the NIRSPEC WCS pipeline. - These are: "detector", "gwa", "msa_frame", "oteip", "v2v3", "v2v3vacorr", - and "world". - """ - det = cf.Frame2D(name='detector', axes_order=(0, 1)) - sca = cf.Frame2D(name='sca', axes_order=(0, 1)) - gwa = cf.Frame2D(name="gwa", axes_order=(0, 1), unit=(u.rad, u.rad), - axes_names=('alpha_in', 'beta_in')) - msa = cf.Frame2D(name='msa', axes_order=(0, 1), unit=(u.m, u.m), - axes_names=('x_msa', 'y_msa')) - v2v3 = cf.Frame2D(name='v2v3', axes_order=(0, 1), unit=(u.arcsec, u.arcsec), - axes_names=('v2', 'v3')) - v2v3vacorr = cf.Frame2D(name='v2v3vacorr', axes_order=(0, 1), - unit=(u.arcsec, u.arcsec), axes_names=('v2', 'v3')) - oteip = cf.Frame2D(name='oteip', axes_order=(0, 1), unit=(u.deg, u.deg), - axes_names=('x_oteip', 'y_oteip')) - world = cf.CelestialFrame(name='world', axes_order=(0, 1), reference_frame=coord.ICRS()) + + Returns + ------- + det, sca, gwa, msa, oteip, v2v3, v2v3vacorr, world : tuple + The coordinate frames. Each is a `~gwcs.coordinate_frames.CoordinateFrame` object. + """ + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + sca = cf.Frame2D(name="sca", axes_order=(0, 1)) + gwa = cf.Frame2D( + name="gwa", axes_order=(0, 1), unit=(u.rad, u.rad), axes_names=("alpha_in", "beta_in") + ) + msa = cf.Frame2D(name="msa", axes_order=(0, 1), unit=(u.m, u.m), axes_names=("x_msa", "y_msa")) + v2v3 = cf.Frame2D( + name="v2v3", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=("v2", "v3") + ) + v2v3vacorr = cf.Frame2D( + name="v2v3vacorr", axes_order=(0, 1), unit=(u.arcsec, u.arcsec), axes_names=("v2", "v3") + ) + oteip = cf.Frame2D( + name="oteip", axes_order=(0, 1), unit=(u.deg, u.deg), axes_names=("x_oteip", "y_oteip") + ) + world = cf.CelestialFrame(name="world", axes_order=(0, 1), reference_frame=coord.ICRS()) return det, sca, gwa, msa, oteip, v2v3, v2v3vacorr, world def get_slit_location_model(slitdata): """ - The transform for the absolute position of a slit on the MSA. + Create the transform for the absolute position of a slit on the MSA. Parameters ---------- @@ -1702,13 +2027,14 @@ def get_slit_location_model(slitdata): ------- model : `~astropy.modeling.core.Model` model. A model which transforms relative position on the slit to - absolute positions in the quadrant.. + absolute positions in the quadrant. This is later combined with the quadrant model to return absolute positions in the MSA. """ num, xcenter, ycenter, xsize, ysize = slitdata - model = models.Scale(xsize) & models.Scale(ysize) | \ - models.Shift(xcenter) & models.Shift(ycenter) + model = models.Scale(xsize) & models.Scale(ysize) | models.Shift(xcenter) & models.Shift( + ycenter + ) return model @@ -1724,9 +2050,14 @@ def gwa_to_ymsa(msa2gwa_model, lam_cen=None, slit=None, slit_y_range=None): Central wavelength in meters. slit : `~stdatamodels.jwst.transforms.models.Slit` A Fixed slit or MOS slitlet. - slit_y_range: list or tuple of size 2 - The lower and upper limit of the slit. + slit_y_range : tuple or None + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. Used for IFU mode only. + + Returns + ------- + tab : `~astropy.modeling.Tabular1D` + A 1D lookup table model. """ nstep = 1000 if slit is not None: @@ -1743,13 +2074,11 @@ def gwa_to_ymsa(msa2gwa_model, lam_cen=None, slit=None, slit_y_range=None): cosin_grating_k = msa2gwa_model(dx, dy) beta_in = cosin_grating_k[1] - tab = Tabular1D(points=(beta_in,), - lookup_table=dy, bounds_error=False, name='tabular') + tab = Tabular1D(points=(beta_in,), lookup_table=dy, bounds_error=False, name="tabular") return tab def _get_transforms(input_model, slitnames, return_slits=False): - """ Return a WCS object with necessary transforms for all slits. @@ -1760,9 +2089,8 @@ def _get_transforms(input_model, slitnames, return_slits=False): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` - A data model with a WCS object for the all open slitlets in - an observation. + input_model : JwstDataModel + A data model with a WCS object for the all open slitlets in an observation. slitnames : list of int or str Slit.name of all open slits. return_slits : bool, optional @@ -1779,42 +2107,49 @@ def _get_transforms(input_model, slitnames, return_slits=False): slit2slicer : list of `~astropy.modeling.core.Model` Transform from ``slit_frame`` to ``slicer`` for each input slit open_slits : list of `~stdatamodels.jwst.transforms.models.Slit` - open slits from wcs.get_transform('gwa', 'slit_frame').slits + Open slits from wcs.get_transform('gwa', 'slit_frame').slits Only returned if return_slits is True """ - wcs = copy.deepcopy(input_model.meta.wcs) sca2gwa = copy.deepcopy(wcs.pipeline[1].transform[1:]) - wcs.set_transform('sca', 'gwa', sca2gwa) + wcs.set_transform("sca", "gwa", sca2gwa) - gwa2slit = [copy.deepcopy(wcs.pipeline[2].transform.get_model(slit_name)) - for slit_name in slitnames] + gwa2slit = [ + copy.deepcopy(wcs.pipeline[2].transform.get_model(slit_name)) for slit_name in slitnames + ] - slit2slicer = [copy.deepcopy(wcs.pipeline[3].transform.get_model(slit_name)) - for slit_name in slitnames] + slit2slicer = [ + copy.deepcopy(wcs.pipeline[3].transform.get_model(slit_name)) for slit_name in slitnames + ] if return_slits: - g2s = wcs.get_transform('gwa', 'slit_frame') + g2s = wcs.get_transform("gwa", "slit_frame") open_slits = g2s.slits return wcs, sca2gwa, gwa2slit, slit2slicer, copy.deepcopy(open_slits) else: return wcs, sca2gwa, gwa2slit, slit2slicer -def _nrs_wcs_set_input_lite(input_model, input_wcs, slit_name, transforms, - wavelength_range=None, open_slits=None, - slit_y_low=None, slit_y_high=None): - +def _nrs_wcs_set_input_lite( + input_model, + input_wcs, + slit_name, + transforms, + wavelength_range=None, + open_slits=None, + slit_y_low=None, + slit_y_high=None, +): """ - Return a WCS object for a specific slit, slice or shutter + Return a WCS object for a specific slit, slice or shutter. The lite version of the routine is distinguished from the legacy routine because it does not make a deep copy of the input WCS object. Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel A WCS object for the all open slitlets in an observation. input_wcs : `~gwcs.wcs.WCS` A WCS object for the all open slitlets in an observation. This @@ -1823,10 +2158,12 @@ def _nrs_wcs_set_input_lite(input_model, input_wcs, slit_name, transforms, Slit.name of an open slit. transforms : list of `~astropy.modeling.core.Model` Model transforms output from ``_get_transforms`` - wavelength_range: list + wavelength_range : list Wavelength range for the combination of filter and grating. Optional. open_slits : list of slits List of open slits. Optional. + slit_y_low, slit_y_high : float + The lower and upper bounds of the slit. Optional. Returns ------- @@ -1834,10 +2171,9 @@ def _nrs_wcs_set_input_lite(input_model, input_wcs, slit_name, transforms, WCS object for this slit. """ - - def _get_y_range(input_model, open_slits): + def _get_y_range(open_slits): if open_slits is None: - log_message = 'nrs_wcs_set_input_lite must be called with open_slits if not in ifu mode' + log_message = "nrs_wcs_set_input_lite must be called with open_slits if not in ifu mode" log.critical(log_message) raise RuntimeError(log_message) # Need the open slits to get the slit ymin,ymax @@ -1849,25 +2185,28 @@ def _get_y_range(input_model, open_slits): slit_wcs = copy.copy(input_wcs) - slit_wcs.set_transform('sca', 'gwa', transforms[0]) - slit_wcs.set_transform('gwa', 'slit_frame', transforms[1]) + slit_wcs.set_transform("sca", "gwa", transforms[0]) + slit_wcs.set_transform("gwa", "slit_frame", transforms[1]) - is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == 'nrs_ifu' + is_nirspec_ifu = ( + is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == "nrs_ifu" + ) if is_nirspec_ifu: - slit_wcs.set_transform('slit_frame', 'slicer', transforms[2] & Identity(1)) + slit_wcs.set_transform("slit_frame", "slicer", transforms[2] & Identity(1)) else: - slit_wcs.set_transform('slit_frame', 'msa_frame', transforms[2] & Identity(1)) + slit_wcs.set_transform("slit_frame", "msa_frame", transforms[2] & Identity(1)) - transform = slit_wcs.get_transform('detector', 'slit_frame') + transform = slit_wcs.get_transform("detector", "slit_frame") if is_nirspec_ifu: bb = compute_bounding_box(transform, wavelength_range) else: if slit_y_low is None or slit_y_high is None: - slit_y_low, slit_y_high = _get_y_range(input_model, open_slits) - bb = compute_bounding_box(transform, wavelength_range, - slit_ymin=slit_y_low, slit_ymax=slit_y_high) + slit_y_low, slit_y_high = _get_y_range(open_slits) + bb = compute_bounding_box( + transform, wavelength_range, slit_ymin=slit_y_low, slit_ymax=slit_y_high + ) slit_wcs.bounding_box = bb return slit_wcs @@ -1875,12 +2214,13 @@ def _get_y_range(input_model, open_slits): def _nrs_wcs_set_input(input_model, slit_name): """ - Returns a WCS object for a specific slit, slice or shutter. + Return a WCS object for a specific slit, slice or shutter. + Does not compute the bounding box. Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel A WCS object for the all open slitlets in an observation. slit_name : int or str Slit.name of an open slit. @@ -1890,48 +2230,55 @@ def _nrs_wcs_set_input(input_model, slit_name): wcsobj : `~gwcs.wcs.WCS` WCS object for this slit. """ - wcsobj = input_model.meta.wcs slit_wcs = copy.deepcopy(wcsobj) - slit_wcs.set_transform('sca', 'gwa', wcsobj.pipeline[1].transform[1:]) + slit_wcs.set_transform("sca", "gwa", wcsobj.pipeline[1].transform[1:]) g2s = slit_wcs.pipeline[2].transform - slit_wcs.set_transform('gwa', 'slit_frame', g2s.get_model(slit_name)) + slit_wcs.set_transform("gwa", "slit_frame", g2s.get_model(slit_name)) exp_type = input_model.meta.exposure.type - is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or (exp_type.lower() == 'nrs_ifu') + is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or (exp_type.lower() == "nrs_ifu") if is_nirspec_ifu: - slit_wcs.set_transform('slit_frame', 'slicer', - wcsobj.pipeline[3].transform.get_model(slit_name) & Identity(1)) + slit_wcs.set_transform( + "slit_frame", "slicer", wcsobj.pipeline[3].transform.get_model(slit_name) & Identity(1) + ) else: - slit_wcs.set_transform('slit_frame', 'msa_frame', - wcsobj.pipeline[3].transform.get_model(slit_name) & Identity(1)) + slit_wcs.set_transform( + "slit_frame", + "msa_frame", + wcsobj.pipeline[3].transform.get_model(slit_name) & Identity(1), + ) return slit_wcs -def nrs_wcs_set_input(input_model, slit_name, wavelength_range=None, - slit_y_low=None, slit_y_high=None): +def nrs_wcs_set_input( + input_model, slit_name, wavelength_range=None, slit_y_low=None, slit_y_high=None +): """ - Returns a WCS object for a specific slit, slice or shutter. + Return a WCS object for a specific slit, slice or shutter. Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel A WCS object for the all open slitlets in an observation. slit_name : int or str Slit.name of an open slit. - wavelength_range: list + wavelength_range : list Wavelength range for the combination of filter and grating. + slit_y_low, slit_y_high : float + The lower and upper bounds of the slit. Optional. Returns ------- wcsobj : `~gwcs.wcs.WCS` WCS object for this slit. """ + def _get_y_range(input_model): # get the open slits from the model # Need them to get the slit ymin,ymax - g2s = input_model.meta.wcs.get_transform('gwa', 'slit_frame') + g2s = input_model.meta.wcs.get_transform("gwa", "slit_frame") open_slits = g2s.slits slit = [s for s in open_slits if s.name == slit_name][0] return slit.ymin, slit.ymax @@ -1940,15 +2287,18 @@ def _get_y_range(input_model): _, wavelength_range = spectral_order_wrange_from_model(input_model) slit_wcs = _nrs_wcs_set_input(input_model, slit_name) - transform = slit_wcs.get_transform('detector', 'slit_frame') - is_nirspec_ifu = is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == 'nrs_ifu' + transform = slit_wcs.get_transform("detector", "slit_frame") + is_nirspec_ifu = ( + is_nrs_ifu_lamp(input_model) or input_model.meta.exposure.type.lower() == "nrs_ifu" + ) if is_nirspec_ifu: bb = compute_bounding_box(transform, wavelength_range) else: if slit_y_low is None or slit_y_high is None: slit_y_low, slit_y_high = _get_y_range(input_model) - bb = compute_bounding_box(transform, wavelength_range, - slit_ymin=slit_y_low, slit_ymax=slit_y_high) + bb = compute_bounding_box( + transform, wavelength_range, slit_ymin=slit_y_low, slit_ymax=slit_y_high + ) slit_wcs.bounding_box = bb return slit_wcs @@ -1957,56 +2307,64 @@ def _get_y_range(input_model): def validate_open_slits(input_model, open_slits, reference_files): """ Remove slits which do not project on the detector from the list of open slits. + For each slit computes the transform from the slit to the detector and determines the bounding box. Parameters ---------- - input_model : jwst.datamodels.JwstDataModel - Input data model + input_model : JwstDataModel + The input data model. + open_slits : list + List of open slits. + reference_files : dict + Mapping between reftype (keys) and reference file name (vals). + Requires the 'disperser', 'wavelengthrange', 'msa', 'collimator', + 'fpa', and 'camera' reference files Returns ------- - slit2det : dict - A dictionary with the slit to detector transform for each slit, - {slit_id: astropy.modeling.Model} + open_slits : list + List of open slits that project onto the detector. """ def _is_valid_slit(domain): xlow, xhigh = domain[0] ylow, yhigh = domain[1] - if (xlow >= 2048 or ylow >= 2048 or - xhigh <= 0 or yhigh <= 0 or - xhigh - xlow < 2 or yhigh - ylow < 1): + if ( + xlow >= 2048 + or ylow >= 2048 + or xhigh <= 0 + or yhigh <= 0 + or xhigh - xlow < 2 + or yhigh - ylow < 1 + ): return False else: return True det2dms = dms_to_sca(input_model).inverse # read models from reference file - disperser = DisperserModel(reference_files['disperser']) - disperser = correct_tilt(disperser, input_model.meta.instrument.gwa_xtilt, - input_model.meta.instrument.gwa_ytilt) + disperser = DisperserModel(reference_files["disperser"]) + disperser = correct_tilt( + disperser, input_model.meta.instrument.gwa_xtilt, input_model.meta.instrument.gwa_ytilt + ) - order, wrange = get_spectral_order_wrange(input_model, - reference_files['wavelengthrange']) + order, wrange = get_spectral_order_wrange(input_model, reference_files["wavelengthrange"]) input_model.meta.wcsinfo.waverange_start = wrange[0] input_model.meta.wcsinfo.waverange_end = wrange[1] input_model.meta.wcsinfo.spectral_order = order agreq = angle_from_disperser(disperser, input_model) # GWA to detector - det2gwa = detector_to_gwa(reference_files, - input_model.meta.instrument.detector, - disperser) + det2gwa = detector_to_gwa(reference_files, input_model.meta.instrument.detector, disperser) gwa2det = det2gwa.inverse # collimator to GWA collimator2gwa = collimator_to_gwa(reference_files, disperser) - col2det = collimator2gwa & Identity(1) | Mapping((3, 0, 1, 2)) | agreq | \ - gwa2det | det2dms + col2det = collimator2gwa & Identity(1) | Mapping((3, 0, 1, 2)) | agreq | gwa2det | det2dms - slit2msa = slit_to_msa(open_slits, reference_files['msa']) + slit2msa = slit_to_msa(open_slits, reference_files["msa"]) for slit in slit2msa.slits: msa_transform = slit2msa.get_model(slit.name) @@ -2016,8 +2374,10 @@ def _is_valid_slit(domain): valid = _is_valid_slit(bb) if not valid: - log.info("Removing slit {0} from the list of open slits because the " - "WCS bounding_box is completely outside the detector.".format(slit.name)) + log.info( + f"Removing slit {slit.name} from the list of open slits because the " + "WCS bounding_box is completely outside the detector." + ) idx = np.nonzero([s.name == slit.name for s in open_slits])[0][0] open_slits.pop(idx) @@ -2030,9 +2390,15 @@ def spectral_order_wrange_from_model(input_model): Parameters ---------- - input_model : jwst.datamodels.JwstDataModel + input_model : JwstDataModel The data model. Must have been through the assign_wcs step. + Returns + ------- + spectral_order : int + The spectral order. + wrange : list + The wavelength range. """ wrange = [input_model.meta.wcsinfo.waverange_start, input_model.meta.wcsinfo.waverange_end] spectral_order = input_model.meta.wcsinfo.spectral_order @@ -2045,8 +2411,13 @@ def nrs_ifu_wcs(input_model): Parameters ---------- - input_model : jwst.datamodels.JwstDataModel + input_model : JwstDataModel The data model. Must have been through the assign_wcs step. + + Returns + ------- + wcs_list : list + A list of WCSs for all IFU slits. """ _, wrange = spectral_order_wrange_from_model(input_model) wcs_list = [] @@ -2065,6 +2436,10 @@ def _create_ifupost_transform(ifupost_slice): ifupost_slice : `jwst.datamodels.properties.ObjectNode` IFUPost transform for a specific slice + Returns + ------- + model : `~astropy.modeling.core.Model` model. + The transform for this slice. """ linear = ifupost_slice.linear polyx = ifupost_slice.xpoly @@ -2077,16 +2452,16 @@ def _create_ifupost_transform(ifupost_slice): # The wavelength dependent polynomial is # expressed as # poly_independent(x, y) + poly_dependent(x, y) * lambda - model_x = ((Mapping((0, 1), n_inputs=3) | polyx) + - ((Mapping((0, 1), n_inputs=3) | polyx_dist) * - (Mapping((2,)) | Identity(1)))) - model_y = ((Mapping((0, 1), n_inputs=3) | polyy) + - ((Mapping((0, 1), n_inputs=3) | polyy_dist) * - (Mapping((2,)) | Identity(1)))) - - output2poly_mapping = Identity(2, name="{0}_outmap".format('ifupost')) + model_x = (Mapping((0, 1), n_inputs=3) | polyx) + ( + (Mapping((0, 1), n_inputs=3) | polyx_dist) * (Mapping((2,)) | Identity(1)) + ) + model_y = (Mapping((0, 1), n_inputs=3) | polyy) + ( + (Mapping((0, 1), n_inputs=3) | polyy_dist) * (Mapping((2,)) | Identity(1)) + ) + + output2poly_mapping = Identity(2, name="ifupost_outmap") output2poly_mapping.inverse = Mapping([0, 1, 2, 0, 1, 2]) - input2poly_mapping = Mapping([0, 1, 2, 0, 1, 2], name="{0}_inmap".format('ifupost')) + input2poly_mapping = Mapping([0, 1, 2, 0, 1, 2], name="ifupost_inmap") input2poly_mapping.inverse = Identity(2) model_poly = input2poly_mapping | (model_x & model_y) | output2poly_mapping @@ -2095,49 +2470,56 @@ def _create_ifupost_transform(ifupost_slice): def nrs_lamp(input_model, reference_files, slit_y_range): - """Return the appropriate function for lamp data + """ + Return the appropriate function for lamp data. Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` + input_model : JwstDataModel The input data model. reference_files : dict - The reference files used for this mode. - slit_y_range : list - The slit dimensions relative to the center of the slit. + Mapping between reftype (keys) and reference file name (vals). + Required files depend on the mode. + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. + + Returns + ------- + pipeline : list + The WCS pipeline, suitable for input into `gwcs.WCS`. """ lamp_mode = input_model.meta.instrument.lamp_mode if isinstance(lamp_mode, str): lamp_mode = lamp_mode.lower() else: - lamp_mode = 'none' - if lamp_mode in ['fixedslit', 'brightobj']: + lamp_mode = "none" + if lamp_mode in ["fixedslit", "brightobj"]: return slits_wcs(input_model, reference_files, slit_y_range) - elif lamp_mode == 'ifu': + elif lamp_mode == "ifu": return ifu(input_model, reference_files, slit_y_range) - elif lamp_mode == 'msaspec': + elif lamp_mode == "msaspec": return slits_wcs(input_model, reference_files, slit_y_range) else: return not_implemented_mode(input_model, reference_files, slit_y_range) exp_type2transform = { - 'nrs_autoflat': slits_wcs, - 'nrs_autowave': nrs_lamp, - 'nrs_brightobj': slits_wcs, - 'nrs_confirm': imaging, - 'nrs_dark': not_implemented_mode, - 'nrs_fixedslit': slits_wcs, - 'nrs_focus': imaging, - 'nrs_ifu': ifu, - 'nrs_image': imaging, - 'nrs_lamp': nrs_lamp, - 'nrs_mimf': imaging, - 'nrs_msaspec': slits_wcs, - 'nrs_msata': imaging, - 'nrs_taconfirm': imaging, - 'nrs_tacq': imaging, - 'nrs_taslit': imaging, - 'nrs_verify': imaging, - 'nrs_wata': imaging, + "nrs_autoflat": slits_wcs, + "nrs_autowave": nrs_lamp, + "nrs_brightobj": slits_wcs, + "nrs_confirm": imaging, + "nrs_dark": not_implemented_mode, + "nrs_fixedslit": slits_wcs, + "nrs_focus": imaging, + "nrs_ifu": ifu, + "nrs_image": imaging, + "nrs_lamp": nrs_lamp, + "nrs_mimf": imaging, + "nrs_msaspec": slits_wcs, + "nrs_msata": imaging, + "nrs_taconfirm": imaging, + "nrs_tacq": imaging, + "nrs_taslit": imaging, + "nrs_verify": imaging, + "nrs_wata": imaging, } diff --git a/jwst/assign_wcs/pointing.py b/jwst/assign_wcs/pointing.py index 88c7f7b2b4..a853c648bb 100644 --- a/jwst/assign_wcs/pointing.py +++ b/jwst/assign_wcs/pointing.py @@ -13,8 +13,7 @@ from stdatamodels.jwst.datamodels import JwstDataModel -__all__ = ["compute_roll_ref", "frame_from_model", "fitswcs_transform_from_model", - "dva_corr_model"] +__all__ = ["compute_roll_ref", "frame_from_model", "fitswcs_transform_from_model", "dva_corr_model"] def _v23tosky(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, wrap_v2_at=180, wrap_lon_at=360): @@ -24,13 +23,34 @@ def _v23tosky(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, wrap_v2_at=180, wrap_lo # The sky rotation expects values in deg. # This should be removed when models work with quantities. - m = ((Scale(1 / 3600) & Scale(1 / 3600)) | SphericalToCartesian(wrap_lon_at=wrap_v2_at) - | rot | CartesianToSpherical(wrap_lon_at=wrap_lon_at)) - m.name = 'v23tosky' + m = ( + (Scale(1 / 3600) & Scale(1 / 3600)) + | SphericalToCartesian(wrap_lon_at=wrap_v2_at) + | rot + | CartesianToSpherical(wrap_lon_at=wrap_lon_at) + ) + m.name = "v23tosky" return m def v23tosky(input_model, wrap_v2_at=180, wrap_lon_at=360): + """ + Create a model that transforms from V2, V3 to sky coordinates. + + Parameters + ---------- + input_model : JwstDataModel + The input data model. + wrap_v2_at : float + The value at which to wrap the V2 coordinate. + wrap_lon_at : float + The value at which to wrap the longitude. + + Returns + ------- + `~astropy.modeling.Model` + The model that transforms from V2, V3 to sky coordinates. + """ m = _v23tosky( v2_ref=input_model.meta.wcsinfo.v2_ref / 3600, v3_ref=input_model.meta.wcsinfo.v3_ref / 3600, @@ -38,14 +58,14 @@ def v23tosky(input_model, wrap_v2_at=180, wrap_lon_at=360): ra_ref=input_model.meta.wcsinfo.ra_ref, dec_ref=input_model.meta.wcsinfo.dec_ref, wrap_v2_at=wrap_v2_at, - wrap_lon_at=wrap_lon_at + wrap_lon_at=wrap_lon_at, ) return m def compute_roll_ref(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, new_v2_ref, new_v3_ref): """ - Computes the position of V3 (measured N to E) at the center af an aperture. + Compute the position of V3 (measured N to E) at the center af an aperture. Parameters ---------- @@ -64,7 +84,6 @@ def compute_roll_ref(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, new_v2_ref, new_ ------- new_roll : float The value of ROLL_REF (in deg) - """ v2 = np.deg2rad(new_v2_ref / 3600) v3 = np.deg2rad(new_v3_ref / 3600) @@ -79,17 +98,19 @@ def compute_roll_ref(v2_ref, v3_ref, roll_ref, ra_ref, dec_ref, new_v2_ref, new_ angles = [v2_ref, -v3_ref, roll_ref, dec_ref, -ra_ref] axes = "zyxyz" - matrices = [rotation_matrix(a, ax) for a, ax in zip(angles, axes)] + matrices = [rotation_matrix(a, ax) for a, ax in zip(angles, axes, strict=True)] m = reduce(np.matmul, matrices[::-1]) return _roll_angle_from_matrix(m, v2, v3) def _roll_angle_from_matrix(matrix, v2, v3): - X = -(matrix[2, 0] * np.cos(v2) + matrix[2, 1] * np.sin(v2)) * \ - np.sin(v3) + matrix[2, 2] * np.cos(v3) - Y = (matrix[0, 0] * matrix[1, 2] - matrix[1, 0] * matrix[0, 2]) * np.cos(v2) + \ - (matrix[0, 1] * matrix[1, 2] - matrix[1, 1] * matrix[0, 2]) * np.sin(v2) - new_roll = np.rad2deg(np.arctan2(Y, X)) + x = -(matrix[2, 0] * np.cos(v2) + matrix[2, 1] * np.sin(v2)) * np.sin(v3) + matrix[ + 2, 2 + ] * np.cos(v3) + y = (matrix[0, 0] * matrix[1, 2] - matrix[1, 0] * matrix[0, 2]) * np.cos(v2) + ( + matrix[0, 1] * matrix[1, 2] - matrix[1, 1] * matrix[0, 2] + ) * np.sin(v2) + new_roll = np.rad2deg(np.arctan2(y, x)) if new_roll < 0: new_roll += 360 return new_roll @@ -101,18 +122,22 @@ def wcsinfo_from_model(input_model): Parameters ---------- - input_model : `~stdatamodels.jwst.datamodels.JwstDataModel` - The input data model + input_model : JwstDataModel + The input data model. + Returns + ------- + wcsinfo : dict + A dictionary with WCS keywords as keys and their values """ - defaults = {'CRPIX': 0, 'CRVAL': 0, 'CDELT': 1., 'CTYPE': "", 'CUNIT': u.Unit("")} + defaults = {"CRPIX": 0, "CRVAL": 0, "CDELT": 1.0, "CTYPE": "", "CUNIT": u.Unit("")} wcsinfo = {} wcsaxes = input_model.meta.wcsinfo.wcsaxes - wcsinfo['WCSAXES'] = wcsaxes - for key in ['CRPIX', 'CRVAL', 'CDELT', 'CTYPE', 'CUNIT']: + wcsinfo["WCSAXES"] = wcsaxes + for key in ["CRPIX", "CRVAL", "CDELT", "CTYPE", "CUNIT"]: val = [] for ax in range(1, wcsaxes + 1): - k = (key + "{0}".format(ax)).lower() + k = (key + f"{ax}").lower() v = getattr(input_model.meta.wcsinfo, k, defaults[key]) val.append(v) wcsinfo[key] = np.array(val) @@ -120,28 +145,31 @@ def wcsinfo_from_model(input_model): pc = np.zeros((wcsaxes, wcsaxes)) for i in range(1, wcsaxes + 1): for j in range(1, wcsaxes + 1): - pc[i - 1, j - 1] = getattr(input_model.meta.wcsinfo, 'pc{0}_{1}'.format(i, j), 1) - wcsinfo['PC'] = pc - wcsinfo['RADESYS'] = input_model.meta.coordinates.reference_frame - wcsinfo['has_cd'] = False + pc[i - 1, j - 1] = getattr(input_model.meta.wcsinfo, f"pc{i}_{j}", 1) + wcsinfo["PC"] = pc + wcsinfo["RADESYS"] = input_model.meta.coordinates.reference_frame + wcsinfo["has_cd"] = False return wcsinfo def fitswcs_transform_from_model(wcsinfo, wavetable=None): """ Create a WCS object using from datamodel.meta.wcsinfo. + Transforms assume 0-based coordinates. Parameters ---------- wcsinfo : dict-like ``~jwst.meta.wcsinfo`` structure. + wavetable : `~astropy.table.Table`, None + A table with wavelength values. If None, a linear transformation + will be used. - Return - ------ + Returns + ------- transform : `~astropy.modeling.core.Model` WCS forward transform - from pixel to world coordinates. - """ spatial_axes, spectral_axes, unknown = gwutils.get_axes(wcsinfo) @@ -150,12 +178,14 @@ def fitswcs_transform_from_model(wcsinfo, wavetable=None): sp_axis = spectral_axes[0] if wavetable is None: # Subtract one from CRPIX which is 1-based. - spectral_transform = astmodels.Shift(-(wcsinfo['CRPIX'][sp_axis] - 1)) | \ - astmodels.Scale(wcsinfo['CDELT'][sp_axis]) | \ - astmodels.Shift(wcsinfo['CRVAL'][sp_axis]) + spectral_transform = ( + astmodels.Shift(-(wcsinfo["CRPIX"][sp_axis] - 1)) + | astmodels.Scale(wcsinfo["CDELT"][sp_axis]) + | astmodels.Shift(wcsinfo["CRVAL"][sp_axis]) + ) else: # Wave dimension is an array that needs to be converted to a table - waves = wavetable['wavelength'].flatten() + waves = wavetable["wavelength"].flatten() spectral_transform = astmodels.Tabular1D(lookup_table=waves) transform = transform & spectral_transform @@ -173,75 +203,101 @@ def frame_from_model(wcsinfo): Parameters ---------- - wcsinfo : `~stdatamodels.jwst.datamodels.JwstDataModel` or dict + wcsinfo : JwstDataModel or dict Either one of the JWST data models or a dict with model.meta.wcsinfo. Returns ------- - frame : `~coordinate_frames.CoordinateFrame` - + frame : `~gwcs.coordinate_frames.CoordinateFrame` + A coordinate frame object corresponding to the input WCS information. """ if isinstance(wcsinfo, JwstDataModel): wcsinfo = wcsinfo_from_model(wcsinfo) - wcsaxes = wcsinfo['WCSAXES'] + wcsaxes = wcsinfo["WCSAXES"] celestial_axes, spectral_axes, other = gwutils.get_axes(wcsinfo) - cunit = wcsinfo['CUNIT'] + cunit = wcsinfo["CUNIT"] frames = [] if celestial_axes: ref_frame = coords.ICRS() - celestial = cf.CelestialFrame(name='sky', axes_order=tuple(celestial_axes), - reference_frame=ref_frame, unit=cunit[celestial_axes], - axes_names=('RA', 'DEC')) + celestial = cf.CelestialFrame( + name="sky", + axes_order=tuple(celestial_axes), + reference_frame=ref_frame, + unit=cunit[celestial_axes], + axes_names=("RA", "DEC"), + ) frames.append(celestial) if spectral_axes: - spec = cf.SpectralFrame(name='spectral', axes_order=tuple(spectral_axes), - unit=cunit[spectral_axes], - axes_names=('wavelength',)) + spec = cf.SpectralFrame( + name="spectral", + axes_order=tuple(spectral_axes), + unit=cunit[spectral_axes], + axes_names=("wavelength",), + ) frames.append(spec) if other: # Make sure these are strings and not np.str_ objects. - axes_names = tuple([str(name) for name in wcsinfo['CTYPE'][other]]) - name = "_".join(wcsinfo['CTYPE'][other]) - spatial = cf.Frame2D(name=name, axes_order=tuple(other), unit=cunit[other], - axes_names=axes_names) + axes_names = tuple([str(name) for name in wcsinfo["CTYPE"][other]]) + name = "_".join(wcsinfo["CTYPE"][other]) + spatial = cf.Frame2D( + name=name, axes_order=tuple(other), unit=cunit[other], axes_names=axes_names + ) frames.append(spatial) if wcsaxes == 2: return frames[0] elif wcsaxes == 3: - world = cf.CompositeFrame(frames, name='world') + world = cf.CompositeFrame(frames, name="world") return world else: - raise ValueError("WCSAXES can be 2 or 3, got {0}".format(wcsaxes)) + raise ValueError("WCSAXES can be 2 or 3, got {wcsaxes}") def create_fitswcs(inp, input_frame=None): - if isinstance(inp, JwstDataModel): - wcsinfo = wcsinfo_from_model(inp) - wavetable = None - spatial_axes, spectral_axes, unknown = gwutils.get_axes(wcsinfo) - if spectral_axes: - sp_axis = spectral_axes[0] - if wcsinfo['CTYPE'][sp_axis] == 'WAVE-TAB': - wavetable = inp.wavetable - transform = fitswcs_transform_from_model(wcsinfo, wavetable=wavetable) - output_frame = frame_from_model(wcsinfo) - else: + """ + Create a WCS object from a JWST data model or a FITS file. + + Parameters + ---------- + inp : JwstDataModel or str + Either a JWST data model or a FITS file. + input_frame : `~gwcs.coordinate_frames.CoordinateFrame`, None + The input coordinate frame. If None, a default frame will be created. + + Returns + ------- + wcsobj : `~gwcs.wcs.WCS` + A WCS object. + """ + if not isinstance(inp, JwstDataModel): raise TypeError("Input is expected to be a JwstDataModel instance or a FITS file.") + wcsinfo = wcsinfo_from_model(inp) + wavetable = None + spatial_axes, spectral_axes, unknown = gwutils.get_axes(wcsinfo) + if spectral_axes: + sp_axis = spectral_axes[0] + if wcsinfo["CTYPE"][sp_axis] == "WAVE-TAB": + wavetable = inp.wavetable + transform = fitswcs_transform_from_model(wcsinfo, wavetable=wavetable) + output_frame = frame_from_model(wcsinfo) if input_frame is None: - wcsaxes = wcsinfo['WCSAXES'] + wcsaxes = wcsinfo["WCSAXES"] if wcsaxes == 2: input_frame = cf.Frame2D(name="detector") elif wcsaxes == 3: - input_frame = cf.CoordinateFrame(name="detector", naxes=3, - axes_order=(0, 1, 2), unit=(u.pix, u.pix, u.pix), - axes_type=["SPATIAL", "SPATIAL", "SPECTRAL"], - axes_names=('x', 'y', 'z'), axis_physical_types=None) + input_frame = cf.CoordinateFrame( + name="detector", + naxes=3, + axes_order=(0, 1, 2), + unit=(u.pix, u.pix, u.pix), + axes_type=["SPATIAL", "SPATIAL", "SPECTRAL"], + axes_names=("x", "y", "z"), + axis_physical_types=None, + ) else: raise TypeError(f"WCSAXES is expected to be 2 or 3, instead it is {wcsaxes}") - pipeline = [(input_frame, transform), - (output_frame, None)] + pipeline = [(input_frame, transform), (output_frame, None)] wcsobj = wcs.WCS(pipeline) return wcsobj @@ -249,8 +305,7 @@ def create_fitswcs(inp, input_frame=None): def dva_corr_model(va_scale, v2_ref, v3_ref): """ - Create transformation that accounts for differential velocity aberration - (scale). + Create transformation that accounts for differential velocity aberration (scale). Parameters ---------- @@ -258,11 +313,9 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): Ratio of the apparent plate scale to the true plate scale. When ``va_scale`` is `None`, it is assumed to be identical to ``1`` and an ``astropy.modeling.models.Identity`` model will be returned. - v2_ref : float, None Telescope ``v2`` coordinate of the reference point in ``arcsec``. When ``v2_ref`` is `None`, it is assumed to be identical to ``0``. - v3_ref : float, None Telescope ``v3`` coordinate of the reference point in ``arcsec``. When ``v3_ref`` is `None`, it is assumed to be identical to ``0``. @@ -272,7 +325,6 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): va_corr : astropy.modeling.CompoundModel, astropy.modeling.models.Identity A 2D compound model that corrects DVA. If ``va_scale`` is `None` or 1 then `astropy.modeling.models.Identity` will be returned. - """ if va_scale is None or va_scale == 1: return Identity(2) @@ -280,7 +332,7 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): if va_scale <= 0: raise ValueError("'Velocity aberration scale must be a positive number.") - va_corr = Scale(va_scale, name='dva_scale_v2') & Scale(va_scale, name='dva_scale_v3') + va_corr = Scale(va_scale, name="dva_scale_v2") & Scale(va_scale, name="dva_scale_v3") if v2_ref is None: v2_ref = 0 @@ -297,6 +349,6 @@ def dva_corr_model(va_scale, v2_ref, v3_ref): v2_shift = (1 - va_scale) * v2_ref v3_shift = (1 - va_scale) * v3_ref - va_corr |= Shift(v2_shift, name='dva_v2_shift') & Shift(v3_shift, name='dva_v3_shift') - va_corr.name = 'DVA_Correction' + va_corr |= Shift(v2_shift, name="dva_v2_shift") & Shift(v3_shift, name="dva_v3_shift") + va_corr.name = "DVA_Correction" return va_corr diff --git a/jwst/assign_wcs/tools/__init__.py b/jwst/assign_wcs/tools/__init__.py index e69de29bb2..309d4452e6 100644 --- a/jwst/assign_wcs/tools/__init__.py +++ b/jwst/assign_wcs/tools/__init__.py @@ -0,0 +1 @@ +"""Testing tools for assign_wcs.""" diff --git a/jwst/assign_wcs/tools/miri/__init__.py b/jwst/assign_wcs/tools/miri/__init__.py index e69de29bb2..784aaaefb6 100644 --- a/jwst/assign_wcs/tools/miri/__init__.py +++ b/jwst/assign_wcs/tools/miri/__init__.py @@ -0,0 +1 @@ +"""Tools for testing MIRI.""" diff --git a/jwst/assign_wcs/tools/nirspec/__init__.py b/jwst/assign_wcs/tools/nirspec/__init__.py index e69de29bb2..a08448293c 100644 --- a/jwst/assign_wcs/tools/nirspec/__init__.py +++ b/jwst/assign_wcs/tools/nirspec/__init__.py @@ -0,0 +1 @@ +"""Tools for testing NIRSPEC MSA metadata.""" diff --git a/jwst/assign_wcs/tools/nirspec/create_configuration_test.py b/jwst/assign_wcs/tools/nirspec/create_configuration_test.py index c9c2b9f4b1..682028f111 100644 --- a/jwst/assign_wcs/tools/nirspec/create_configuration_test.py +++ b/jwst/assign_wcs/tools/nirspec/create_configuration_test.py @@ -4,115 +4,143 @@ def add_configuration_records(filename): + """ + Add in some test configuration records to the MSA metadata file. + + Parameters + ---------- + filename : str + The name of the MSA metadata file to add the configuration records to. + """ # filename = 'jw95065006001_01_msa.fits' # Read in the base fits file. msa_conf = fits.open(filename) # Add in the different configuration test conditions - rectype = np.dtype([('slitlet_id', '>i2'), - ('msa_metadata_id', '>i2'), - ('shutter_quadrant', '>i2'), - ('shutter_row', '>i2'), - ('shutter_column', '>i2'), - ('source_id', '>i2'), - ('background', 'S1'), - ('shutter_state', 'S6'), - ('estimated_source_in_shutter_x', '>f4'), - ('estimated_source_in_shutter_y', '>f4')]) + rectype = np.dtype( + [ + ("slitlet_id", ">i2"), + ("msa_metadata_id", ">i2"), + ("shutter_quadrant", ">i2"), + ("shutter_row", ">i2"), + ("shutter_column", ">i2"), + ("source_id", ">i2"), + ("background", "S1"), + ("shutter_state", "S6"), + ("estimated_source_in_shutter_x", ">f4"), + ("estimated_source_in_shutter_y", ">f4"), + ] + ) # The base one is - base = np.array([(12, 2, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 2, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 2, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 64, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 64, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 64, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 65, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 65, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 65, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 66, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 66, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 66, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 67, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 67, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 67, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 68, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 68, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 68, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 69, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 69, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 69, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 70, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 70, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 70, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 71, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (12, 71, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (12, 71, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan)], - rectype) + base = np.array( + [ + (12, 2, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 2, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 2, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 64, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 64, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 64, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 65, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 65, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 65, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 66, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 66, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 66, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 67, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 67, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 67, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 68, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 68, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 68, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 69, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 69, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 69, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 70, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 70, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 70, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (12, 71, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (12, 71, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (12, 71, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + ], + rectype, + ) # Test 1: Kinda normal - test1 = np.array([ - (55, 12, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (55, 12, 4, 251, 23, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (55, 12, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (55, 12, 4, 251, 25, 1, 'Y', 'OPEN', np.nan, np.nan), - (55, 12, 4, 251, 26, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - ], - rectype) + test1 = np.array( + [ + (55, 12, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (55, 12, 4, 251, 23, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (55, 12, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (55, 12, 4, 251, 25, 1, "Y", "OPEN", np.nan, np.nan), + (55, 12, 4, 251, 26, 1, "N", "OPEN", 0.18283921, 0.31907734), + ], + rectype, + ) # Test 2: Create slitlet_id set with no background open # This should fail as there should be only one "N" - test2 = np.array([ - (56, 13, 4, 251, 22, 1, 'N', 'OPEN', np.nan, np.nan), - (56, 13, 4, 251, 23, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (56, 13, 4, 251, 24, 1, 'Y', 'OPEN', np.nan, np.nan), - (56, 13, 4, 251, 25, 1, 'Y', 'OPEN', np.nan, np.nan), - (56, 13, 4, 251, 26, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - ], - rectype) + test2 = np.array( + [ + (56, 13, 4, 251, 22, 1, "N", "OPEN", np.nan, np.nan), + (56, 13, 4, 251, 23, 1, "N", "OPEN", 0.18283921, 0.31907734), + (56, 13, 4, 251, 24, 1, "Y", "OPEN", np.nan, np.nan), + (56, 13, 4, 251, 25, 1, "Y", "OPEN", np.nan, np.nan), + (56, 13, 4, 251, 26, 1, "N", "OPEN", 0.18283921, 0.31907734), + ], + rectype, + ) # Test 3: All background - test3 = np.array([ - (57, 14, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (57, 14, 4, 251, 23, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (57, 14, 4, 251, 24, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - ], - rectype) + test3 = np.array( + [ + (57, 14, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (57, 14, 4, 251, 23, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (57, 14, 4, 251, 24, 1, "Y", "OPEN", 0.18283921, 0.31907734), + ], + rectype, + ) # Test 4: Empty in between - test4 = np.array([ - (58, 15, 4, 251, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (58, 15, 4, 251, 23, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (58, 15, 4, 251, 24, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (58, 15, 4, 251, 25, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (58, 15, 4, 251, 27, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (58, 15, 4, 251, 28, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - ], - rectype) + test4 = np.array( + [ + (58, 15, 4, 251, 22, 1, "Y", "OPEN", np.nan, np.nan), + (58, 15, 4, 251, 23, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (58, 15, 4, 251, 24, 1, "N", "OPEN", 0.18283921, 0.31907734), + (58, 15, 4, 251, 25, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (58, 15, 4, 251, 27, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (58, 15, 4, 251, 28, 1, "Y", "OPEN", 0.18283921, 0.31907734), + ], + rectype, + ) # Test 5: Empty in between - test5 = np.array([ - (59, 16, 4, 256, 22, 1, 'Y', 'OPEN', np.nan, np.nan), - (59, 16, 4, 256, 23, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (59, 16, 4, 256, 24, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (59, 16, 4, 256, 25, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (59, 16, 4, 256, 27, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (59, 16, 4, 256, 28, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (60, 16, 4, 258, 30, 1, 'Y', 'OPEN', np.nan, np.nan), - (60, 16, 4, 258, 31, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (60, 16, 4, 258, 32, 1, 'N', 'OPEN', 0.18283921, 0.31907734), - (60, 16, 4, 258, 33, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (60, 16, 4, 258, 34, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - (60, 16, 4, 258, 35, 1, 'Y', 'OPEN', 0.18283921, 0.31907734), - ], - rectype) + test5 = np.array( + [ + (59, 16, 4, 256, 22, 1, "Y", "OPEN", np.nan, np.nan), + (59, 16, 4, 256, 23, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (59, 16, 4, 256, 24, 1, "N", "OPEN", 0.18283921, 0.31907734), + (59, 16, 4, 256, 25, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (59, 16, 4, 256, 27, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (59, 16, 4, 256, 28, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (60, 16, 4, 258, 30, 1, "Y", "OPEN", np.nan, np.nan), + (60, 16, 4, 258, 31, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (60, 16, 4, 258, 32, 1, "N", "OPEN", 0.18283921, 0.31907734), + (60, 16, 4, 258, 33, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (60, 16, 4, 258, 34, 1, "Y", "OPEN", 0.18283921, 0.31907734), + (60, 16, 4, 258, 35, 1, "Y", "OPEN", 0.18283921, 0.31907734), + ], + rectype, + ) # And now write out the msa_conf to a new file - tt = numpy.lib.recfunctions.stack_arrays((base, test1, test2, test3, test4, test5), usemask=False) + tt = numpy.lib.recfunctions.stack_arrays( + (base, test1, test2, test3, test4, test5), usemask=False + ) msa_conf[2].data = tt - msa_conf.writeto('test_configuration_msa.fits') + msa_conf.writeto("test_configuration_msa.fits") -if __name__ == '__main__': - print('Call add_configuration_records(filename) directly.') +if __name__ == "__main__": + print("Call add_configuration_records(filename) directly.") # noqa: T201 diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index 695f966502..c43b3c785e 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -1,7 +1,5 @@ -""" -Utility function for assign_wcs. +"""Utility functions for assign_wcs.""" -""" import logging import functools import numpy as np @@ -11,7 +9,6 @@ from astropy.modeling import models as astmodels from astropy.table import QTable from astropy.constants import c -from typing import Union, List from gwcs import WCS from gwcs.wcstools import grid_from_bounding_box @@ -19,10 +16,10 @@ from stpipe.exceptions import StpipeExitException from stcal.alignment.util import compute_s_region_keyword, compute_s_region_imaging -from stdatamodels.jwst.datamodels import WavelengthrangeModel +from stdatamodels.jwst.datamodels import WavelengthrangeModel, MiriLRSSpecwcsModel from stdatamodels.jwst.transforms.models import GrismObject -from ..lib.catalog_utils import SkyObject +from jwst.lib.catalog_utils import SkyObject log = logging.getLogger(__name__) @@ -32,19 +29,28 @@ _MAX_SIP_DEGREE = 6 -__all__ = ["reproject", "velocity_correction", - "MSAFileError", "NoDataOnDetectorError", "compute_scale", - "calc_rotation_matrix", "wrap_ra", "update_fits_wcsinfo"] +__all__ = [ + "reproject", + "velocity_correction", + "MSAFileError", + "NoDataOnDetectorError", + "compute_scale", + "calc_rotation_matrix", + "wrap_ra", + "update_fits_wcsinfo", +] class MSAFileError(Exception): + """Exception to raise when MSA shutter configuration file is missing or invalid.""" def __init__(self, message): super(MSAFileError, self).__init__(message) class NoDataOnDetectorError(StpipeExitException): - """WCS solution indicates no data on detector + """ + WCS solution indicates no data on detector. When WCS solutions are available, the solutions indicate that no data will be present, raise this exception. @@ -53,12 +59,11 @@ class NoDataOnDetectorError(StpipeExitException): configurations of the MSA, it is possible that no dispersed spectra will appear on NRS2. This is not a failure of calibration, but needs to be called out in order for the calling architecture to be aware of this. - """ def __init__(self, message=None): if message is None: - message = 'WCS solution indicate that no science is in the data.' + message = "WCS solution indicate that no science is in the data." # The first argument instructs stpipe CLI tools to exit with status # 64 when this exception is raised. super().__init__(64, message) @@ -66,7 +71,7 @@ def __init__(self, message=None): def _domain_to_bounding_box(domain): # TODO: remove this when domain is completely removed - bb = tuple([(item['lower'], item['upper']) for item in domain]) + bb = tuple([(item["lower"], item["upper"]) for item in domain]) if len(bb) == 1: bb = bb[0] return bb @@ -74,8 +79,7 @@ def _domain_to_bounding_box(domain): def reproject(wcs1, wcs2): """ - Given two WCSs return a function which takes pixel coordinates in - the first WCS and computes their location in the second one. + Take in pixel coordinates in the first WCS and computes their location in the second one. It performs the forward transformation of ``wcs1`` followed by the inverse of ``wcs2``. @@ -95,24 +99,27 @@ def reproject(wcs1, wcs2): def _reproject(x, y): sky = wcs1.forward_transform(x, y) return wcs2.backward_transform(*sky) + return _reproject -def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], - disp_axis: int | None = None, pscale_ratio: float | None = None) -> float: - """Compute scaling transform. +def compute_scale( + wcs: WCS, + fiducial: tuple | np.ndarray, + disp_axis: int | None = None, + pscale_ratio: float | None = None, +) -> float: + """ + Compute scaling transform. Parameters ---------- wcs : `~gwcs.wcs.WCS` Reference WCS object from which to compute a scaling factor. - fiducial : tuple Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points. - disp_axis : int Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction` - pscale_ratio : int Ratio of input to output pixel scale @@ -120,23 +127,24 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], ------- scale : float Scaling factor for x and y or cross-dispersion direction. - """ - spectral = 'SPECTRAL' in wcs.output_frame.axes_type + spectral = "SPECTRAL" in wcs.output_frame.axes_type if spectral and disp_axis is None: - raise ValueError('If input WCS is spectral, a disp_axis must be given') + raise ValueError("If input WCS is spectral, a disp_axis must be given") crpix = np.array(wcs.invert(*fiducial, with_bounding_box=False)) delta = np.zeros_like(crpix) - spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] + spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0] delta[spatial_idx[0]] = 1 crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False) - coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg") + coords = SkyCoord( + ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg" + ) xscale: float = np.abs(coords[0].separation(coords[1]).value) yscale: float = np.abs(coords[0].separation(coords[2]).value) @@ -153,25 +161,24 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], return scale -def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> List[float]: - """Calculate the rotation matrix. +def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> list[float]: + """ + Calculate the rotation matrix. Parameters ---------- roll_ref : float Telescope roll angle of V3 North over East at the ref. point in radians - v3i_yang : float The angle between ideal Y-axis and V3 in radians. - vparity : int The x-axis parity, usually taken from the JWST SIAF parameter VIdlParity. Value should be "1" or "-1". Returns ------- - matrix: [pc1_1, pc1_2, pc2_1, pc2_2] - The rotation matrix + matrix : list + The rotation matrix, [pc1_1, pc1_2, pc2_1, pc2_2] Notes ----- @@ -181,10 +188,9 @@ def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> | pc1_1 pc2_1 | | pc1_2 pc2_2 | ---------------- - """ if vparity not in (1, -1): - raise ValueError(f'vparity should be 1 or -1. Input was: {vparity}') + raise ValueError(f"vparity should be 1 or -1. Input was: {vparity}") rel_angle = roll_ref - (vparity * v3i_yang) @@ -198,15 +204,26 @@ def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> def compute_fiducial(wcslist, bounding_box=None): """ - For a celestial footprint this is the center. - For a spectral footprint, it is the beginning of the range. + Compute the 'fiducial point' of a list of WCS objects. This function assumes all WCSs have the same output coordinate frame. - """ + Parameters + ---------- + wcslist : list of `~gwcs.wcs.WCS` + List of WCS objects. + bounding_box : tuple, None + The bounding box of the output frame. + + Returns + ------- + fiducial : np.ndarray + The fiducial point. For a celestial footprint this is the center. + For a spectral footprint, it is the beginning of the range. + """ axes_types = wcslist[0].output_frame.axes_type - spatial_axes = np.array(axes_types) == 'SPATIAL' - spectral_axes = np.array(axes_types) == 'SPECTRAL' + spatial_axes = np.array(axes_types) == "SPATIAL" + spectral_axes = np.array(axes_types) == "SPECTRAL" footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist]) spatial_footprint = footprints[spatial_axes] spectral_footprint = footprints[spectral_axes] @@ -219,11 +236,11 @@ def compute_fiducial(wcslist, bounding_box=None): y = np.cos(lat) * np.sin(lon) z = np.sin(lat) - x_mid = (np.max(x) + np.min(x)) / 2. - y_mid = (np.max(y) + np.min(y)) / 2. - z_mid = (np.max(z) + np.min(z)) / 2. + x_mid = (np.max(x) + np.min(x)) / 2.0 + y_mid = (np.max(y) + np.min(y)) / 2.0 + z_mid = (np.max(z) + np.min(z)) / 2.0 lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0 - lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid ** 2 + y_mid ** 2))) + lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2))) fiducial[spatial_axes] = lon_fiducial, lat_fiducial if spectral_footprint.any(): fiducial[spectral_axes] = spectral_footprint.min() @@ -232,14 +249,19 @@ def compute_fiducial(wcslist, bounding_box=None): def is_fits(input_img): """ - Returns - -------- - isFits: tuple - An ``(isfits, fitstype)`` tuple. The values of ``isfits`` and - ``fitstype`` are specified as: + Determine if the input is a FITS file. + + Parameters + ---------- + input_img : str, `~astropy.io.fits.HDUList` + The input image to be checked. - - ``isfits``: True|False - - ``fitstype``: if True, one of 'waiver', 'mef', 'simple'; if False, None + Returns + ------- + isfits : bool + True if the input is a FITS file, False otherwise. + fitstype : str + The type of FITS file, one of 'waiver', 'mef', 'simple'; None if not a FITS file. Notes ----- @@ -250,10 +272,9 @@ def is_fits(input_img): error upon opening, this routine will raise that exception for the calling routine/user to handle. """ - isfits = False fitstype = None - names = ['fits', 'fit', 'FITS', 'FIT'] + names = ["fits", "fit", "FITS", "FIT"] # determine if input is a fits file based on extension # Only check type of FITS file if filename ends in valid FITS string f = None @@ -269,7 +290,7 @@ def is_fits(input_img): if isfits: if not f: try: - f = fits.open(input_img, mode='readonly') + f = fits.open(input_img, mode="readonly") fileclose = True except Exception: if f is not None: @@ -279,12 +300,12 @@ def is_fits(input_img): if data0 is not None: try: if isinstance(f[1], fits.TableHDU): - fitstype = 'waiver' + fitstype = "waiver" except IndexError: - fitstype = 'simple' + fitstype = "simple" else: - fitstype = 'mef' + fitstype = "mef" if fileclose: f.close() @@ -297,8 +318,8 @@ def subarray_transform(input_model): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` - Data model. + input_model : JwstDataModel + The input data model. Returns ------- @@ -319,8 +340,7 @@ def subarray_transform(input_model): if ystart is not None and ystart != 1: tr_ystart = astmodels.Shift(ystart - 1) - if (isinstance(tr_xstart, astmodels.Identity) and - isinstance(tr_ystart, astmodels.Identity)): + if isinstance(tr_xstart, astmodels.Identity) and isinstance(tr_ystart, astmodels.Identity): # the case of a full frame observation return None else: @@ -328,18 +348,27 @@ def subarray_transform(input_model): return subarray2full -def not_implemented_mode(input_model, ref, slit_y_range=None): +def not_implemented_mode(input_model, ref, slit_y_range=None): # noqa: ARG001 """ - Return ``None`` if assign_wcs has not been implemented for a mode. + Send an error to the log and return None if assign_wcs has not been implemented for a mode. + + Parameters + ---------- + input_model : JwstDataModel + The input data model. + ref : dict + Mapping between reftype (keys) and reference file name (vals). + slit_y_range : tuple + The slit Y-range for Nirspec slits, relative to (0, 0) in the center. """ exp_type = input_model.meta.exposure.type - message = "WCS for EXP_TYPE of {0} is not implemented.".format(exp_type) + message = f"WCS for EXP_TYPE of {exp_type} is not implemented." log.critical(message) - return None def get_object_info(catalog_name=None): - """Return a list of SkyObjects from the direct image + """ + Return a list of SkyObjects from the direct image. The source_catalog step catalog items are read into a list of SkyObjects which can be referenced by catalog id. Only @@ -354,10 +383,6 @@ def get_object_info(catalog_name=None): ------- objects : list[jwst.transforms.models.SkyObject] A list of SkyObject tuples - - Notes - ----- - """ if isinstance(catalog_name, str): if len(catalog_name) == 0: @@ -365,10 +390,10 @@ def get_object_info(catalog_name=None): log.error(err_text) raise ValueError(err_text) try: - catalog = QTable.read(catalog_name, format='ascii.ecsv') + catalog = QTable.read(catalog_name, format="ascii.ecsv") except FileNotFoundError as e: - log.error("Could not find catalog file: {0}".format(e)) - raise FileNotFoundError("Could not find catalog: {0}".format(e)) + log.error(f"Could not find catalog file: {e}") + raise FileNotFoundError(f"Could not find catalog: {e}") from None elif isinstance(catalog_name, QTable): catalog = catalog_name else: @@ -384,13 +409,13 @@ def get_object_info(catalog_name=None): try: if not set(required_fields).issubset(set(catalog.colnames)): difference = set(required_fields).difference(set(catalog.colnames)) - err_text = "Missing required columns in source catalog: {0}".format(difference) + err_text = f"Missing required columns in source catalog: {difference}" log.error(err_text) raise KeyError(err_text) except AttributeError as e: - err_text = "Problem validating object catalog columns: {0}".format(e) + err_text = f"Problem validating object catalog columns: {e}" log.error(err_text) - raise AttributeError + raise AttributeError(err_text) from None # The columns are named sky_bbox_ll, sky_bbox_ul, sky_bbox_lr, # and sky_bbox_ur, each of which is a SkyCoord (i.e. RA & Dec & frame) at @@ -399,30 +424,35 @@ def get_object_info(catalog_name=None): # (hence, the four separate columns). for row in catalog: - objects.append(SkyObject(label=row['label'], - xcentroid=row['xcentroid'], - ycentroid=row['ycentroid'], - sky_centroid=row['sky_centroid'], - isophotal_abmag=row['isophotal_abmag'], - isophotal_abmag_err=row['isophotal_abmag_err'], - sky_bbox_ll=row['sky_bbox_ll'], - sky_bbox_lr=row['sky_bbox_lr'], - sky_bbox_ul=row['sky_bbox_ul'], - sky_bbox_ur=row['sky_bbox_ur'], - is_extended=row['is_extended'] - ) - ) + objects.append( + SkyObject( + label=row["label"], + xcentroid=row["xcentroid"], + ycentroid=row["ycentroid"], + sky_centroid=row["sky_centroid"], + isophotal_abmag=row["isophotal_abmag"], + isophotal_abmag_err=row["isophotal_abmag_err"], + sky_bbox_ll=row["sky_bbox_ll"], + sky_bbox_lr=row["sky_bbox_lr"], + sky_bbox_ul=row["sky_bbox_ul"], + sky_bbox_ur=row["sky_bbox_ur"], + is_extended=row["is_extended"], + ) + ) return objects -def create_grism_bbox(input_model, - reference_files=None, - mmag_extract=None, - extract_orders=None, - wfss_extract_half_height=None, - wavelength_range=None, - nbright=None): - """Create bounding boxes for each object in the catalog +def create_grism_bbox( + input_model, + reference_files=None, + mmag_extract=None, + extract_orders=None, + wfss_extract_half_height=None, + wavelength_range=None, + nbright=None, +): + """ + Create bounding boxes for each object in the catalog. The sky coordinates in the catalog image are first related to the grism image. They need to go through the WCS object @@ -431,10 +461,9 @@ def create_grism_bbox(input_model, location can then be sent through the trace polynomials to find the spectral location on the grism image for that wavelength and order. - Parameters ---------- - input_model : `jwst.datamodels.ImagingModel` + input_model : ImageModel Data model which holds the grism image reference_files : dict, optional Dictionary of reference file names. @@ -486,7 +515,6 @@ def create_grism_bbox(input_model, If ``wfss_extract_half_height`` is specified it is used to compute the extent in the cross-dispersion direction, which becomes ``2 * wfss_extract_half_height + 1``. ``wfss_extract_half_height`` can only be applied to point source objects. - """ instr_name = input_model.meta.instrument.name if instr_name == "NIRCAM": @@ -503,8 +531,8 @@ def create_grism_bbox(input_model, raise TypeError(message) else: # Get the list of extract_orders and lmin, lmax from the ``wavelengthrange`` reference file. - with WavelengthrangeModel(reference_files['wavelengthrange']) as f: - if 'WFSS' not in f.meta.exposure.type: + with WavelengthrangeModel(reference_files["wavelengthrange"]) as f: + if "WFSS" not in f.meta.exposure.type: err_text = "Wavelengthrange reference file not for WFSS" log.error(err_text) raise ValueError(err_text) @@ -516,9 +544,9 @@ def create_grism_bbox(input_model, wavelength_range = f.get_wfss_wavelength_range(filter_name, extract_orders) if mmag_extract is None: - mmag_extract = 999. # extract all objects, regardless of magnitude + mmag_extract = 999.0 # extract all objects, regardless of magnitude else: - log.info("Extracting objects < abmag = {0}".format(mmag_extract)) + log.info(f"Extracting objects < abmag = {mmag_extract}") if not isinstance(mmag_extract, (int, float)): raise TypeError(f"Expected mmag_extract to be a number, got {mmag_extract}") @@ -530,14 +558,19 @@ def create_grism_bbox(input_model, log.info(f"Getting objects from {input_model.meta.source_catalog}") - return _create_grism_bbox(input_model, mmag_extract, wfss_extract_half_height, wavelength_range, - nbright) - + return _create_grism_bbox( + input_model, mmag_extract, wfss_extract_half_height, wavelength_range, nbright + ) -def _create_grism_bbox(input_model, mmag_extract=None, wfss_extract_half_height=None, - wavelength_range=None, nbright=None): - log.debug(f'Extracting with wavelength_range {wavelength_range}') +def _create_grism_bbox( + input_model, + mmag_extract=None, + wfss_extract_half_height=None, + wavelength_range=None, + nbright=None, +): + log.debug(f"Extracting with wavelength_range {wavelength_range}") # this contains the pure information from the catalog with no translations skyobject_list = get_object_info(input_model.meta.source_catalog) @@ -545,141 +578,181 @@ def _create_grism_bbox(input_model, mmag_extract=None, wfss_extract_half_height= # here, image is in the imaging reference frame, before going through the # dispersion coefficients - sky_to_detector = input_model.meta.wcs.get_transform('world', 'detector') + sky_to_detector = input_model.meta.wcs.get_transform("world", "detector") sky_to_grism = input_model.meta.wcs.backward_transform grism_objects = [] # the return list of GrismObjects for obj in skyobject_list: - if obj.isophotal_abmag is not None: - if obj.isophotal_abmag < mmag_extract: - # could add logic to ignore object if too far off image, - - # save the image frame center of the object - # takes in ra, dec, wavelength, order but wave and order - # don't get used until the detector->grism_detector transform - xcenter, ycenter, _, _ = sky_to_detector(obj.sky_centroid.icrs.ra.value, - obj.sky_centroid.icrs.dec.value, - 1, 1) - - order_bounding = {} - waverange = {} - partial_order = {} - for order in wavelength_range: - # range_select = [(x[2], x[3]) for x in wavelengthrange if (x[0] == order and x[1] == filter_name)] - # The orders of the bounding box in the non-dispersed image - # drive the extraction extent. The location of the min and - # max wavelengths for each order are used to get the - # location of the +/- sides of the bounding box in the - # grism image - lmin, lmax = wavelength_range[order] - ra = np.array([obj.sky_bbox_ll.ra.value, obj.sky_bbox_lr.ra.value, - obj.sky_bbox_ul.ra.value, obj.sky_bbox_ur.ra.value]) - dec = np.array([obj.sky_bbox_ll.dec.value, obj.sky_bbox_lr.dec.value, - obj.sky_bbox_ul.dec.value, obj.sky_bbox_ur.dec.value]) - x1, y1, _, _, _ = sky_to_grism(ra, dec, [lmin] * 4, [order] * 4) - x2, y2, _, _, _ = sky_to_grism(ra, dec, [lmax] * 4, [order] * 4) - - xstack = np.hstack([x1, x2]) - ystack = np.hstack([y1, y2]) - - # Subarrays are only allowed in nircam tsgrism mode. The polynomial transforms - # only work with the full frame coordinates. The code here is called during extract_2d, - # and is creating bounding boxes which should be in the full frame coordinates, it just - # uses the input catalog and the magnitude to limit the objects that need bounding boxes. - - # Tsgrism is always supposed to have the source object at the same pixel, and that is - # hardcoded into the transforms. At least a while ago, the 2d extraction for tsgrism mode - # didn't call this bounding box code. So I think it's safe to leave the subarray - # subtraction out, i.e. do not subtract x/ystart. - - xmin = np.nanmin(xstack) - xmax = np.nanmax(xstack) - ymin = np.nanmin(ystack) - ymax = np.nanmax(ystack) - - if wfss_extract_half_height is not None and not obj.is_extended: - if input_model.meta.wcsinfo.dispersion_direction == 2: - ra_center, dec_center = obj.sky_centroid.ra.value, obj.sky_centroid.dec.value - center, _, _, _, _ = sky_to_grism(ra_center, dec_center, (lmin + lmax) / 2, order) - xmin = center - wfss_extract_half_height - xmax = center + wfss_extract_half_height - elif input_model.meta.wcsinfo.dispersion_direction == 1: - ra_center, dec_center = obj.sky_centroid.ra.value, obj.sky_centroid.dec.value - _, center, _, _, _ = sky_to_grism(ra_center, dec_center, (lmin + lmax) / 2, order) - ymin = center - wfss_extract_half_height - ymax = center + wfss_extract_half_height - else: - raise ValueError("Cannot determine dispersion direction.") - - # Convert floating-point corner values to whole pixel indexes - xmin = gwutils._toindex(xmin) - xmax = gwutils._toindex(xmax) - ymin = gwutils._toindex(ymin) - ymax = gwutils._toindex(ymax) - - # Don't add objects and orders that are entirely off the detector. - # "partial_order" marks objects that are near enough to the detector - # edge to have some spectrum on the detector. - # This is useful because the catalog often is created from a resampled direct - # image that is bigger than the detector FOV for a single grism exposure. - exclude = False - ispartial = False - - # Here we check to ensure that the extraction region `pts` - # has at least two pixels of width in the dispersion - # direction, and one in the cross-dispersed direction when - # placed into the subarray extent. - pts = np.array([[ymin, xmin], [ymax, xmax]]) - subarr_extent = np.array([[0, 0], - [input_model.meta.subarray.ysize - 1, - input_model.meta.subarray.xsize - 1]]) - - if input_model.meta.wcsinfo.dispersion_direction == 1: - # X-axis is dispersion direction - disp_col = 1 - xdisp_col = 0 - else: - # Y-axis is dispersion direction - disp_col = 0 - xdisp_col = 1 - - dispaxis_check = (pts[1, disp_col] - subarr_extent[0, disp_col] > 0) and \ - (subarr_extent[1, disp_col] - pts[0, disp_col] > 0) - xdispaxis_check = (pts[1, xdisp_col] - subarr_extent[0, xdisp_col] >= 0) and \ - (subarr_extent[1, xdisp_col] - pts[0, xdisp_col] >= 0) - - contained = dispaxis_check and xdispaxis_check - - inidx = np.all(np.logical_and(subarr_extent[0] <= pts, pts <= subarr_extent[1]), axis=1) - - if not contained: - exclude = True - log.info("Excluding off-image object: {}, order {}".format(obj.label, order)) - elif contained >= 1: - outbox = pts[np.logical_not(inidx)] - if len(outbox) > 0: - ispartial = True - log.info("Partial order on detector for obj: {} order: {}".format(obj.label, order)) - - if not exclude: - order_bounding[order] = ((ymin, ymax), (xmin, xmax)) - waverange[order] = ((lmin, lmax)) - partial_order[order] = ispartial - - if len(order_bounding) > 0: - grism_objects.append(GrismObject(sid=obj.label, - order_bounding=order_bounding, - sky_centroid=obj.sky_centroid, - partial_order=partial_order, - waverange=waverange, - sky_bbox_ll=obj.sky_bbox_ll, - sky_bbox_lr=obj.sky_bbox_lr, - sky_bbox_ul=obj.sky_bbox_ul, - sky_bbox_ur=obj.sky_bbox_ur, - xcentroid=xcenter, - ycentroid=ycenter, - is_extended=obj.is_extended, - isophotal_abmag=obj.isophotal_abmag)) + if obj.isophotal_abmag is None: + continue + if obj.isophotal_abmag >= mmag_extract: + continue + # could add logic to ignore object if too far off image, + + # save the image frame center of the object + # takes in ra, dec, wavelength, order but wave and order + # don't get used until the detector->grism_detector transform + xcenter, ycenter, _, _ = sky_to_detector( + obj.sky_centroid.icrs.ra.value, obj.sky_centroid.icrs.dec.value, 1, 1 + ) + + order_bounding = {} + waverange = {} + partial_order = {} + for order in wavelength_range: + # range_select = [(x[2], x[3]) for x in wavelengthrange \ + # if (x[0] == order and x[1] == filter_name)] + # The orders of the bounding box in the non-dispersed image + # drive the extraction extent. The location of the min and + # max wavelengths for each order are used to get the + # location of the +/- sides of the bounding box in the + # grism image + lmin, lmax = wavelength_range[order] + ra = np.array( + [ + obj.sky_bbox_ll.ra.value, + obj.sky_bbox_lr.ra.value, + obj.sky_bbox_ul.ra.value, + obj.sky_bbox_ur.ra.value, + ] + ) + dec = np.array( + [ + obj.sky_bbox_ll.dec.value, + obj.sky_bbox_lr.dec.value, + obj.sky_bbox_ul.dec.value, + obj.sky_bbox_ur.dec.value, + ] + ) + x1, y1, _, _, _ = sky_to_grism(ra, dec, [lmin] * 4, [order] * 4) + x2, y2, _, _, _ = sky_to_grism(ra, dec, [lmax] * 4, [order] * 4) + + xstack = np.hstack([x1, x2]) + ystack = np.hstack([y1, y2]) + + # Subarrays are only allowed in nircam tsgrism mode. The polynomial transforms + # only work with the full frame coordinates. + # The code here is called during extract_2d, + # and is creating bounding boxes which should be in the full frame coordinates, + # it just uses the input catalog and the magnitude + # to limit the objects that need bounding boxes. + + # Tsgrism is always supposed to have the source object at the same pixel, and that is + # hardcoded into the transforms. + # At least a while ago, the 2d extraction for tsgrism mode + # didn't call this bounding box code. So I think it's safe to leave the subarray + # subtraction out, i.e. do not subtract x/ystart. + + xmin = np.nanmin(xstack) + xmax = np.nanmax(xstack) + ymin = np.nanmin(ystack) + ymax = np.nanmax(ystack) + + if wfss_extract_half_height is not None and not obj.is_extended: + if input_model.meta.wcsinfo.dispersion_direction == 2: + ra_center, dec_center = ( + obj.sky_centroid.ra.value, + obj.sky_centroid.dec.value, + ) + center, _, _, _, _ = sky_to_grism( + ra_center, dec_center, (lmin + lmax) / 2, order + ) + xmin = center - wfss_extract_half_height + xmax = center + wfss_extract_half_height + elif input_model.meta.wcsinfo.dispersion_direction == 1: + ra_center, dec_center = ( + obj.sky_centroid.ra.value, + obj.sky_centroid.dec.value, + ) + _, center, _, _, _ = sky_to_grism( + ra_center, dec_center, (lmin + lmax) / 2, order + ) + ymin = center - wfss_extract_half_height + ymax = center + wfss_extract_half_height + else: + raise ValueError("Cannot determine dispersion direction.") + + # Convert floating-point corner values to whole pixel indexes + xmin = gwutils._toindex(xmin) # noqa: SLF001 + xmax = gwutils._toindex(xmax) # noqa: SLF001 + ymin = gwutils._toindex(ymin) # noqa: SLF001 + ymax = gwutils._toindex(ymax) # noqa: SLF001 + + # Don't add objects and orders that are entirely off the detector. + # "partial_order" marks objects that are near enough to the detector + # edge to have some spectrum on the detector. + # This is useful because the catalog often is created from a resampled direct + # image that is bigger than the detector FOV for a single grism exposure. + exclude = False + ispartial = False + + # Here we check to ensure that the extraction region `pts` + # has at least two pixels of width in the dispersion + # direction, and one in the cross-dispersed direction when + # placed into the subarray extent. + pts = np.array([[ymin, xmin], [ymax, xmax]]) + subarr_extent = np.array( + [ + [0, 0], + [ + input_model.meta.subarray.ysize - 1, + input_model.meta.subarray.xsize - 1, + ], + ] + ) + + if input_model.meta.wcsinfo.dispersion_direction == 1: + # X-axis is dispersion direction + disp_col = 1 + xdisp_col = 0 + else: + # Y-axis is dispersion direction + disp_col = 0 + xdisp_col = 1 + + dispaxis_check = (pts[1, disp_col] - subarr_extent[0, disp_col] > 0) and ( + subarr_extent[1, disp_col] - pts[0, disp_col] > 0 + ) + xdispaxis_check = (pts[1, xdisp_col] - subarr_extent[0, xdisp_col] >= 0) and ( + subarr_extent[1, xdisp_col] - pts[0, xdisp_col] >= 0 + ) + + contained = dispaxis_check and xdispaxis_check + + inidx = np.all(np.logical_and(subarr_extent[0] <= pts, pts <= subarr_extent[1]), axis=1) + + if not contained: + exclude = True + log.info(f"Excluding off-image object: {obj.label}, order {order}") + elif contained >= 1: + outbox = pts[np.logical_not(inidx)] + if len(outbox) > 0: + ispartial = True + log.info(f"Partial order on detector for obj: {obj.label} order: {order}") + + if not exclude: + order_bounding[order] = ((ymin, ymax), (xmin, xmax)) + waverange[order] = (lmin, lmax) + partial_order[order] = ispartial + + if len(order_bounding) > 0: + grism_objects.append( + GrismObject( + sid=obj.label, + order_bounding=order_bounding, + sky_centroid=obj.sky_centroid, + partial_order=partial_order, + waverange=waverange, + sky_bbox_ll=obj.sky_bbox_ll, + sky_bbox_lr=obj.sky_bbox_lr, + sky_bbox_ul=obj.sky_bbox_ul, + sky_bbox_ur=obj.sky_bbox_ur, + xcentroid=xcenter, + ycentroid=ycenter, + is_extended=obj.is_extended, + isophotal_abmag=obj.isophotal_abmag, + ) + ) # At this point we have a list of grism objects limited to # isophotal_abmag < mmag_extract. We now need to further restrict @@ -714,15 +787,21 @@ def get_num_msa_open_shutters(shutter_state): shutter_state : str ``Slit.shutter_state`` attribute - a combination of ``1`` - open shutter, ``0`` - closed shutter, ``x`` - main shutter. + + Returns + ------- + num : int + The number of open shutters in the slitlet. """ - num = shutter_state.count('1') - if 'x' in shutter_state: + num = shutter_state.count("1") + if "x" in shutter_state: num += 1 return num def transform_bbox_from_shape(shape, order="C"): - """Create a bounding box from the shape of the data. + """ + Create a bounding box from the shape of the data. This is appropriate to attached to a transform. @@ -739,16 +818,17 @@ def transform_bbox_from_shape(shape, order="C"): Bounding box in y, x order if order is "C" (default) Boundsing box in x, y order if order is "F" """ - bbox = ((-0.5, shape[-2] - 0.5), - (-0.5, shape[-1] - 0.5)) + bbox = ((-0.5, shape[-2] - 0.5), (-0.5, shape[-1] - 0.5)) return bbox if order == "C" else bbox[::-1] def wcs_bbox_from_shape(shape): - """Create a bounding box from the shape of the data. + """ + Create a bounding box from the shape of the data. This is appropriate to attach to a wcs object + Parameters ---------- shape : tuple @@ -759,13 +839,13 @@ def wcs_bbox_from_shape(shape): bbox : tuple Bounding box in x, y order. """ - bbox = ((-0.5, shape[-1] - 0.5), - (-0.5, shape[-2] - 0.5)) + bbox = ((-0.5, shape[-1] - 0.5), (-0.5, shape[-2] - 0.5)) return bbox -def bounding_box_from_subarray(input_model, order='C'): - """Create a bounding box from the subarray size. +def bounding_box_from_subarray(input_model, order="C"): + """ + Create a bounding box from the subarray size. Note: The bounding_box assumes full frame coordinates. It is set to ((ystart, ystart + xsize), (xstart, xstart + xsize)). @@ -773,8 +853,8 @@ def bounding_box_from_subarray(input_model, order='C'): Parameters ---------- - input_model : `~jwst.datamodels.JwstDataModel` - The data model. + input_model : JwstDataModel + The input data model. order : str The order of the array. Either "C" or "F". @@ -795,18 +875,58 @@ def bounding_box_from_subarray(input_model, order='C'): bb_yend = input_model.meta.subarray.ysize - 0.5 bbox = ((bb_ystart, bb_yend), (bb_xstart, bb_xend)) - return bbox if order == 'C' else bbox[::-1] + return bbox if order == "C" else bbox[::-1] def update_s_region_imaging(model): - """ - Update the ``S_REGION`` keyword using ``WCS.footprint``. - """ + """Update the ``S_REGION`` keyword using ``WCS.footprint``.""" s_region = compute_s_region_imaging(model.meta.wcs, shape=model.data.shape, center=False) if s_region is not None: model.meta.wcsinfo.s_region = s_region +def update_s_region_lrs(model, reference_files): + """ + Update ``S_REGION`` using V2,V3 of the slit corners from reference file. + + s_region for model is updated in place. + + Parameters + ---------- + model : DataModel + Input model + reference_files : list + List of reference files for assign_wcs. + """ + refmodel = MiriLRSSpecwcsModel(reference_files["specwcs"]) + + v2vert1 = refmodel.meta.v2_vert1 + v2vert2 = refmodel.meta.v2_vert2 + v2vert3 = refmodel.meta.v2_vert3 + v2vert4 = refmodel.meta.v2_vert4 + + v3vert1 = refmodel.meta.v3_vert1 + v3vert2 = refmodel.meta.v3_vert2 + v3vert3 = refmodel.meta.v3_vert3 + v3vert4 = refmodel.meta.v3_vert4 + + refmodel.close() + v2 = [v2vert1, v2vert2, v2vert3, v2vert4] + v3 = [v3vert1, v3vert2, v3vert3, v3vert4] + + if any(elem is None for elem in v2) or any(elem is None for elem in v3): + log.info("The V2,V3 coordinates of the MIRI LRS-Fixed slit contains NaN values.") + log.info("The s_region will not be updated") + + lam = 7.0 # wavelength does not matter for s region so jwst assign a value in range of LRS + s = model.meta.wcs.transform("v2v3", "world", v2, v3, lam) + a = s[0] + b = s[1] + footprint = np.array([[a[0], b[0]], [a[1], b[1]], [a[2], b[2]], [a[3], b[3]]]) + + update_s_region_keyword(model, footprint) + + def compute_footprint_spectral(model): """ Determine spatial footprint for spectral observations using the instrument model. @@ -815,6 +935,13 @@ def compute_footprint_spectral(model): ---------- model : `~jwst.datamodels.IFUImageModel` The output of assign_wcs. + + Returns + ------- + footprint : ndarray + The spatial footprint of the observation. + spectral_region : tuple + The wavelength range for the observation. """ swcs = model.meta.wcs bbox = swcs.bounding_box @@ -835,39 +962,49 @@ def compute_footprint_spectral(model): min_ra = min_ra + 360.0 if max_ra >= 360.0: max_ra = max_ra - 360.0 - footprint = np.array([[min_ra, np.nanmin(dec)], - [max_ra, np.nanmin(dec)], - [max_ra, np.nanmax(dec)], - [min_ra, np.nanmax(dec)]]) + footprint = np.array( + [ + [min_ra, np.nanmin(dec)], + [max_ra, np.nanmin(dec)], + [max_ra, np.nanmax(dec)], + [min_ra, np.nanmax(dec)], + ] + ) lam_min = np.nanmin(lam) lam_max = np.nanmax(lam) return footprint, (lam_min, lam_max) def update_s_region_spectral(model): - """ Update the S_REGION keyword. - """ + """Update the S_REGION keyword.""" footprint, spectral_region = compute_footprint_spectral(model) update_s_region_keyword(model, footprint) model.meta.wcsinfo.spectral_region = spectral_region def compute_footprint_nrs_slit(slit): - """ Compute the footprint of a Nirspec slit using the instrument model. + """ + Compute the footprint of a Nirspec slit using the instrument model. Parameters ---------- slit : `~jwst.datamodels.SlitModel` + The slit model. + + Returns + ------- + footprint : ndarray + The spatial footprint + spectral_region : tuple + The wavelength range for the observation. """ slit2world = slit.meta.wcs.get_transform("slit_frame", "world") # Define the corners of a virtual slit. The center of the slit is (0, 0). - virtual_corners_x = [-.5, -.5, .5, .5] + virtual_corners_x = [-0.5, -0.5, 0.5, 0.5] virtual_corners_y = [slit.slit_ymin, slit.slit_ymax, slit.slit_ymax, slit.slit_ymin] # Use a default wavelength or 2 microns as input to the transform. input_lam = [2e-6] * 4 - ra, dec, lam = slit2world(virtual_corners_x, - virtual_corners_y, - input_lam) + ra, dec, lam = slit2world(virtual_corners_x, virtual_corners_y, input_lam) footprint = np.array([ra, dec]).T lam_min = np.nanmin(lam) lam_max = np.nanmax(lam) @@ -881,8 +1018,7 @@ def update_s_region_nrs_slit(slit): def update_s_region_keyword(model, footprint): - """ Update the S_REGION keyword. - """ + """Update the S_REGION keyword.""" s_region = compute_s_region_keyword(footprint) if s_region is not None: model.meta.wcsinfo.s_region = s_region @@ -901,7 +1037,7 @@ def compute_footprint_nrs_ifu(dmodel, mod): Parameters ---------- - output_model : `~jwst.datamodels.IFUImageModel` + dmodel : `~jwst.datamodels.IFUImageModel` The output of assign_wcs. mod : module The imported ``nirspec`` module. @@ -1002,8 +1138,13 @@ def velocity_correction(velosys): ---------- velosys : float Radial velocity wrt Barycenter [m / s]. + + Returns + ------- + model : `astropy.modeling.Model` + The velocity correction model. """ - correction = (1 / (1 + velosys / c.value)) + correction = 1 / (1 + velosys / c.value) model = astmodels.Identity(1) * astmodels.Const1D(correction, name="velocity_correction") model.inverse = astmodels.Identity(1) / astmodels.Const1D(correction, name="inv_vel_correction") @@ -1011,7 +1152,8 @@ def velocity_correction(velosys): def wrap_ra(ravalues): - """Test for 0/360 wrapping in ra values. + """ + Test for 0/360 wrapping in ra values. If exists it makes it difficult to determine ra range of a region on the sky. This problem is solved by putting them all @@ -1020,13 +1162,13 @@ def wrap_ra(ravalues): Parameters ---------- ravalues : numpy.ndarray - input RA values + The input RA values Returns - ------ - a numpy array of ra values all on "same side" of 0/360 border + ------- + np.ndarray + A numpy array of ra values all on "same side" of 0/360 border """ - ravalues_array = np.array(ravalues) index_good = np.where(np.isfinite(ravalues_array)) ravalues_wrap = ravalues_array[index_good].copy() @@ -1064,25 +1206,34 @@ def in_ifu_slice(slice_wcs, ra, dec, lam): Returns ------- x, y : float, ndarray - x, y locations within the slice. + The x, y locations within the slice. """ - slicer2world = slice_wcs.get_transform('slicer', 'world') + slicer2world = slice_wcs.get_transform("slicer", "world") slx, sly, sllam = slicer2world.inverse(ra, dec, lam) # Compute the slice X coordinate using the center of the slit. - SLX, _, _ = slice_wcs.get_transform('slit_frame', 'slicer')(0, 0, 2e-6) - onslice_ind = np.isclose(slx, SLX, atol=5e-4) + slx_center, _, _ = slice_wcs.get_transform("slit_frame", "slicer")(0, 0, 2e-6) + onslice_ind = np.isclose(slx, slx_center, atol=5e-4) return onslice_ind -def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, - max_inv_pix_error=0.01, inv_degree=None, - npoints=12, crpix=None, projection='TAN', - imwcs=None, **kwargs): +def update_fits_wcsinfo( + datamodel, + max_pix_error=0.01, + degree=None, + max_inv_pix_error=0.01, + inv_degree=None, + npoints=12, + crpix=None, + projection="TAN", + imwcs=None, + **kwargs, +): """ - Update ``datamodel.meta.wcsinfo`` based on a FITS WCS + SIP approximation - of a GWCS object. By default, this function will approximate + Update ``datamodel.meta.wcsinfo`` based on a FITS WCS + SIP approximation of a GWCS object. + + By default, this function will approximate the datamodel's GWCS object stored in ``datamodel.meta.wcs`` but it can also approximate a user-supplied GWCS object when provided via the ``imwcs`` parameter. @@ -1102,7 +1253,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, This function modifies input data model's ``datamodel.meta.wcsinfo`` members. - Parameters ---------- datamodel : `ImageModel` @@ -1111,13 +1261,11 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, is used to compute FITS WCS + SIP approximation. When ``imwcs`` is not `None` then computed FITS WCS will be an approximation of the WCS provided through the ``imwcs`` parameter. - max_pix_error : float, optional Maximum allowed error over the domain of the pixel array. This error is the equivalent pixel error that corresponds to the maximum error in the output coordinate resulting from the fit based on a nominal plate scale. - degree : int, iterable, None, optional Degree of the SIP polynomial. Default value `None` indicates that all allowed degree values (``[1...6]``) will be considered and @@ -1129,12 +1277,10 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Finally, ``degree`` can be an integer indicating the exact SIP degree to be fit to the WCS transformation. In this case ``max_pixel_error`` is ignored. - max_inv_pix_error : float, None, optional Maximum allowed inverse error over the domain of the pixel array in pixel units. With the default value of `None` no inverse is generated. - inv_degree : int, iterable, None, optional Degree of the SIP polynomial. Default value `None` indicates that all allowed degree values (``[1...6]``) will be considered and @@ -1146,11 +1292,9 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, Finally, ``degree`` can be an integer indicating the exact SIP degree to be fit to the WCS transformation. In this case ``max_inv_pixel_error`` is ignored. - npoints : int, optional The number of points in each dimension to sample the bounding box for use in the SIP fit. Minimum number of points is 3. - crpix : list of float, None, optional Coordinates (1-based) of the reference point for the new FITS WCS. When not provided, i.e., when set to `None` (default) the reference @@ -1158,7 +1302,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, ``wcsinfo`` does not contain ``crpix`` information, then the reference pixel will be chosen near the center of the bounding box for axes corresponding to the celestial frame. - projection : str, `~astropy.modeling.projections.Pix2SkyProjection`, optional Projection to be used for the created FITS WCS. It can be specified as a string of three characters specifying a FITS projection code @@ -1171,7 +1314,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, reference_api.html#module-astropy.modeling.projections>`_ projection models inherited from :py:class:`~astropy.modeling.projections.Pix2SkyProjection`. - imwcs : `gwcs.WCS`, None, optional Imaging GWCS object for WFSS mode whose FITS WCS approximation should be computed and stored in the ``datamodel.meta.wcsinfo`` field. @@ -1184,19 +1326,21 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, WCS from ``datamodel.meta.wcs`` will result in the GWCS and FITS WCS descriptions to diverge. - Other Parameters - ---------------- - bounding_box : tuple, None, optional - A pair of tuples, each consisting of two numbers - Represents the range of pixel values in both dimensions - ((xmin, xmax), (ymin, ymax)) - - verbose : bool, optional - Print progress of fits. + **kwargs : dict, optional + Additional parameters to be passed to + :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. + These may include: + bounding_box : tuple, None, optional + A pair of tuples, each consisting of two numbers + Represents the range of pixel values in both dimensions + ((xmin, xmax), (ymin, ymax)) + verbose : bool, optional + Print progress of fits. Returns ------- - FITS header with all SIP WCS keywords + `~astropy.io.fits.Header` + FITS header with all SIP WCS keywords Raises ------ @@ -1212,7 +1356,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, to floating point problems that arise with high powers. For more details, see :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. - """ if crpix is None: crpix = [datamodel.meta.wcsinfo.crpix1, datamodel.meta.wcsinfo.crpix2] @@ -1224,9 +1367,6 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, if imwcs is None: imwcs = datamodel.meta.wcs - # make a copy of kwargs: - kwargs = {k: v for k, v in kwargs.items()} - # limit default 'degree' ranges to _MAX_SIP_DEGREE: if degree is None: degree = range(1, _MAX_SIP_DEGREE) @@ -1241,22 +1381,33 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, npoints=npoints, crpix=crpix, projection=projection, - **kwargs + **kwargs, ) # update meta.wcsinfo with FITS keywords except for naxis* - del hdr['naxis*'] + del hdr["naxis*"] # maintain convention of lowercase keys hdr_dict = {k.lower(): v for k, v in hdr.items()} # delete naxis, cdelt, pc from wcsinfo - rm_keys = ['naxis', 'cdelt1', 'cdelt2', - 'pc1_1', 'pc1_2', 'pc2_1', 'pc2_2', - 'a_order', 'b_order', 'ap_order', 'bp_order'] - - rm_keys.extend(f"{s}_{i}_{j}" for i in range(10) for j in range(10) - for s in ['a', 'b', 'ap', 'bp']) + rm_keys = [ + "naxis", + "cdelt1", + "cdelt2", + "pc1_1", + "pc1_2", + "pc2_1", + "pc2_2", + "a_order", + "b_order", + "ap_order", + "bp_order", + ] + + rm_keys.extend( + f"{s}_{i}_{j}" for i in range(10) for j in range(10) for s in ["a", "b", "ap", "bp"] + ) for key in rm_keys: if key in datamodel.meta.wcsinfo.instance: @@ -1269,11 +1420,11 @@ def update_fits_wcsinfo(datamodel, max_pix_error=0.01, degree=None, def wfss_imaging_wcs(wfss_model, imaging, bbox=None, **kwargs): - """ Add a FITS WCS approximation for imaging mode to WFSS headers. + """ + Add a FITS WCS approximation for imaging mode to WFSS headers. Parameters ---------- - wfss_model : `~ImageModel` Input WFSS model (NRC or NIS). imaging : func, callable @@ -1281,7 +1432,6 @@ def wfss_imaging_wcs(wfss_model, imaging, bbox=None, **kwargs): bbox : tuple or None The bounding box over which to approximate the distortion solution. Typically this is based on the shape of the direct image. - """ xstart = wfss_model.meta.subarray.xstart ystart = wfss_model.meta.subarray.ystart @@ -1295,25 +1445,30 @@ def wfss_imaging_wcs(wfss_model, imaging, bbox=None, **kwargs): else: imwcs.bounding_box = wcs_bbox_from_shape(wfss_model.data.shape) - _ = update_fits_wcsinfo(wfss_model, projection='TAN', imwcs=imwcs, bounding_box=None, **kwargs) + _ = update_fits_wcsinfo(wfss_model, projection="TAN", imwcs=imwcs, bounding_box=None, **kwargs) def get_wcs_reference_files(datamodel): - """Retrieve names of WCS reference files for NIS_WFSS and NRC_WFSS modes. + """ + Retrieve names of WCS reference files for NIS_WFSS and NRC_WFSS modes. Parameters ---------- - - datamodel : `~ImageModel` + datamodel : ImageModel Input WFSS file (NRC or NIS). + Returns + ------- + dict + Mapping between reftype (keys) and reference file name (vals). """ from jwst.assign_wcs import AssignWcsStep + refs = {} step = AssignWcsStep() for reftype in AssignWcsStep.reference_file_types: val = step.get_reference_file(datamodel, reftype) - if val.strip() == 'N/A': + if val.strip() == "N/A": refs[reftype] = None else: refs[reftype] = val From 673ad54f7be39afe93af44d40d4ff314f3a7dd39 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Mar 2025 13:07:19 -0500 Subject: [PATCH 2/3] Undo deprecation of niriss_bounding_box - separate issue --- jwst/assign_wcs/niriss.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 157e1c2df0..adae6eb14a 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -1,5 +1,4 @@ import logging -import warnings import asdf from astropy import coordinates as coord @@ -120,10 +119,6 @@ def niriss_bounding_box(input_model): """ Create a bounding box for the NIRISS model. - .. deprecated:: 1.17.2 - :py:func:`niriss_bounding_box` has been deprecated and will be removed - in a future release. - Parameters ---------- input_model : JwstDataModel @@ -134,13 +129,6 @@ def niriss_bounding_box(input_model): CompoundBoundingBox The bounding box for the NIRISS model. """ - warnings.warn( - "'niriss_bounding_bo()' has been deprecated since 1.17.2 and " - "will be removed in a future release. ", - DeprecationWarning, - stacklevel=2, - ) - bbox = {(order,): _niriss_order_bounding_box(input_model, order) for order in [1, 2, 3]} model = input_model.meta.wcs.forward_transform return CompoundBoundingBox.validate( From 7a77c6339c0b1f684a09ef3366eb74feaee7b090 Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Fri, 7 Mar 2025 15:57:14 -0500 Subject: [PATCH 3/3] Fix doc build failure for update_fits_wcsinfo --- jwst/assign_wcs/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index c43b3c785e..9cc9834824 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -1327,14 +1327,14 @@ def update_fits_wcsinfo( FITS WCS descriptions to diverge. **kwargs : dict, optional - Additional parameters to be passed to - :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. + Additional parameters to be passed to :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. These may include: - bounding_box : tuple, None, optional + + * bounding_box : tuple, None, optional A pair of tuples, each consisting of two numbers Represents the range of pixel values in both dimensions ((xmin, xmax), (ymin, ymax)) - verbose : bool, optional + * verbose : bool, optional Print progress of fits. Returns