diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7c94b3010e..f5faf71d3d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,7 +56,6 @@ repos: jwst/persistence/.* | jwst/photom/.* | jwst/refpix/.* | - jwst/resample/.* | jwst/reset/.* | jwst/residual_fringe/.* | jwst/rscd/.* | diff --git a/.ruff.toml b/.ruff.toml index 80b37194ba..89a29ca8bf 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -38,7 +38,6 @@ exclude = [ "jwst/persistence/**.py", "jwst/photom/**.py", "jwst/refpix/**.py", - "jwst/resample/**.py", "jwst/reset/**.py", "jwst/residual_fringe/**.py", "jwst/rscd/**.py", @@ -137,7 +136,6 @@ ignore-fully-untyped = true # Turn off annotation checking for fully untyped co "jwst/persistence/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/photom/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/refpix/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] -"jwst/resample/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/reset/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/residual_fringe/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/rscd/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] diff --git a/jwst/resample/__init__.py b/jwst/resample/__init__.py index 15e44b81a8..061621b8c7 100644 --- a/jwst/resample/__init__.py +++ b/jwst/resample/__init__.py @@ -1,4 +1,6 @@ +"""Apply resampling to JWST data.""" + from .resample_step import ResampleStep from .resample_spec_step import ResampleSpecStep -__all__ = ['ResampleStep', 'ResampleSpecStep'] +__all__ = ["ResampleStep", "ResampleSpecStep"] diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 38cdf6ceab..720718bc7b 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -1,6 +1,6 @@ import logging import json -import os +from pathlib import Path import re import numpy as np @@ -32,12 +32,12 @@ ] _SUPPORTED_CUSTOM_WCS_PARS = [ - 'pixel_scale_ratio', - 'pixel_scale', - 'output_shape', - 'crpix', - 'crval', - 'rotation', + "pixel_scale_ratio", + "pixel_scale", + "output_shape", + "crpix", + "crval", + "rotation", ] log = logging.getLogger(__name__) @@ -45,14 +45,30 @@ class ResampleImage(Resample): + """Resample imaging data.""" + dq_flag_name_map = pixel - def __init__(self, input_models, pixfrac=1.0, kernel="square", - fillval="NAN", weight_type="ivm", good_bits=0, - blendheaders=True, output_wcs=None, wcs_pars=None, - output=None, enable_ctx=True, enable_var=True, - compute_err=None, asn_id=None): + def __init__( + self, + input_models, + pixfrac=1.0, + kernel="square", + fillval="NAN", + weight_type="ivm", + good_bits=0, + blendheaders=True, + output_wcs=None, + wcs_pars=None, + output=None, + enable_ctx=True, + enable_var=True, + compute_err=None, + asn_id=None, + ): """ + Initialize the ResampleImage object. + Parameters ---------- input_models : ModelLibrary @@ -254,25 +270,24 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", asn_id : str, None, optional The association id. The id is what appears in the :ref:`asn-jwst-naming`. - """ self.input_models = input_models self.output_jwst_model = None self.output_dir = None self.output_filename = output - if output is not None and '.fits' not in str(output): + if output is not None and ".fits" not in str(output): self.output_dir = output self.output_filename = None - self.intermediate_suffix = 'outlier_i2d' + self.intermediate_suffix = "outlier_i2d" self.blendheaders = blendheaders if blendheaders: self._blender = ModelBlender( blend_ignore_attrs=[ - 'meta.photometry.pixelarea_steradians', - 'meta.photometry.pixelarea_arcsecsq', - 'meta.filename', + "meta.photometry.pixelarea_steradians", + "meta.photometry.pixelarea_arcsecsq", + "meta.filename", ] ) @@ -285,10 +300,7 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", unsup = [] unsup = set(wcs_pars.keys()).difference(_SUPPORTED_CUSTOM_WCS_PARS) if unsup: - raise KeyError( - "Unsupported custom WCS parameters: " - f"{','.join(map(repr, unsup))}." - ) + raise KeyError(f"Unsupported custom WCS parameters: {','.join(map(repr, unsup))}.") if output_wcs is None: # determine output WCS: @@ -313,13 +325,10 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", else: if wcs_pars: - log.warning( - "Ignoring 'wcs_pars' since 'output_wcs' is not None." - ) + log.warning("Ignoring 'wcs_pars' since 'output_wcs' is not None.") if output_wcs["wcs"].array_shape is None: raise ValueError( - "Custom WCS objects must have the 'array_shape' " - "attribute set (defined)." + "Custom WCS objects must have the 'array_shape' attribute set (defined)." ) super().__init__( @@ -336,27 +345,43 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", ) def input_model_to_dict(self, model, weight_type, enable_var, compute_err): - """ Converts a data model to a dictionary of keywords and values - expected by `stcal.resample`. Input parameters are the same as used - when initializing `ResampleImage`. + """ + Convert a data model to a dictionary of keywords and values expected by `stcal.resample`. - .. note:: Subclasses can override this method to add additional fields - to the dictionary as needed. + Parameters + ---------- + model : DataModel + A JWST data model. + weight_type : str + The weighting type for adding models' data. + enable_var : bool + Indicates whether to resample variance arrays. + compute_err : str + The method to compute the output model's error array. Returns ------- - model_dict : dict - + dict + A dictionary of keywords and values expected by `stcal.resample`. """ return input_jwst_model_to_dict( - model=model, - weight_type=weight_type, - enable_var=enable_var, - compute_err=compute_err + model=model, weight_type=weight_type, enable_var=enable_var, compute_err=compute_err ) def create_output_jwst_model(self, ref_input_model=None): - """ Create a new blank model and update it's meta with info from ``ref_input_model``. """ + """ + Create a new blank model and update its meta with info from ``ref_input_model``. + + Parameters + ---------- + ref_input_model : `~jwst.datamodels.JwstDataModel`, optional + The reference input model from which to copy meta data. + + Returns + ------- + ImageModel + A new blank model with updated meta data. + """ output_model = datamodels.ImageModel(None) # tuple(self.output_wcs.array_shape)) # update meta data and wcs @@ -366,6 +391,16 @@ def create_output_jwst_model(self, ref_input_model=None): return output_model def update_output_model(self, model, info_dict): + """ + Add meta information to the output model. + + Parameters + ---------- + model : ImageModel + The output model to be updated. + info_dict : dict + A dictionary containing information about the resampling process. + """ model.data = info_dict["data"] model.wht = info_dict["wht"] if self._enable_ctx: @@ -397,25 +432,13 @@ def update_output_model(self, model, info_dict): model.meta.exposure.elapsed_exposure_time = info_dict["elapsed_exposure_time"] def add_model(self, model): - """ Resamples model image and either variance data (if ``enable_var`` - was `True`) or error data (if ``enable_err`` was `True`) and adds - them using appropriate weighting to the corresponding - arrays of the output model. It also updates resampled data weight, - the context array (if ``enable_ctx`` is `True`), relevant output - model's values. - - Whenever ``model`` has a unique group ID that was never processed - before, the "pointings" value of the output model is incremented and - the "group_id" attribute is updated. Also, time counters are updated - with new values from the input ``model`` by calling - :py:meth:`~Resample.update_time`. + """ + Add a single input model to the resampling. Parameters ---------- - model : dict - A dictionary containing data arrays and other meta attributes - and values of actual models used by pipelines. - + model : ImageModel + A JWST data model to be resampled. """ super().add_model( self.input_model_to_dict( @@ -426,28 +449,12 @@ def add_model(self, model): ) ) if self.output_jwst_model is None: - self.output_jwst_model = self.create_output_jwst_model( - ref_input_model=model - ) + self.output_jwst_model = self.create_output_jwst_model(ref_input_model=model) if self.blendheaders: self._blender.accumulate(model) - def finalize(self, free_memory=True): - """ Performs final computations from any intermediate values, - sets output model values, and optionally frees temporary/intermediate - objects. - - ``finalize`` calls :py:meth:`~Resample.finalize_resample_variance` and - :py:meth:`~Resample.finalize_time_info`. - - .. warning:: - If ``enable_var=True`` then intermediate arrays holding variance - weights will be lost and so continuing adding new models after - a call to :py:meth:`~Resample.finalize` will result in incorrect - variance. In this case `finalize` will set the finalized flag to - `True`. - - """ + def finalize(self): + """Perform final computations and set output model values and metadata.""" if self.blendheaders: self._blender.finalize_model(self.output_jwst_model) super().finalize() @@ -462,10 +469,13 @@ def finalize(self, free_memory=True): self.update_fits_wcsinfo(self.output_jwst_model) assign_wcs_util.update_s_region_imaging(self.output_jwst_model) - self.output_jwst_model.meta.cal_step.resample = 'COMPLETE' + self.output_jwst_model.meta.cal_step.resample = "COMPLETE" def reset_arrays(self, n_input_models=None): - """ Initialize/reset `Drizzle` objects, `ModelBlender`, output model + """ + Initialize/reset between finalize() and add_model() calls. + + Resets or re-initializes `Drizzle` objects, `ModelBlender`, output model and arrays, and time counters. Output WCS and shape are not modified from `Resample` object initialization. This method needs to be called before calling :py:meth:`add_model` for the first time after @@ -477,22 +487,23 @@ def reset_arrays(self, n_input_models=None): Number of input models expected to be resampled. When provided, this is used to estimate memory requirements and optimize memory allocation for the context array. - """ super().reset_arrays(n_input_models=n_input_models) if self.blendheaders: self._blender = ModelBlender( blend_ignore_attrs=[ - 'meta.photometry.pixelarea_steradians', - 'meta.photometry.pixelarea_arcsecsq', - 'meta.filename', + "meta.photometry.pixelarea_steradians", + "meta.photometry.pixelarea_arcsecsq", + "meta.filename", ] ) self.output_jwst_model = None def resample_group(self, indices): - """ Resample multiple input images that belong to a single - ``group_id`` as specified by ``indices``. If ``output_jwst_model`` + """ + Resample multiple input images belonging to a single ``group_id``. + + If ``output_jwst_model`` was created by a previous call to this method, ``output_jwst_model`` as well as other arrays (weights, context, etc.) will be cleared. Upon completion, this method calls :py:meth:`finalize` to compute @@ -511,12 +522,11 @@ def resample_group(self, indices): output_jwst_model Resampled model with populated data, weights, error arrays and other attributes. - """ if self.output_jwst_model is not None: self.reset_arrays(n_input_models=len(indices)) - output_model_filename = '' + output_model_filename = "" log.info(f"{len(indices)} exposures to drizzle together") first = True @@ -527,16 +537,12 @@ def resample_group(self, indices): if self.output_jwst_model is None: # Determine output file type from input exposure filenames # Use this for defining the output filename - indx = model.meta.filename.rfind('.') + indx = model.meta.filename.rfind(".") output_type = model.meta.filename[indx:] - output_root = '_'.join(model.meta.filename.replace( - output_type, - '' - ).split('_')[:-1]) - output_model_filename = ( - f'{output_root}_' - f'{self.intermediate_suffix}{output_type}' + output_root = "_".join( + model.meta.filename.replace(output_type, "").split("_")[:-1] ) + output_model_filename = f"{output_root}_{self.intermediate_suffix}{output_type}" if isinstance(model, datamodels.SlitModel): # must call this explicitly to populate area extension @@ -558,11 +564,17 @@ def resample_group(self, indices): return self.output_jwst_model def resample_many_to_many(self, in_memory=True): - """Resample many inputs to many outputs where outputs have a common frame. + """ + Resample many inputs to many outputs where outputs have a common frame. + + Coadd only different detectors of the same exposure, i.e. map NRCA5 and + NRCB5 onto the same output image, as they image different areas of the + sky. + + Used for outlier detection. Parameters ---------- - in_memory : bool, optional Indicates whether to return a `ModelLibrary` with resampled models loaded in memory or whether to serialize resampled models to @@ -570,26 +582,21 @@ def resample_many_to_many(self, in_memory=True): info. See https://stpipe.readthedocs.io/en/latest/model_library.html#on-disk-mode for more details. - Coadd only different detectors of the same exposure, i.e. map NRCA5 and - NRCB5 onto the same output image, as they image different areas of the - sky. - - Used for outlier detection + Returns + ------- + ModelLibrary + A library of resampled models. """ output_models = [] - for group_id, indices in self.input_models.group_indices.items(): - + for _group_id, indices in self.input_models.group_indices.items(): output_model = self.resample_group(indices) if not in_memory: # Write out model to disk, then return filename output_name = output_model.meta.filename if self.output_dir is not None: - output_name = os.path.join( - self.output_dir, - output_name - ) + output_name = Path(self.output_dir) / output_name output_model.save(output_name) log.info(f"Saved model in {output_name}") output_models.append(output_name) @@ -602,14 +609,20 @@ def resample_many_to_many(self, in_memory=True): else: # build ModelLibrary as an association from the output files # this saves memory if there are multiple groups - asn = asn_from_list(output_models, product_name='outlier_i2d', asn_id="abcdefg") + asn = asn_from_list(output_models, product_name="outlier_i2d", asn_id="abcdefg") asn_dict = json.loads(asn.dump()[1]) # serializes the asn and converts to dict return ModelLibrary(asn_dict, on_disk=True) def resample_many_to_one(self): - """Resample and coadd many inputs to a single output. + """ + Resample and coadd many inputs to a single output. + + Used for stage 3 resampling. - Used for stage 3 resampling + Returns + ------- + ImageModel + The resampled and coadded image. """ log.info("Resampling science and variance data") @@ -628,6 +641,11 @@ def resample_many_to_one(self): def update_fits_wcsinfo(model): """ Update FITS WCS keywords of the resampled image. + + Parameters + ---------- + model : ImageModel + The resampled image """ # Delete any SIP-related keywords first pattern = r"^(cd[12]_[12]|[ab]p?_\d_\d|[ab]p?_order)$" @@ -656,48 +674,51 @@ def update_fits_wcsinfo(model): model.meta.wcsinfo.ctype2 = "DEC--TAN" # Remove no longer relevant WCS keywords - rm_keys = ['v2_ref', 'v3_ref', 'ra_ref', 'dec_ref', 'roll_ref', - 'v3yangle', 'vparity'] + rm_keys = ["v2_ref", "v3_ref", "ra_ref", "dec_ref", "roll_ref", "v3yangle", "vparity"] for key in rm_keys: if key in model.meta.wcsinfo.instance: del model.meta.wcsinfo.instance[key] def input_jwst_model_to_dict(model, weight_type, enable_var, compute_err): - """ Converts a data model to a dictionary of keywords and values - expected by `stcal.resample`. Input parameters are the same as used - when initializing `ResampleImage`. + """ + Convert a data model to a dictionary of keywords and values expected by `stcal.resample`. + + Parameters + ---------- + model : DataModel + A JWST data model. + weight_type : str + The weighting type for adding models' data. + enable_var : bool + Indicates whether to resample variance arrays. + compute_err : str + The method to compute the output model's error array. Returns ------- - model_dict : dict - + dict + A dictionary of keywords and values expected by `stcal.resample`. """ - model_dict = { # arrays: "data": model.data, "dq": model.dq, - # meta: "filename": model.meta.filename, "group_id": model.meta.group_id, "wcs": model.meta.wcs, "wcsinfo": model.meta.wcsinfo, "bunit_data": model.meta.bunit_data, - "exposure_time": model.meta.exposure.exposure_time, "start_time": model.meta.exposure.start_time, "end_time": model.meta.exposure.end_time, "duration": model.meta.exposure.duration, "measurement_time": model.meta.exposure.measurement_time, - "pixelarea_steradians": model.meta.photometry.pixelarea_steradians, "pixelarea_arcsecsq": model.meta.photometry.pixelarea_arcsecsq, - "level": model.meta.background.level, # sky level "subtracted": model.meta.background.subtracted, - # spectroscopy-specific: "instrument_name": model.meta.instrument.name, "exposure_type": model.meta.exposure.type, @@ -708,8 +729,7 @@ def input_jwst_model_to_dict(model, weight_type, enable_var, compute_err): model_dict["var_rnoise"] = model.var_rnoise model_dict["var_poisson"] = model.var_poisson - elif (weight_type is not None and - weight_type.startswith('ivm')): + elif weight_type is not None and weight_type.startswith("ivm"): model_dict["var_rnoise"] = model.var_rnoise if compute_err == "driz_err": @@ -720,11 +740,27 @@ def input_jwst_model_to_dict(model, weight_type, enable_var, compute_err): def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): """ - xmin, xmax, ymin, ymax - integer coordinates of pixel boundaries - step - distance between points along an edge - shrink - number of pixels by which to reduce `shape` + Compute list of boundary points for a rectangle. + + Parameters + ---------- + xmin, xmax, ymin, ymax : int + Coordinates of pixel boundaries. + dx, dy : int + Distance between points along an edge in the X and Y directions, respectively. + shrink : int + Number of pixels by which to reduce `shape` - Returns a list of points and the area of the rectangle + Returns + ------- + x, y : numpy.ndarray + Arrays of X and Y coordinates of the boundary points. + area : float + Area of the rectangle. + center : tuple + Center of the rectangle. + b, r, t, l : slice + Slices for the bottom, right, top, and left edges, respectively. """ nx = xmax - xmin + 1 ny = ymax - ymin + 1 @@ -750,9 +786,9 @@ def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): y = np.empty(size) b = np.s_[0:sx] # bottom edge - r = np.s_[sx:sx + sy] # right edge - t = np.s_[sx + sy:2 * sx + sy] # top edge - l = np.s_[2 * sx + sy:2 * sx + 2 * sy] # left + r = np.s_[sx : sx + sy] # right edge + t = np.s_[sx + sy : 2 * sx + sy] # top edge + l = np.s_[2 * sx + sy : 2 * sx + 2 * sy] # left x[b] = np.linspace(xmin, xmax, sx, False) y[b] = ymin @@ -770,13 +806,24 @@ def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): def compute_image_pixel_area(wcs): - """ Computes pixel area in steradians. + """ + Compute pixel area in steradians from a WCS. + + Parameters + ---------- + wcs : gwcs.WCS + A WCS object. + + Returns + ------- + float + Pixel area in steradians. """ if wcs.array_shape is None: raise ValueError("WCS must have array_shape attribute set.") valid_polygon = False - spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] + spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0] ny, nx = wcs.array_shape ((xmin, xmax), (ymin, ymax)) = wcs.bounding_box @@ -801,7 +848,7 @@ def compute_image_pixel_area(wcs): ymin=ymin, ymax=ymax, dx=min((xmax - xmin) // 4, 15), - dy=min((ymax - ymin) // 4, 15) + dy=min((ymax - ymin) // 4, 15), ) except ValueError: return None @@ -812,10 +859,9 @@ def compute_image_pixel_area(wcs): limits = [ymin, xmax, ymax, xmin] - for j in range(4): + for _ in range(4): sl = [b, r, t, l][k] - if not (np.all(np.isfinite(ra[sl])) and - np.all(np.isfinite(dec[sl]))): + if not (np.all(np.isfinite(ra[sl])) and np.all(np.isfinite(dec[sl]))): limits[k] += dxy[k] k = (k + 1) % 4 break @@ -835,8 +881,7 @@ def compute_image_pixel_area(wcs): sky_area = SphericalPolygon.from_radec(ra, dec, center=wcenter).area() if sky_area > 2 * np.pi: log.warning( - "Unexpectedly large computed sky area for an image. " - "Setting area to: 4*Pi - area" + "Unexpectedly large computed sky area for an image. Setting area to: 4*Pi - area" ) sky_area = 4 * np.pi - sky_area pix_area = sky_area / image_area @@ -863,7 +908,5 @@ def copy_asn_info_from_library(library, output_model): return if (asn_pool := library.asn.get("asn_pool", None)) is not None: output_model.meta.asn.pool_name = asn_pool - if ( - asn_table_name := library.asn.get("table_name", None) - ) is not None: + if (asn_table_name := library.asn.get("table_name", None)) is not None: output_model.meta.asn.table_name = asn_table_name diff --git a/jwst/resample/resample_spec.py b/jwst/resample/resample_spec.py index b19ed16ad5..c485923014 100644 --- a/jwst/resample/resample_spec.py +++ b/jwst/resample/resample_spec.py @@ -6,7 +6,12 @@ from astropy import coordinates as coord from astropy import units as u from astropy.modeling.models import ( - Const1D, Linear1D, Mapping, Pix2Sky_TAN, RotateNative2Celestial, Tabular1D + Const1D, + Linear1D, + Mapping, + Pix2Sky_TAN, + RotateNative2Celestial, + Tabular1D, ) from astropy.modeling.fitting import LinearLSQFitter from astropy.stats import sigma_clip @@ -17,8 +22,7 @@ from stdatamodels.jwst import datamodels -from jwst.assign_wcs.util import compute_scale, wcs_bbox_from_shape,\ - wrap_ra +from jwst.assign_wcs.util import compute_scale, wcs_bbox_from_shape, wrap_ra from jwst.resample import resample_utils from jwst.resample.resample import ResampleImage from jwst.datamodels import ModelLibrary @@ -32,7 +36,7 @@ class ResampleSpec(ResampleImage): """ - This is the controlling routine for the resampling process for spectral data. + Resample spectral data. Notes ----- @@ -46,22 +50,24 @@ class ResampleSpec(ResampleImage): 3. Updates output data model with output arrays from drizzle, including a record of metadata from all input models. """ - def __init__(self, input_models, pixfrac=1.0, kernel="square", - fillval="NAN", weight_type="ivm", good_bits=0, - blendheaders=True, output_wcs=None, wcs_pars=None, - output=None, enable_ctx=True, enable_var=True, - compute_err=None, asn_id=None, in_memory=True): + + def __init__(self, input_models, good_bits=0, output_wcs=None, wcs_pars=None, **kwargs): """ + Initialize the ResampleSpec object. + Parameters ---------- - input_models : list of objects - list of data models, one for each input image - - output : str - filename for output - - kwargs : dict - Other parameters + input_models : list + List of data models, one for each input image + good_bits : int + Bit values that should be considered good when creating a mask + output_wcs : dict + Output WCS parameters + wcs_pars : dict + Additional parameters for WCS + **kwargs : dict + Additional parameters to be passed into `ResampleImage.__init__()`. + See the docstring of that method for more details. """ shape = None pixel_scale = None @@ -69,9 +75,7 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", pixel_scale_ratio = 1.0 if isinstance(output_wcs, dict): - output_wcs_dict = { - k: v for k, v in output_wcs.items() if k != "wcs" - } + output_wcs_dict = {k: v for k, v in output_wcs.items() if k != "wcs"} output_wcs = output_wcs["wcs"] pixel_scale = output_wcs_dict.get("pixel_scale") pixel_area = output_wcs_dict.get("pixel_area") @@ -92,9 +96,10 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", # Get an average input pixel scale for parameter calculations disp_axis = input_models[0].meta.wcsinfo.dispersion_direction input_pixscale0 = 3600.0 * compute_spectral_pixel_scale( - input_models[0].meta.wcs, disp_axis=disp_axis) + input_models[0].meta.wcs, disp_axis=disp_axis + ) if np.isnan(input_pixscale0): - log.warning('Input pixel scale could not be determined.') + log.warning("Input pixel scale could not be determined.") if pixel_scale is not None: log.warning( "Output pixel scale setting is not supported without an " @@ -105,7 +110,7 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", nominal_area = input_models[0].meta.photometry.pixelarea_steradians if nominal_area is None: - log.warning('Nominal pixel area not set in input data.') + log.warning("Nominal pixel area not set in input data.") log.warning( "Setting output pixel scale is not supported without an " "input pixel scale. Setting pixel_scale=None." @@ -117,24 +122,25 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", # Use user-supplied reference WCS for the resampled image: if pixel_area is None: if nominal_area is None: - log.warning("Unable to compute output pixel area " - "from 'output_wcs'.") + log.warning("Unable to compute output pixel area from 'output_wcs'.") output_pix_area = None else: # Compare input and output spatial scale to update nominal area output_pscale = 3600.0 * compute_spectral_pixel_scale( - output_wcs, disp_axis=disp_axis) + output_wcs, disp_axis=disp_axis + ) if np.isnan(output_pscale) or np.isnan(input_pixscale0): - log.warning('Output pixel scale could not be determined.') + log.warning("Output pixel scale could not be determined.") output_pix_area = None else: - log.debug(f'Setting output pixel area from the approximate ' - f'output spatial scale: {output_pscale}') - output_pix_area = (output_pscale * nominal_area - / input_pixscale0) + log.debug( + f"Setting output pixel area from the approximate " + f"output spatial scale: {output_pscale}" + ) + output_pix_area = output_pscale * nominal_area / input_pixscale0 else: - log.debug(f'Using output pixel area: {pixel_area}') + log.debug(f"Using output pixel area: {pixel_area}") output_pix_area = pixel_area # Set the pixel scale ratio for scaling reasons @@ -148,15 +154,16 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", output_wcs.array_shape = shape else: if pixel_scale is not None and nominal_area is not None: - log.info(f'Specified output pixel scale: {pixel_scale} arcsec.') + log.info(f"Specified output pixel scale: {pixel_scale} arcsec.") # Set the pscale ratio from the input pixel scale # (pixel scale ratio is output / input) if pixel_scale_ratio != 1.0: - log.warning('Ignoring input pixel_scale_ratio in favor ' - 'of explicit pixel_scale.') + log.warning( + "Ignoring input pixel_scale_ratio in favor of explicit pixel_scale." + ) pixel_scale_ratio = input_pixscale0 / pixel_scale - log.info(f'Computed output pixel scale ratio: {pixel_scale_ratio:.5g}') + log.info(f"Computed output pixel scale ratio: {pixel_scale_ratio:.5g}") # Define output WCS based on all inputs, including a reference WCS. # These functions internally use pixel_scale_ratio to accommodate @@ -164,21 +171,16 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", # Any other customizations (crpix, crval, rotation) are ignored. if resample_utils.is_sky_like(input_models[0].meta.wcs.output_frame): if input_models[0].meta.instrument.name != "NIRSPEC": - output_wcs = self.build_interpolated_output_wcs( - input_models, - pixel_scale_ratio=pixel_scale_ratio + input_models, pixel_scale_ratio=pixel_scale_ratio ) else: output_wcs = self.build_nirspec_output_wcs( - input_models, - good_bits=good_bits, - pixel_scale_ratio=pixel_scale_ratio + input_models, good_bits=good_bits, pixel_scale_ratio=pixel_scale_ratio ) else: output_wcs = self.build_nirspec_lamp_output_wcs( - input_models, - pixel_scale_ratio=pixel_scale_ratio + input_models, pixel_scale_ratio=pixel_scale_ratio ) # Use the nominal output pixel area in sr if available, @@ -193,10 +195,9 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", self._spec_output_pix_area = output_pix_area if pixel_scale is None: - log.info(f'Specified output pixel scale ratio: {pixel_scale_ratio}.') - pixel_scale = 3600.0 * compute_spectral_pixel_scale( - output_wcs, disp_axis=disp_axis) - log.info(f'Computed output pixel scale: {pixel_scale:.5g} arcsec.') + log.info(f"Specified output pixel scale ratio: {pixel_scale_ratio}.") + pixel_scale = 3600.0 * compute_spectral_pixel_scale(output_wcs, disp_axis=disp_axis) + log.info(f"Computed output pixel scale: {pixel_scale:.5g} arcsec.") if output_wcs_dict is None: output_wcs_dict = {} @@ -208,25 +209,24 @@ def __init__(self, input_models, pixfrac=1.0, kernel="square", library = ModelLibrary(input_models, on_disk=False) super().__init__( - input_models=library, - pixfrac=pixfrac, - kernel=kernel, - fillval=fillval, - weight_type=weight_type, - good_bits=good_bits, - blendheaders=blendheaders, - output_wcs=output_wcs_dict, - wcs_pars=None, - output=output, - enable_ctx=enable_ctx, - enable_var=enable_var, - compute_err=compute_err, - asn_id=asn_id + library, good_bits=good_bits, output_wcs=output_wcs_dict, wcs_pars=None, **kwargs ) - self.intermediate_suffix = 'outlier_s2d' + self.intermediate_suffix = "outlier_s2d" def create_output_jwst_model(self, ref_input_model=None): - """ Create a new blank model and update its meta with info from ``ref_input_model``. """ + """ + Create a new blank model and update its meta with info from ``ref_input_model``. + + Parameters + ---------- + ref_input_model : `~jwst.datamodels.JwstDataModel`, optional + The reference input model from which to copy meta data. + + Returns + ------- + SlitModel + A new blank model with updated meta data. + """ output_model = datamodels.SlitModel(None) # update meta data and wcs if ref_input_model is not None: @@ -235,6 +235,16 @@ def create_output_jwst_model(self, ref_input_model=None): return output_model def update_output_model(self, model, info_dict): + """ + Add spectroscopy-specific meta information to the output model. + + Parameters + ---------- + model : SlitModel + The output model to be updated. + info_dict : dict + A dictionary containing information about the resampling process. + """ super().update_output_model(model, info_dict) if self._spec_output_pix_area is None: model.meta.photometry.pixelarea_steradians = None @@ -242,7 +252,7 @@ def update_output_model(self, model, info_dict): else: model.meta.photometry.pixelarea_steradians = self._spec_output_pix_area model.meta.photometry.pixelarea_arcsecsq = ( - self._spec_output_pix_area * np.rad2deg(3600)**2 + self._spec_output_pix_area * np.rad2deg(3600) ** 2 ) # TODO: this is helpful info that should be stored in products. @@ -254,8 +264,9 @@ def update_output_model(self, model, info_dict): # model.meta.resample.pointings # model.meta.cal_step.resample - def build_nirspec_output_wcs(self, input_models, refmodel=None, - good_bits=None, pixel_scale_ratio=1.0): + def build_nirspec_output_wcs( + self, input_models, refmodel=None, good_bits=None, pixel_scale_ratio=1.0 + ): """ Create a spatial/spectral WCS covering the footprint of the input. @@ -319,14 +330,14 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, refwcs = refmodel.meta.wcs # Set up the transforms that are needed - s2d = refwcs.get_transform('slit_frame', 'detector') - d2s = refwcs.get_transform('detector', 'slit_frame') - if 'moving_target' in refwcs.available_frames: - s2w = refwcs.get_transform('slit_frame', 'moving_target') - w2s = refwcs.get_transform('moving_target', 'slit_frame') + s2d = refwcs.get_transform("slit_frame", "detector") + d2s = refwcs.get_transform("detector", "slit_frame") + if "moving_target" in refwcs.available_frames: + s2w = refwcs.get_transform("slit_frame", "moving_target") + w2s = refwcs.get_transform("moving_target", "slit_frame") else: - s2w = refwcs.get_transform('slit_frame', 'world') - w2s = refwcs.get_transform('world', 'slit_frame') + s2w = refwcs.get_transform("slit_frame", "world") + w2s = refwcs.get_transform("world", "slit_frame") # Estimate position of the target without relying on the meta.target: # compute the mean spatial and wavelength coords weighted @@ -341,8 +352,9 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, # Reject the worst outliers in the data with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=AstropyUserWarning, - message=".*automatically clipped.*") + warnings.filterwarnings( + "ignore", category=AstropyUserWarning, message=".*automatically clipped.*" + ) weights = sigma_clip(refmodel_data, masked=True, sigma=100.0) weights = np.ma.filled(weights, fill_value=0.0) if not np.all(weights == 0.0): @@ -355,14 +367,13 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, # (at the center of the slit in x) targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l) sx, sy = s2d(0, wmean_s, wmean_l) - log.debug(f'Fiducial RA, Dec, wavelength: ' - f'{targ_ra}, {targ_dec}, {wmean_l}') - log.debug(f'Index at fiducial center: x={sx}, y={sy}') + log.debug(f"Fiducial RA, Dec, wavelength: {targ_ra}, {targ_dec}, {wmean_l}") + log.debug(f"Index at fiducial center: x={sx}, y={sy}") # Estimate spatial sampling from the reference model # at the center of the array lam_center_idx = int(np.mean(bbox, axis=1)[0]) - log.debug(f'Center of dispersion axis: {lam_center_idx}') + log.debug(f"Center of dispersion axis: {lam_center_idx}") grid_center = grid[0][:, lam_center_idx], grid[1][:, lam_center_idx] ra_ref, dec_ref, _ = np.array(refwcs(*grid_center)) @@ -389,7 +400,7 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, # Check whether sampling is more along RA or along Dec swap_xy = abs(pix_to_xtan.slope) < abs(pix_to_ytan.slope) - log.debug(f'Swap xy: {swap_xy}') + log.debug(f"Swap xy: {swap_xy}") # Get output wavelengths from all data ref_lam = _find_nirspec_output_sampling_wavelengths(all_wcs) @@ -399,7 +410,8 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, # Find the spatial extent in x/y tangent min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent( - all_wcs, undist2sky.inverse) + all_wcs, undist2sky.inverse + ) diff_y = np.abs(max_tan_y - min_tan_y) diff_x = np.abs(max_tan_x - min_tan_x) @@ -453,9 +465,9 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, else: ref_lam = 3 * ref_lam pixel_coord = [-0.5, 0, 0.5] - wavelength_transform = Tabular1D(points=pixel_coord, - lookup_table=ref_lam, - bounds_error=False, fill_value=np.nan) + wavelength_transform = Tabular1D( + points=pixel_coord, lookup_table=ref_lam, bounds_error=False, fill_value=np.nan + ) # For spatial coordinates, map detector pixels to tangent offset, # then to world coordinates (RA, Dec, wavelength in um). @@ -477,8 +489,8 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, # Make a 1D lookup table for all ny. # Allow linear extrapolation at the edges. slit_transform = Tabular1D( - points=np.arange(ny), lookup_table=slit_center, - bounds_error=False, fill_value=None) + points=np.arange(ny), lookup_table=slit_center, bounds_error=False, fill_value=None + ) # In the transform, the first slit coordinate is always set to 0 # to represent the "horizontal" center of the slit @@ -499,15 +511,16 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, slit2world = det2slit.inverse | pix2world # Create coordinate frames: detector, slit_frame, and world - det = cf.Frame2D(name='detector', axes_order=(0, 1)) - slit_spatial = cf.Frame2D(name='slit_spatial', axes_order=(0, 1), - unit=("", ""), axes_names=('x_slit', 'y_slit')) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), - unit=(u.micron,), axes_names=('wavelength',)) - slit_frame = cf.CompositeFrame([slit_spatial, spec], name='slit_frame') - sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), - reference_frame=coord.ICRS()) - world = cf.CompositeFrame([sky, spec], name='world') + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + slit_spatial = cf.Frame2D( + name="slit_spatial", axes_order=(0, 1), unit=("", ""), axes_names=("x_slit", "y_slit") + ) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + slit_frame = cf.CompositeFrame([slit_spatial, spec], name="slit_frame") + sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) + world = cf.CompositeFrame([sky, spec], name="world") pipeline = [(det, det2slit), (slit_frame, slit2world), (world, None)] output_wcs = WCS(pipeline) @@ -522,6 +535,20 @@ def build_nirspec_output_wcs(self, input_models, refmodel=None, def _max_spatial_extent(self, wcs_list, transform): """ Compute spatial coordinate limits for all nods in the tangent plane. + + Parameters + ---------- + wcs_list : list + List of WCS objects for all nods. + transform : callable + Function to convert RA, Dec to tangent plane coordinates. + + Returns + ------- + limits_x : tuple + Minimum and maximum x values. + limits_y : tuple + Minimum and maximum y values. """ limits_x = [np.inf, -np.inf] limits_y = [np.inf, -np.inf] @@ -534,7 +561,7 @@ def _max_spatial_extent(self, wcs_list, transform): dec = dec[good] xtan, ytan = transform(ra, dec) - for tan_all, limits in zip([xtan, ytan], [limits_x, limits_y]): + for tan_all, limits in zip([xtan, ytan], [limits_x, limits_y], strict=True): min_tan = np.min(tan_all) max_tan = np.max(tan_all) @@ -558,12 +585,18 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): - `detector`: image x, y - `world`: RA, Dec, wavelength + Parameters + ---------- + input_models : list + List of data models, one for each input image + pixel_scale_ratio : float + The ratio of the output pixel scale to the input pixel scale + Returns ------- output_wcs : `~gwcs.WCS` object A gwcs WCS object defining the output frame WCS """ - # for each input model convert slit x,y to ra,dec,lam # use first input model to set spatial scale # use center of appended ra and dec arrays to set up @@ -615,8 +648,7 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): all_wavelength = np.append(all_wavelength, wavelength_array) # find the center ra and dec for this slit at central wavelength - lam_center_index = int((bbox[spectral_axis][1] - - bbox[spectral_axis][0]) / 2) + lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) if spatial_axis == 0: # MIRI LRS spectral = 1, the spatial axis = 0 ra_slice = ra[lam_center_index, :] @@ -636,15 +668,15 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): # Filter out RuntimeWarnings due to computed NaNs in the WCS with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) # was ignore. need to make more specific - # at this center of slit find x,y tangent projection - x_tan, y_tan - x_tan, y_tan = undist2sky1.inverse(ra, dec) + warnings.simplefilter("ignore", RuntimeWarning) + # at this center of slit find x,y tangent projection - x_tan, y_tan + x_tan, y_tan = undist2sky1.inverse(ra, dec) # pull out data from center if spectral_axis == 0: x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] - else: # MIRI LRS Spectral Axis = 1, the WCS x axis is spatial + else: # MIRI LRS Spectral Axis = 1, the WCS x axis is spatial x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] @@ -692,16 +724,19 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): # Check if the data is MIRI LRS FIXED Slit. If it is then # the wavelength array needs to be flipped so that the resampled # dispersion direction matches the dispersion direction on the detector. - if input_models[0].meta.exposure.type == 'MIR_LRS-FIXEDSLIT': + if input_models[0].meta.exposure.type == "MIR_LRS-FIXEDSLIT": wavelength_array = np.flip(wavelength_array, axis=None) step = 1 stop = wavelength_array.shape[0] points = np.arange(0, stop, step) - pix_to_wavelength = Tabular1D(points=points, - lookup_table=wavelength_array, - bounds_error=False, fill_value=None, - name='pix2wavelength') + pix_to_wavelength = Tabular1D( + points=points, + lookup_table=wavelength_array, + bounds_error=False, + fill_value=None, + name="pix2wavelength", + ) # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both @@ -713,10 +748,13 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] - pix_to_wavelength.inverse = Tabular1D(points=points, - lookup_table=lookup_table, - bounds_error=False, fill_value=None, - name='wavelength2pix') + pix_to_wavelength.inverse = Tabular1D( + points=points, + lookup_table=lookup_table, + bounds_error=False, + fill_value=None, + name="wavelength2pix", + ) # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) @@ -757,7 +795,8 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): ## Use all the wcs min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent( - all_wcs, undist2sky.inverse) + all_wcs, undist2sky.inverse + ) diff_y = np.abs(max_tan_y - min_tan_y) diff_x = np.abs(max_tan_x - min_tan_x) pix_to_tan_slope_y = np.abs(pix_to_ytan.slope) @@ -770,23 +809,22 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): else: ny = int(np.ceil(diff_x / pix_to_tan_slope_x)) - offset_y = (ny)/2 * pix_to_tan_slope_y - offset_x = (ny)/2 * pix_to_tan_slope_x - pix_to_ytan.intercept = - slope_sign_y * offset_y - pix_to_xtan.intercept = - slope_sign_x * offset_x + offset_y = (ny) / 2 * pix_to_tan_slope_y + offset_x = (ny) / 2 * pix_to_tan_slope_x + pix_to_ytan.intercept = -slope_sign_y * offset_y + pix_to_xtan.intercept = -slope_sign_x * offset_x # define the output wcs transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength - det = cf.Frame2D(name='detector', axes_order=(0, 1)) - sky = cf.CelestialFrame(name='sky', axes_order=(0, 1), - reference_frame=coord.ICRS()) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), - unit=(u.micron,), axes_names=('wavelength',)) - world = cf.CompositeFrame([sky, spec], name='world') + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + world = cf.CompositeFrame([sky, spec], name="world") - pipeline = [(det, transform), - (world, None)] + pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) @@ -804,7 +842,7 @@ def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): """ - Create a spatial/spectral WCS output frame for NIRSpec lamp mode + Create a spatial/spectral WCS output frame for NIRSpec lamp mode. Creates output frame by linearly fitting x_msa, y_msa along the slit and producing a lookup table to interpolate wavelengths in the dispersion @@ -815,6 +853,13 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): - `detector`: image x, y - `world`: MSA x, MSA y, wavelength + Parameters + ---------- + input_models : list + List of data models, one for each input image + pixel_scale_ratio : float + The ratio of the output pixel scale to the input pixel scale + Returns ------- output_wcs : `~gwcs.WCS` object @@ -838,8 +883,7 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Find the center ra and dec for this slit at central wavelength - lam_center_index = int((bbox[spectral_axis][1] - - bbox[spectral_axis][0]) / 2) + lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) x_msa_array = x_msa.T[lam_center_index] y_msa_array = y_msa.T[lam_center_index] x_msa_array = x_msa_array[~np.isnan(x_msa_array)] @@ -858,10 +902,13 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): step = 1 stop = wavelength_array.shape[0] points = np.arange(0, stop, step) - pix_to_wavelength = Tabular1D(points=points, - lookup_table=wavelength_array, - bounds_error=False, fill_value=None, - name='pix2wavelength') + pix_to_wavelength = Tabular1D( + points=points, + lookup_table=wavelength_array, + bounds_error=False, + fill_value=None, + name="pix2wavelength", + ) # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both @@ -873,10 +920,13 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] - pix_to_wavelength.inverse = Tabular1D(points=points, - lookup_table=lookup_table, - bounds_error=False, fill_value=None, - name='wavelength2pix') + pix_to_wavelength.inverse = Tabular1D( + points=points, + lookup_table=lookup_table, + bounds_error=False, + fill_value=None, + name="wavelength2pix", + ) # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) @@ -886,14 +936,14 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): # define the output wcs transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength - det = cf.Frame2D(name='detector', axes_order=(0, 1)) - sky = cf.Frame2D(name=f'resampled_{model.meta.wcs.output_frame.name}', axes_order=(0, 1)) - spec = cf.SpectralFrame(name='spectral', axes_order=(2,), - unit=(u.micron,), axes_names=('wavelength',)) - world = cf.CompositeFrame([sky, spec], name='world') + det = cf.Frame2D(name="detector", axes_order=(0, 1)) + sky = cf.Frame2D(name=f"resampled_{model.meta.wcs.output_frame.name}", axes_order=(0, 1)) + spec = cf.SpectralFrame( + name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) + ) + world = cf.CompositeFrame([sky, spec], name="world") - pipeline = [(det, transform), - (world, None)] + pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) @@ -914,7 +964,17 @@ def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): def find_dispersion_axis(refmodel): """ - Find the dispersion axis (0-indexed) of the given 2D wavelength array + Find the dispersion axis (0-indexed) of the given 2D wavelength array. + + Parameters + ---------- + refmodel : `~jwst.datamodels.DataModel` + The input data model. + + Returns + ------- + dispaxis : int + The dispersion axis (0-indexed). """ dispaxis = refmodel.meta.wcsinfo.dispersion_direction # Change from 1 --> X and 2 --> Y to 0 --> X and 1 --> Y. @@ -953,7 +1013,7 @@ def _find_nirspec_output_sampling_wavelengths(wcs_list): while image_lam: best_overlap = -np.inf best_wcs = 0 - for k, (lam, lmin, lmax) in enumerate(image_lam): + for k, (_lam, lmin, lmax) in enumerate(image_lam): overlap = min(lam2, lmax) - max(lam1, lmin) if best_overlap < overlap: best_overlap = overlap @@ -989,7 +1049,8 @@ def _find_nirspec_output_sampling_wavelengths(wcs_list): def compute_spectral_pixel_scale(wcs, fiducial=None, disp_axis=1): - """Compute an approximate spatial pixel scale for spectral data. + """ + Compute an approximate spatial pixel scale for spectral data. Parameters ---------- @@ -1014,4 +1075,3 @@ def compute_spectral_pixel_scale(wcs, fiducial=None, disp_axis=1): pixel_scale = compute_scale(wcs, fiducial, disp_axis=disp_axis) return float(pixel_scale) - diff --git a/jwst/resample/resample_spec_step.py b/jwst/resample/resample_spec_step.py index 12a6d1cd1b..ac1f64f29e 100755 --- a/jwst/resample/resample_spec_step.py +++ b/jwst/resample/resample_spec_step.py @@ -15,19 +15,11 @@ # Force use of all DQ flagged data except for DO_NOT_USE and NON_SCIENCE -GOOD_BITS = '~DO_NOT_USE+NON_SCIENCE' +GOOD_BITS = "~DO_NOT_USE+NON_SCIENCE" class ResampleSpecStep(Step): - """ - ResampleSpecStep: Resample input data onto a regular grid using the - drizzle algorithm. - - Parameters - ----------- - input : `~jwst.datamodels.MultiSlitModel`, `~jwst.datamodels.ModelContainer`, Association - A single datamodel, a container of datamodels, or an association file - """ + """Resample spectral data onto a regular grid using the drizzle algorithm.""" class_alias = "resample_spec" @@ -43,18 +35,31 @@ class ResampleSpecStep(Step): single = boolean(default=False) # Resample each input to its own output grid blendheaders = boolean(default=True) # Blend metadata from inputs into output in_memory = boolean(default=True) # Keep images in memory - """ # noqa: E501 + """ # noqa: E501 + + def process(self, input_data): + """ + Run the resample step on the input data. - def process(self, input): - input_new = datamodels.open(input) + Parameters + ---------- + input_data : MultiSlitModel, ModelContainer, str + A single datamodel, a container of datamodels, or an association file. + + Returns + ------- + SlitModel or MultiSlitModel + The resampled output, one slit per source. + """ + input_new = datamodels.open(input_data) # Check if input_new is a MultiSlitModel model_is_msm = isinstance(input_new, MultiSlitModel) # If input is a 3D rateints MultiSlitModel (unsupported) skip the step if model_is_msm and len((input_new[0]).shape) == 3: - self.log.warning('Resample spec step will be skipped') - input_new.meta.cal_step.resample = 'SKIPPED' + self.log.warning("Resample spec step will be skipped") + input_new.meta.cal_step.resample = "SKIPPED" return input_new @@ -81,7 +86,7 @@ def process(self, input): # Setup drizzle-related parameters kwargs = self.get_drizpars() - kwargs['output'] = output + kwargs["output"] = output self.drizpars = kwargs # Call resampling @@ -90,7 +95,7 @@ def process(self, input): elif len(input_models[0].data.shape) != 2: # resample can only handle 2D images, not 3D cubes, etc - raise RuntimeError("Input {} is not a 2D image.".format(input_models[0])) + raise TypeError(f"Input {input_models[0]} is not a 2D image.") else: # result is a SlitModel @@ -103,7 +108,7 @@ def process(self, input): # populate the result wavelength attribute for MultiSlitModel if isinstance(result, MultiSlitModel): - for slit_idx, slit in enumerate(result.slits): + for slit_idx, _slit in enumerate(result.slits): wl_array = get_wavelengths(result.slits[slit_idx]) result.slits[slit_idx].wavelength = wl_array else: @@ -115,11 +120,11 @@ def process(self, input): def _process_multislit(self, input_models): """ - Resample MultiSlit data + Resample MultiSlit data. Parameters ---------- - input : `~jwst.datamodels.ModelContainer` + input_models : `~jwst.datamodels.ModelContainer` A container of `~jwst.datamodels.MultiSlitModel` Returns @@ -141,23 +146,20 @@ def _process_multislit(self, input_models): # Call the resampling routine if self.single: resamp = resample_spec.ResampleSpec( - container, - enable_var=False, - compute_err="driz_err", - **self.drizpars - ) - drizzled_library = resamp.resample_many_to_many( - in_memory=self.in_memory + container, enable_var=False, compute_err="driz_err", **self.drizpars ) + drizzled_library = resamp.resample_many_to_many(in_memory=self.in_memory) else: resamp = resample_spec.ResampleSpec( - container, - enable_var=True, - compute_err="from_var", - **self.drizpars + container, enable_var=True, compute_err="from_var", **self.drizpars ) drizzled_library = resamp.resample_many_to_one() - drizzled_library = ModelLibrary([drizzled_library,], on_disk=False) + drizzled_library = ModelLibrary( + [ + drizzled_library, + ], + on_disk=False, + ) with drizzled_library: for i, model in enumerate(drizzled_library): @@ -182,34 +184,34 @@ def _process_multislit(self, input_models): def get_drizpars(self): """ Load all drizzle-related parameter values into kwargs list. + + Returns + ------- + kwargs : dict + Dictionary of drizzle parameters """ # Define the keys pulled from step parameters - kwargs = dict( - pixfrac=self.pixfrac, - kernel=self.kernel, - fillval=self.fillval, - weight_type=self.weight_type, - good_bits=GOOD_BITS, - blendheaders=self.blendheaders, - ) + kwargs = { + "pixfrac": self.pixfrac, + "kernel": self.kernel, + "fillval": self.fillval, + "weight_type": self.weight_type, + "good_bits": GOOD_BITS, + "blendheaders": self.blendheaders, + } # Custom output WCS parameters wcs_pars = {} - wcs_pars['output_shape'] = ResampleStep.check_list_pars( - self.output_shape, - 'output_shape', - min_vals=[1, 1] + wcs_pars["output_shape"] = ResampleStep.check_list_pars( + self.output_shape, "output_shape", min_vals=[1, 1] ) - kwargs['output_wcs'] = load_custom_wcs( - self.output_wcs, - wcs_pars['output_shape'] - ) - wcs_pars['pixel_scale'] = self.pixel_scale - wcs_pars['pixel_scale_ratio'] = self.pixel_scale_ratio + kwargs["output_wcs"] = load_custom_wcs(self.output_wcs, wcs_pars["output_shape"]) + wcs_pars["pixel_scale"] = self.pixel_scale + wcs_pars["pixel_scale_ratio"] = self.pixel_scale_ratio # Report values to processing log for k, v in kwargs.items(): - self.log.debug(' {}={}'.format(k, v)) + self.log.debug(f" {k}={v}") kwargs["wcs_pars"] = wcs_pars @@ -217,11 +219,11 @@ def get_drizpars(self): def _process_slit(self, input_models): """ - Resample Slit data + Resample Slit data. Parameters ---------- - input : `~jwst.datamodels.ModelContainer` + input_models : `~jwst.datamodels.ModelContainer` A container of `~jwst.datamodels.ImageModel` or `~jwst.datamodels.SlitModel` @@ -237,14 +239,9 @@ def _process_slit(self, input_models): # Call the resampling routine if self.single: resamp = resample_spec.ResampleSpec( - input_models, - enable_var=False, - compute_err="driz_err", - **self.drizpars - ) - drizzled_library = resamp.resample_many_to_many( - in_memory=self.in_memory + input_models, enable_var=False, compute_err="driz_err", **self.drizpars ) + drizzled_library = resamp.resample_many_to_many(in_memory=self.in_memory) with drizzled_library: result = drizzled_library.borrow(0) drizzled_library.shelve(result, 0, modify=False) @@ -252,10 +249,7 @@ def _process_slit(self, input_models): else: resamp = resample_spec.ResampleSpec( - input_models, - enable_var=True, - compute_err="from_var", - **self.drizpars + input_models, enable_var=True, compute_err="from_var", **self.drizpars ) result = resamp.resample_many_to_one() @@ -267,11 +261,11 @@ def _process_slit(self, input_models): result.meta.resample.pixfrac = self.pixfrac self.update_slit_metadata(result) - if result.meta.exposure.type.lower() == 'mir_lrs-fixedslit': + if result.meta.exposure.type.lower() == "mir_lrs-fixedslit": s_region_model1 = input_models[0].meta.wcsinfo.s_region s_region = find_miri_lrs_sregion(s_region_model1, result.meta.wcs) result.meta.wcsinfo.s_region = s_region - self.log.info(f'Updating S_REGION: {s_region}.') + self.log.info(f"Updating S_REGION: {s_region}.") else: update_s_region_spectral(result) return result @@ -284,10 +278,23 @@ def update_slit_metadata(self, model): the normal update() method doesn't work with them. Updates output_model in-place. """ - for attr in ['name', 'xstart', 'xsize', 'ystart', 'ysize', - 'slitlet_id', 'source_id', 'source_name', 'source_alias', - 'stellarity', 'source_type', 'source_xpos', 'source_ypos', - 'dispersion_direction', 'shutter_state']: + for attr in [ + "name", + "xstart", + "xsize", + "ystart", + "ysize", + "slitlet_id", + "source_id", + "source_name", + "source_alias", + "stellarity", + "source_type", + "source_xpos", + "source_ypos", + "dispersion_direction", + "shutter_state", + ]: try: val = getattr(self.input_models[-1], attr) except AttributeError: diff --git a/jwst/resample/resample_step.py b/jwst/resample/resample_step.py index a654bc8698..878c36ed61 100755 --- a/jwst/resample/resample_step.py +++ b/jwst/resample/resample_step.py @@ -16,22 +16,17 @@ # Force use of all DQ flagged data except for DO_NOT_USE and NON_SCIENCE -GOOD_BITS = '~DO_NOT_USE+NON_SCIENCE' +GOOD_BITS = "~DO_NOT_USE+NON_SCIENCE" class ResampleStep(Step): """ - Resample input data onto a regular grid using the drizzle algorithm. + Resample imaging data onto a regular grid using the drizzle algorithm. .. note:: When supplied via ``output_wcs``, a custom WCS overrides other custom WCS parameters such as ``output_shape`` (now computed from by ``output_wcs.bounding_box``), ``crpix`` - - Parameters - ----------- - input : ~jwst.datamodels.JwstDataModel or ~jwst.associations.Association - Single filename for either a single image or an association table. """ class_alias = "resample" @@ -51,26 +46,40 @@ class ResampleStep(Step): single = boolean(default=False) # Resample each input to its own output grid blendheaders = boolean(default=True) # Blend metadata from inputs into output in_memory = boolean(default=True) # Keep images in memory - """ # noqa: E501 + """ # noqa: E501 reference_file_types: list = [] - def process(self, input): + def process(self, input_data): + """ + Run the resample step on the input data. - if isinstance(input, str): - ext = filetype.check(input) + Parameters + ---------- + input_data : str, ImageModel, or any asn-type input loadable into ModelLibrary + Filename pointing to an ImageModel or an association, or the ImageModel or + association itself. + + Returns + ------- + ModelLibrary or ImageModel + The final output data. If the `single` parameter is set to True, then this + is a single ImageModel; otherwise, it is a ModelLibrary. + """ + if isinstance(input_data, str): + ext = filetype.check(input_data) if ext in ("fits", "asdf"): - input = dm.open(input) - if isinstance(input, ModelLibrary): - input_models = input - elif isinstance(input, (str, dict, list)): - input_models = ModelLibrary(input, on_disk=not self.in_memory) - elif isinstance(input, ImageModel): - input_models = ModelLibrary([input], on_disk=not self.in_memory) - output = input.meta.filename + input_data = dm.open(input_data) + if isinstance(input_data, ModelLibrary): + input_models = input_data + elif isinstance(input_data, (str, dict, list)): + input_models = ModelLibrary(input_data, on_disk=not self.in_memory) + elif isinstance(input_data, ImageModel): + input_models = ModelLibrary([input_data], on_disk=not self.in_memory) + output = input_data.meta.filename self.blendheaders = False else: - raise RuntimeError(f"Input {input} is not a 2D image.") + raise TypeError(f"Input {input_data} is not a 2D image.") try: output = input_models.asn["products"][0]["name"] @@ -102,23 +111,13 @@ def process(self, input): # Call the resampling routine if self.single: resamp = resample.ResampleImage( - input_models, - output=output, - enable_var=False, - compute_err="driz_err", - **kwargs - ) - result = resamp.resample_many_to_many( - in_memory=self.in_memory + input_models, output=output, enable_var=False, compute_err="driz_err", **kwargs ) + result = resamp.resample_many_to_many(in_memory=self.in_memory) else: resamp = resample.ResampleImage( - input_models, - output=output, - enable_var=True, - compute_err="from_var", - **kwargs + input_models, output=output, enable_var=True, compute_err="from_var", **kwargs ) result = resamp.resample_many_to_one() @@ -157,7 +156,7 @@ def check_list_pars(vals, name, min_vals=None): if n == 2: return None elif n == 0: - if min_vals and sum(x >= y for x, y in zip(vals, min_vals)) != 2: + if min_vals and sum(x >= y for x, y in zip(vals, min_vals, strict=True)) != 2: raise ValueError(f"'{name}' values must be larger or equal to {list(min_vals)}") return list(vals) else: @@ -166,41 +165,39 @@ def check_list_pars(vals, name, min_vals=None): def get_drizpars(self): """ Load all drizzle-related parameter values into kwargs list. + + Returns + ------- + kwargs : dict + Dictionary of drizzle parameters """ # Define the keys pulled from step parameters - kwargs = dict( - pixfrac=self.pixfrac, - kernel=self.kernel, - fillval=self.fillval, - weight_type=self.weight_type, - good_bits=GOOD_BITS, - blendheaders=self.blendheaders, - ) + kwargs = { + "pixfrac": self.pixfrac, + "kernel": self.kernel, + "fillval": self.fillval, + "weight_type": self.weight_type, + "good_bits": GOOD_BITS, + "blendheaders": self.blendheaders, + } # Custom output WCS parameters. - output_shape = self.check_list_pars( - self.output_shape, - 'output_shape', - min_vals=[1, 1] - ) - kwargs['output_wcs'] = load_custom_wcs( - self.output_wcs, - output_shape - ) + output_shape = self.check_list_pars(self.output_shape, "output_shape", min_vals=[1, 1]) + kwargs["output_wcs"] = load_custom_wcs(self.output_wcs, output_shape) wcs_pars = { - 'crpix': self.check_list_pars(self.crpix, 'crpix'), - 'crval': self.check_list_pars(self.crval, 'crval'), - 'rotation': self.rotation, - 'pixel_scale': self.pixel_scale, - 'pixel_scale_ratio': self.pixel_scale_ratio, - 'output_shape': None if output_shape is None else output_shape[::-1], + "crpix": self.check_list_pars(self.crpix, "crpix"), + "crval": self.check_list_pars(self.crval, "crval"), + "rotation": self.rotation, + "pixel_scale": self.pixel_scale, + "pixel_scale_ratio": self.pixel_scale_ratio, + "output_shape": None if output_shape is None else output_shape[::-1], } - kwargs['wcs_pars'] = wcs_pars + kwargs["wcs_pars"] = wcs_pars # Report values to processing log for k, v in kwargs.items(): - self.log.debug(' {}={}'.format(k, v)) + self.log.debug(f" {k}={v}") return kwargs diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index b6e95ff7ca..c96ee123a5 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -26,53 +26,42 @@ ) -__all__ = [ - "build_mask", - "decode_context", - "make_output_wcs", - "resampled_wcs_from_models" -] +__all__ = ["build_mask", "decode_context", "make_output_wcs", "resampled_wcs_from_models"] log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) def resampled_wcs_from_models( - input_models, - pixel_scale_ratio=1.0, - pixel_scale=None, - output_shape=None, - rotation=None, - crpix=None, - crval=None, + input_models, + pixel_scale_ratio=1.0, + pixel_scale=None, + output_shape=None, + rotation=None, + crpix=None, + crval=None, ): """ - Computes the WCS of the resampled image from input models and - specified WCS parameters. + Compute the WCS of the resampled image from input models and specified WCS parameters. Parameters ---------- - input_models : `~jwst.datamodel.ModelLibrary` Each datamodel must have a ``model.meta.wcs`` set to a ~gwcs.WCS object. - pixel_scale_ratio : float, optional Desired pixel scale ratio defined as the ratio of the desired output pixel scale to the first input model's pixel scale computed from this model's WCS at the fiducial point (taken as the ``ref_ra`` and ``ref_dec`` from the ``wcsinfo`` meta attribute of the first input image). Ignored when ``pixel_scale`` is specified. - pixel_scale : float, None, optional Desired pixel scale (in degrees) of the output WCS. When provided, overrides ``pixel_scale_ratio``. - output_shape : tuple of two integers (int, int), None, optional Shape of the image (data array) using ``np.ndarray`` convention (``ny`` first and ``nx`` second). This value will be assigned to ``pixel_shape`` and ``array_shape`` properties of the returned WCS object. - rotation : float, None, optional Position angle of output image's Y-axis relative to North. A value of 0.0 would orient the final output image to be North up. @@ -81,12 +70,10 @@ def resampled_wcs_from_models( camera with the x and y axes of the resampled image corresponding approximately to the detector axes. Ignored when ``transform`` is provided. - crpix : tuple of float, None, optional Position of the reference pixel in the resampled image array. If ``crpix`` is not specified, it will be set to the center of the bounding box of the returned WCS object. - crval : tuple of float, None, optional Right ascension and declination of the reference pixel. Automatically computed if not provided. @@ -95,16 +82,12 @@ def resampled_wcs_from_models( ------- wcs : ~gwcs.wcs.WCS The WCS object corresponding to the combined input footprints. - pscale_in : float Computed pixel scale (in degrees) of the first input image. - pscale_out : float Computed pixel scale (in degrees) of the output image. - pixel_scale_ratio : float Pixel scale ratio (output to input). - """ # build a list of WCS of all input models: sregion_list = [] @@ -126,33 +109,24 @@ def resampled_wcs_from_models( naxes = ref_wcs.output_frame.naxes if naxes != 2: raise UnsupportedWCSError( - "Output WCS needs 2 coordinate axes but the " - f"supplied WCS has {naxes} axes." + f"Output WCS needs 2 coordinate axes but the supplied WCS has {naxes} axes." ) if pixel_scale is None: # TODO: at some point we should switch to compute_mean_pixel_area # instead of compute_scale. pscale_in0 = compute_scale( - ref_wcs, - fiducial=np.array([ref_wcsinfo["ra_ref"], ref_wcsinfo["dec_ref"]]) + ref_wcs, fiducial=np.array([ref_wcsinfo["ra_ref"], ref_wcsinfo["dec_ref"]]) ) pixel_scale = pscale_in0 * pixel_scale_ratio - log.info( - f"Pixel scale ratio (pscale_out/pscale_in): {pixel_scale_ratio}" - ) + log.info(f"Pixel scale ratio (pscale_out/pscale_in): {pixel_scale_ratio}") log.info(f"Computed output pixel scale: {3600 * pixel_scale} arcsec.") else: - pscale_in0 = np.rad2deg( - math.sqrt(compute_mean_pixel_area(ref_wcs, shape=shape)) - ) + pscale_in0 = np.rad2deg(math.sqrt(compute_mean_pixel_area(ref_wcs, shape=shape))) pixel_scale_ratio = pixel_scale / pscale_in0 log.info(f"Output pixel scale: {3600 * pixel_scale} arcsec.") - log.info( - "Computed pixel scale ratio (pscale_out/pscale_in): " - f"{pixel_scale_ratio}." - ) + log.info(f"Computed pixel scale ratio (pscale_out/pscale_in): {pixel_scale_ratio}.") wcs = wcs_from_sregions( sregion_list, @@ -163,42 +137,44 @@ def resampled_wcs_from_models( rotation=rotation, shape=output_shape, crpix=crpix, - crval=crval + crval=crval, ) return wcs, pscale_in0, pixel_scale, pixel_scale_ratio -def make_output_wcs(input_models, ref_wcs=None, - pscale_ratio=None, pscale=None, rotation=None, shape=None, - crpix=None, crval=None): +def make_output_wcs( + input_models, + ref_wcs=None, + pscale_ratio=None, + pscale=None, + rotation=None, + shape=None, + crpix=None, + crval=None, +): """ + Generate output WCS here based on footprints of all input WCS objects. + .. deprecated:: 1.17.2 :py:func:`make_output_wcs` has been deprecated and will be removed in a future release. Use :py:func:`resampled_wcs_from_models` instead. - - Generate output WCS here based on footprints of all input WCS objects. - Parameters ---------- input_models : `~jwst.datamodel.ModelLibrary` The datamodels to combine into a single output WCS. Each datamodel must have a ``meta.wcs.s_region`` attribute. - ref_wcs : gwcs.WCS, None, optional Custom WCS to use as the output WCS. If not provided, the reference WCS will be taken as the WCS of the first input model, with its bounding box adjusted to encompass all input frames. - pscale_ratio : float, None, optional Ratio of input to output pixel scale. Ignored when ``pscale`` is provided. - pscale : float, None, optional Absolute pixel scale in degrees. When provided, overrides ``pscale_ratio``. - rotation : float, None, optional Position angle of output image Y-axis relative to North. A value of 0.0 would orient the final output image to be North up. @@ -206,18 +182,15 @@ def make_output_wcs(input_models, ref_wcs=None, but will instead be resampled in the default orientation for the camera with the x and y axes of the resampled image corresponding approximately to the detector axes. - shape : tuple of int, None, optional Shape of the image (data array) using ``numpy.ndarray`` convention (``ny`` first and ``nx`` second). This value will be assigned to ``pixel_shape`` and ``array_shape`` properties of the returned WCS object. - crpix : tuple of float, None, optional Position of the reference pixel in the image array. If ``crpix`` is not specified, it will be set to the center of the bounding box of the returned WCS object. - crval : tuple of float, None, optional Right ascension and declination of the reference pixel. Automatically computed if not provided. @@ -232,7 +205,7 @@ def make_output_wcs(input_models, ref_wcs=None, "will be removed in a future release. " "Use 'resampled_wcs_from_models()' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if ref_wcs is None: @@ -248,8 +221,7 @@ def make_output_wcs(input_models, ref_wcs=None, naxes = ref_wcs.output_frame.naxes if naxes != 2: - msg = ("Output WCS needs 2 spatial axes " - f"but the supplied WCS has {naxes} axes.") + msg = f"Output WCS needs 2 spatial axes but the supplied WCS has {naxes} axes." raise RuntimeError(msg) output_wcs = wcs_from_sregions( @@ -261,15 +233,14 @@ def make_output_wcs(input_models, ref_wcs=None, rotation=rotation, shape=shape, crpix=crpix, - crval=crval + crval=crval, ) del example_model else: naxes = ref_wcs.output_frame.naxes if naxes != 2: - msg = ("Output WCS needs 2 spatial axes " - f"but the supplied WCS has {naxes} axes.") + msg = f"Output WCS needs 2 spatial axes but the supplied WCS has {naxes} axes." raise RuntimeError(msg) output_wcs = deepcopy(ref_wcs) if shape is not None: @@ -285,75 +256,86 @@ def make_output_wcs(input_models, ref_wcs=None, def build_driz_weight(model, weight_type=None, good_bits=None): """ + Create a weight map for use by drizzle. + .. deprecated:: 1.17.2 :py:func:`build_driz_weight` has been deprecated and will be removed in a future release. Use :py:func:`stcal.utils.build_driz_weight` instead. - - - Create a weight map for use by drizzle - """ + """ # numpydoc ignore=RT01 warnings.warn( "'build_driz_weight()' has been deprecated since 1.17.2 and " "will be removed in a future release. " "Use 'stcal.utils.build_driz_weight()' instead.", DeprecationWarning, - stacklevel=2 - ) - return _stcal_build_driz_weight( - model=model, - weight_type=weight_type, - good_bits=good_bits + stacklevel=2, ) + return _stcal_build_driz_weight(model=model, weight_type=weight_type, good_bits=good_bits) def build_mask(dqarr, bitvalue): - """Build a bit mask from an input DQ array and a bitvalue flag + """ + Build a bit mask from an input DQ array and a bitvalue flag. + + Parameters + ---------- + dqarr : numpy.ndarray + Data quality array. + bitvalue : int + Bit value to be used for flagging good pixels. - In the returned bit mask, 1 is good, 0 is bad + Returns + ------- + numpy.ndarray + Bit mask, where 1 is good and 0 is bad. """ - return _stcal_build_mask( - dqarr=dqarr, - good_bits=bitvalue, - flag_name_map=pixel - ) + return _stcal_build_mask(dqarr=dqarr, good_bits=bitvalue, flag_name_map=pixel) def is_sky_like(frame): - """ Checks that a frame is a sky-like frame by looking at its output units. + """ + Check that a frame is a sky-like frame by looking at its output units. + If output units are either ``deg`` or ``arcsec`` the frame is considered a sky-like frame (as opposite to, e.g., a Cartesian frame.) + + Parameters + ---------- + frame : gwcs.WCS + WCS object to check. + + Returns + ------- + bool + ``True`` if the frame is sky-like, ``False`` otherwise. """ return u.Unit("deg") in frame.unit or u.Unit("arcsec") in frame.unit def decode_context(context, x, y): """ + Get 0-based indices of input images that contributed to (resampled) output pixel. + .. deprecated:: 1.17.2 :py:func:`decode_context` has been deprecated and will be removed in a future release. Use :py:func:`drizzle.utils.decode_context` instead. - - Get 0-based indices of input images that contributed to (resampled) - output pixel with coordinates ``x`` and ``y``. - Parameters ---------- - context: numpy.ndarray + context : numpy.ndarray A 3D `~numpy.ndarray` of integral data type. - - x: int, list of integers, numpy.ndarray of integers + x : int, list of integers, numpy.ndarray of integers X-coordinate of pixels to decode (3rd index into the ``context`` array) - - y: int, list of integers, numpy.ndarray of integers + y : int, list of integers, numpy.ndarray of integers Y-coordinate of pixels to decode (2nd index into the ``context`` array) Returns ------- - A list of `numpy.ndarray` objects each containing indices of input images - that have contributed to an output pixel with coordinates ``x`` and ``y``. - The length of returned list is equal to the number of input coordinate - arrays ``x`` and ``y``. + list + A list of `numpy.ndarray` objects each containing indices of input images + that have contributed to an output pixel with coordinates ``x`` and ``y``. + The length of returned list is equal to the number of input coordinate + arrays ``x`` and ``y``. Examples -------- @@ -363,34 +345,48 @@ def decode_context(context, x, y): >>> import numpy as np >>> from jwst.resample.resample_utils import decode_context >>> con = np.array( - ... [[[0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 36196864, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 537920000, 0, 0, 0]], - ... [[0, 0, 0, 0, 0, 0,], - ... [0, 0, 0, 67125536, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 163856, 0, 0, 0]], - ... [[0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 8203, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 0, 0, 0], - ... [0, 0, 32865, 0, 0, 0]]], - ... dtype=np.int32 + ... [ + ... [ + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 0, 36196864, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 537920000, 0, 0, 0], + ... ], + ... [ + ... [ + ... 0, + ... 0, + ... 0, + ... 0, + ... 0, + ... 0, + ... ], + ... [0, 0, 0, 67125536, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 163856, 0, 0, 0], + ... ], + ... [ + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 0, 8203, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0], + ... [0, 0, 32865, 0, 0, 0], + ... ], + ... ], + ... dtype=np.int32, ... ) >>> decode_context(con, [3, 2], [1, 4]) [array([ 9, 12, 14, 19, 21, 25, 37, 40, 46, 58, 64, 65, 67, 77]), array([ 9, 20, 29, 36, 47, 49, 64, 69, 70, 79])] - """ warnings.warn( "'decode_context()' has been deprecated since 1.17.2 and " "will be removed in a future release. " "Use 'drizzle.utils.decode_context()' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return _drizzle_decode_context(context, x, y) @@ -407,10 +403,8 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None): :py:func:`load_custom_wcs` are: ``"pixel_area"``, ``"pixel_scale"``, ``"pixel_shape"``, and ``"array_shape"``. The latter two are used only when the WCS object does not have the corresponding attributes set. - Pixel scale and pixel area should be provided in units of ``arcsec`` and ``arcsec**2``. - output_shape : tuple of int, optional Array shape for the output data. If not provided, the custom WCS must specify one of (in order of priority): @@ -425,7 +419,6 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None): the ASDF file. If ``"pixel_area"`` is provided but ``"pixel_scale"`` is not then pixel scale will be computed from pixel area assuming square pixels: ``pixel_scale = sqrt(pixel_area)``. - """ if not asdf_wcs_file: return None @@ -439,8 +432,7 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None): user_pixel_shape = af.tree.get("pixel_shape", None) user_array_shape = af.tree.get( - "array_shape", - None if user_pixel_shape is None else user_pixel_shape[::-1] + "array_shape", None if user_pixel_shape is None else user_pixel_shape[::-1] ) if output_shape is not None: @@ -450,8 +442,7 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None): wcs.array_shape = user_array_shape elif getattr(wcs, "bounding_box", None) is not None: wcs.array_shape = tuple( - int(axs[1] + 0.5) - for axs in wcs.bounding_box.bounding_box(order="C") + int(axs[1] + 0.5) for axs in wcs.bounding_box.bounding_box(order="C") ) else: raise ValueError( @@ -469,23 +460,24 @@ def load_custom_wcs(asdf_wcs_file, output_shape=None): def find_miri_lrs_sregion(sregion_model1, wcs): - """ Find s region for MIRI LRS resampled data. + """ + Find s region for MIRI LRS resampled data. Parameters ---------- - sregion_model1 : string - s_regions of the first input model + sregion_model1 : str + The s_regions of the first input model wcs : gwcs.WCS Spatial/spectral WCS. Returns ------- - sregion : string - s_region for the resample data. + sregion : str + The s_region for the resample data. """ # use the first sregion to set the width of the slit spatial_box = sregion_model1 - s = spatial_box.split(' ') + s = spatial_box.split(" ") a1 = float(s[3]) b1 = float(s[4]) a2 = float(s[5]) @@ -496,13 +488,13 @@ def find_miri_lrs_sregion(sregion_model1, wcs): b4 = float(s[10]) # convert each corner to SkyCoord - coord1 = SkyCoord(a1, b1, unit='deg') - coord2 = SkyCoord(a2, b2, unit='deg') - coord3 = SkyCoord(a3, b3, unit='deg') - coord4 = SkyCoord(a4, b4, unit='deg') + coord1 = SkyCoord(a1, b1, unit="deg") + coord2 = SkyCoord(a2, b2, unit="deg") + coord3 = SkyCoord(a3, b3, unit="deg") + coord4 = SkyCoord(a4, b4, unit="deg") # Find the distance between the corners - # corners are counterclockwize from 1,2,3,4 + # corners are counter clockwize from 1,2,3,4 sep1 = coord1.separation(coord2) sep2 = coord2.separation(coord3) sep3 = coord3.separation(coord4) @@ -513,9 +505,9 @@ def find_miri_lrs_sregion(sregion_model1, wcs): # the minimum separation is the slit width min_sep = np.min(sep) - min_sep = min_sep* u.deg # set the units to degrees + min_sep = min_sep * u.deg # set the units to degrees - log.info(f'Estimated MIRI LRS slit width: {min_sep*3600} arcsec.') + log.info(f"Estimated MIRI LRS slit width: {min_sep * 3600} arcsec.") # now use the combined WCS to map all pixels to the slit center bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) @@ -524,27 +516,30 @@ def find_miri_lrs_sregion(sregion_model1, wcs): dec = dec.flatten() # ra and dec are the values along the output resampled slit center # using the first point and last point find the position angle - star1 = SkyCoord(ra[0]*u.deg, dec[0]*u.deg, frame='icrs') - star2 = SkyCoord(ra[-1]*u.deg, dec[-1]*u.deg, frame='icrs') + star1 = SkyCoord(ra[0] * u.deg, dec[0] * u.deg, frame="icrs") + star2 = SkyCoord(ra[-1] * u.deg, dec[-1] * u.deg, frame="icrs") position_angle = star1.position_angle(star2).to(u.deg) # 90 degrees to the position angle of the slit will define s_region - pos_angle = position_angle - 90.0*u.deg + pos_angle = position_angle - 90.0 * u.deg - star_c1 = star1.directional_offset_by(pos_angle, min_sep/2) - star_c2 = star1.directional_offset_by(pos_angle, -min_sep/2) - star_c3 = star2.directional_offset_by(pos_angle, min_sep/2) - star_c4 = star2.directional_offset_by(pos_angle, -min_sep/2) + star_c1 = star1.directional_offset_by(pos_angle, min_sep / 2) + star_c2 = star1.directional_offset_by(pos_angle, -min_sep / 2) + star_c3 = star2.directional_offset_by(pos_angle, min_sep / 2) + star_c4 = star2.directional_offset_by(pos_angle, -min_sep / 2) # set these values to footprint # ra,dec corners are in counter-clockwise direction - footprint = [star_c1.ra.value, star_c1.dec.value, - star_c3.ra.value, star_c3.dec.value, - star_c4.ra.value, star_c4.dec.value, - star_c2.ra.value, star_c2.dec.value] + footprint = [ + star_c1.ra.value, + star_c1.dec.value, + star_c3.ra.value, + star_c3.dec.value, + star_c4.ra.value, + star_c4.dec.value, + star_c2.ra.value, + star_c2.dec.value, + ] footprint = np.array(footprint) s_region = compute_s_region_keyword(footprint) return s_region - - - diff --git a/jwst/resample/tests/test_interface.py b/jwst/resample/tests/test_interface.py index c17b326127..5f2fdb960e 100644 --- a/jwst/resample/tests/test_interface.py +++ b/jwst/resample/tests/test_interface.py @@ -13,5 +13,5 @@ def test_multi_integration_input(resample_class): cube.meta.observation.time = '10:32:20.181' # Resample can't handle cubes, so it should fail - with pytest.raises(RuntimeError): + with pytest.raises(TypeError): resample_class().call(cube)