diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 2da04a5751..9d7aab559a 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -398,11 +398,11 @@ def wfss(input_model, reference_files): # Get the disperser parameters which are defined as a model for each # spectral order with NIRISSGrismModel(reference_files['specwcs']) as f: - dispx = f.dispx - dispy = f.dispy - displ = f.displ - invdispl = f.invdispl - orders = f.orders + dispx = f.dispx.instance + dispy = f.dispy.instance + displ = f.displ.instance + invdispl = f.invdispl.instance + orders = f.orders.instance fwcpos_ref = f.fwcpos_ref # This is the actual rotation from the input model diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index db05c81a95..1145dc8423 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -1,15 +1,123 @@ +from functools import partial import numpy as np +from typing import Callable, Sequence +from astropy.wcs import WCS from scipy.interpolate import interp1d +import warnings from ..lib.winclip import get_clipped_pixels from .sens1d import create_1d_sens -def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, - sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, - oversample_factor=2, extrapolate_sed=False, xoffset=0, - yoffset=0): +def flat_lam(fluxes: np.ndarray, lams: np.ndarray) -> np.ndarray: + ''' + Parameters + ---------- + x : float + x-coordinate of the pixel. + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + + Returns + ------- + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + ''' + return fluxes[0] + + +def flux_interpolator_injector(lams: np.ndarray, + flxs: np.ndarray, + extrapolate_sed: bool, + ) -> Callable[[float], float]: + ''' + Parameters + ---------- + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + flxs : float array + Array of fluxes (flam) for the pixels contained in x0, y0. If a single + direct image is in use, this will be a single value. + extrapolate_sed : bool + Whether to allow for the SED of the object to be extrapolated when it does not fully cover the + needed wavelength range. Default if False. + + Returns + ------- + flux : function + Function that returns the flux at a given wavelength. If only one direct image is in use, this + function will always return the same value + ''' + + if len(lams) > 1: + # If we have direct image flux values from more than one filter (lams), + # we have the option to extrapolate the fluxes outside the + # wavelength range of the direct images + if extrapolate_sed is False: + return interp1d(lams, flxs, fill_value=0., bounds_error=False) + else: + return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) + else: + # If we only have flux from one wavelength, just use that + # single flux value at all wavelengths + return partial(flat_lam, flxs) + + +def determine_wl_spacing(dw: float, + lams: np.ndarray, + oversample_factor: int, + ) -> float: + ''' + Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, + whichever is smaller, divided by oversampling requested + + Parameters + ---------- + dw : float + The natural wavelength scale of the grism image + lams : float array + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + oversample_factor : int + The amount of oversampling + + Returns + ------- + dlam : float + The wavelength spacing to use for the dispersed pixels + ''' + # + if len(lams) > 1: + input_dlam = np.median(lams[1:] - lams[:-1]) + if input_dlam < dw: + return input_dlam / oversample_factor + return dw / oversample_factor + + +def dispersed_pixel(x0: np.ndarray, + y0: np.ndarray, + width: float, + height: float, + lams: np.ndarray, + flxs: np.ndarray, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp: np.ndarray, + seg_wcs: WCS, + grism_wcs: WCS, + ID: int, + naxis: Sequence[int], + oversample_factor: int = 2, + extrapolate_sed: bool = False, + xoffset: float = 0, + yoffset: float = 0, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: """ This function take a list of pixels and disperses them using the information contained in the grism image WCS object and returns a list of dispersed pixels and fluxes. @@ -83,20 +191,8 @@ def dispersed_pixel(x0, y0, width, height, lams, flxs, order, wmin, wmax, sky_to_imgxy = grism_wcs.get_transform('world', 'detector') imgxy_to_grismxy = grism_wcs.get_transform('detector', 'grism_detector') - # Setup function for retrieving flux values at each dispersed wavelength - if len(lams) > 1: - # If we have direct image flux values from more than one filter (lambda), - # we have the option to extrapolate the fluxes outside the - # wavelength range of the direct images - if extrapolate_sed is False: - flux = interp1d(lams, flxs, fill_value=0., bounds_error=False) - else: - flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) - else: - # If we only have flux from one lambda, just use that - # single flux value at all wavelengths - def flux(x): - return flxs[0] + # Set up function for retrieving flux values at each dispersed wavelength + flux_interpolator = flux_interpolator_injector(lams, flxs, extrapolate_sed) # Get x/y positions in the grism image corresponding to wmin and wmax: # Start with RA/Dec of the input pixel position in segmentation map, @@ -110,19 +206,9 @@ def flux(x): dxw = xwmax - xwmin dyw = ywmax - ywmin - # Compute the delta-wave per pixel - dw = np.abs((wmax - wmin) / (dyw - dxw)) - - # Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, - # whichever is smaller, divided by oversampling requested - input_dlam = np.median(lams[1:] - lams[:-1]) - if input_dlam < dw: - dlam = input_dlam / oversample_factor - else: - # this value gets used when we only have 1 direct image wavelength - dlam = dw / oversample_factor - # Create list of wavelengths on which to compute dispersed pixels + dw = np.abs((wmax - wmin) / (dyw - dxw)) + dlam = determine_wl_spacing(dw, lams, oversample_factor) lambdas = np.arange(wmin, wmax + dlam, dlam) n_lam = len(lambdas) @@ -161,7 +247,11 @@ def flux(x): # values are naturally in units of physical fluxes, so we divide out # the sensitivity (flux calibration) values to convert to units of # countrate (DN/s). - counts = flux(lams) * areas / sens + # flux_interpolator(lams) is either single-valued (for a single direct image) + # or an array of the same length as lams (for multiple direct images in different filters) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") + counts = flux_interpolator(lams) * areas / (sens * oversample_factor) counts[no_cal] = 0. # set to zero where no flux cal info available return xs, ys, areas, lams, counts, ID diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 62b66dfcf6..0f02a9d21d 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,25 +1,126 @@ +import copy import time -import multiprocessing import numpy as np +import multiprocessing as mp from scipy import sparse from stdatamodels.jwst import datamodels +from astropy.wcs import WCS +from typing import Sequence from .disperse import dispersed_pixel import logging +from photutils.background import Background2D, MedianBackground +from astropy.stats import SigmaClip + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) +def disperse_multiprocess(pars, max_cpu): + + pars = copy.deepcopy(pars) + ctx = mp.get_context("forkserver") + with ctx.Pool(max_cpu) as mypool: + all_res = mypool.starmap(dispersed_pixel, pars) + + return all_res + + +def background_subtract(data: np.ndarray, + box_size: tuple = None, + filter_size: tuple = (3,3), + sigma: float = 3.0, + exclude_percentile: float = 30.0, + ) -> np.ndarray: + """ + Simple astropy background subtraction + + Parameters + ---------- + data : np.ndarray + 2D array of pixel values + box_size : tuple + Size of box in pixels to use for background estimation. + If not set, defaults to 1/5 of the image size. + filter_size : tuple + Size of filter to use for background estimation + sigma : float + Sigma threshold for background clipping + exclude_percentile : float + Percentage of masked pixels above which box is excluded from background estimation + + Returns + ------- + data : np.ndarray + 2D array of pixel values with background subtracted + + Notes + ----- + Improper background subtraction in input _i2d image leads to extra flux + in the simulated dispersed image, and was one cause of flux scaling issues + in a previous version. + """ + if box_size is None: + box_size = (int(data.shape[0]/5), int(data.shape[1]/5)) + sigma_clip = SigmaClip(sigma=sigma) + bkg_estimator = MedianBackground() + bkg = Background2D(data, box_size, filter_size=filter_size, + sigma_clip=sigma_clip, bkg_estimator=bkg_estimator, + exclude_percentile=exclude_percentile) + + return data - bkg.background + + +def _select_ids(ID: int, all_IDs: list[int]) -> list[int]: + ''' + Select the source IDs to be processed based on the input ID parameter. + + Parameters + ---------- + ID : int or list-like + ID(s) of source to process. If None, all sources processed. + all_IDs : np.ndarray + Array of all source IDs in the segmentation map + + Returns + ------- + selected_IDs : list + List of selected source IDs + ''' + if ID is None: + log.info(f"Loading all {len(all_IDs)} sources from segmentation map") + return all_IDs + + elif isinstance(ID, int): + log.info(f"Loading single source {ID} from segmentation map") + return [ID] + + elif isinstance(ID, list) or isinstance(ID, np.ndarray): + log.info(f"Loading {len(ID)} of {len(all_IDs)} selected sources from segmentation map") + return list(ID) + else: + raise ValueError("ID must be an integer or a list of integers") + class Observation: """This class defines an actual observation. It is tied to a single grism image.""" - def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, - sed_file=None, extrapolate_sed=False, - boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1): + def __init__(self, + direct_images: list[str], + segmap_model: datamodels.SegmentationMapModel, + grism_wcs: WCS, + filter: str, + ID: int = None, + sed_file: str = None, + extrapolate_sed: bool = False, + boundaries: Sequence = [], + offsets: Sequence = [0, 0], + renormalize: bool = True, + max_cpu: int = 1, + ) -> None: """ Initialize all data and metadata for a given observation. Creates lists of @@ -35,8 +136,8 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, WCS object from grism image filter : str Filter name - ID : int - ID of source to process. If zero, all sources processed. + ID : int or list-like, optional + ID(s) of source to process. If zero, all sources processed. sed_file : str Name of Spectral Energy Distribution (SED) file containing datasets matching the ID in the segmentation file and each consisting of a [[lambda],[flux]] array. @@ -54,9 +155,10 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, self.seg_wcs = segmap_model.meta.wcs self.grism_wcs = grism_wcs self.ID = ID - self.IDs = [] self.dir_image_names = direct_images self.seg = segmap_model.data + all_ids = np.array(list(set(np.ravel(self.seg)))) + self.IDs = _select_ids(ID, all_ids) self.filter = filter self.sed_file = sed_file # should always be NONE for baseline pipeline (use flat SED) self.cache = False @@ -69,9 +171,9 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, if len(boundaries) == 0: log.debug("No boundaries passed.") self.xstart = 0 - self.xend = self.xstart + self.dims[0] - 1 + self.xend = self.xstart + self.seg.shape[0] - 1 self.ystart = 0 - self.yend = self.ystart + self.dims[1] - 1 + self.yend = self.ystart + self.seg.shape[1] - 1 else: self.xstart, self.xend, self.ystart, self.yend = boundaries self.dims = (self.yend - self.ystart + 1, self.xend - self.xstart + 1) @@ -85,33 +187,24 @@ def __init__(self, direct_images, segmap_model, grism_wcs, filter, ID=0, # Create pixel lists for sources labeled in segmentation map self.create_pixel_list() - def create_pixel_list(self): - # Create a list of pixels to be dispersed, grouped per object ID. - - if self.ID == 0: - # When ID=0, all sources in the segmentation map are processed. - # This creates a huge list of all x,y pixel indices that have non-zero values - # in the seg map, sorted by those indices belonging to a particular source ID. - self.xs = [] - self.ys = [] - all_IDs = np.array(list(set(np.ravel(self.seg)))) - all_IDs = all_IDs[all_IDs > 0] - self.IDs = all_IDs - log.info(f"Loading {len(all_IDs)} sources from segmentation map") - for ID in all_IDs: - ys, xs = np.nonzero(self.seg == ID) - if len(xs) > 0 and len(ys) > 0: - self.xs.append(xs) - self.ys.append(ys) + # Initialize the list of slits + self.simul_slits = datamodels.MultiSlitModel() + self.simul_slits_order = [] + self.simul_slits_sid = [] - else: - # Process only the given source ID - log.info(f"Loading source {self.ID} from segmentation map") - ys, xs = np.nonzero(self.seg == self.ID) + def create_pixel_list(self): + ''' + Create a list of pixels to be dispersed, grouped per object ID. + When ID is None, all sources in the segmentation map are processed. + ''' + + self.xs = [] + self.ys = [] + for ID in self.IDs: + ys, xs = np.nonzero(self.seg == ID) if len(xs) > 0 and len(ys) > 0: - self.xs = [xs] - self.ys = [ys] - self.IDs = [self.ID] + self.xs.append(xs) + self.ys.append(ys) # Populate lists of direct image flux values for the sources. self.fluxes = {} @@ -120,6 +213,7 @@ def create_pixel_list(self): log.info(f"Using direct image {dir_image_name}") with datamodels.open(dir_image_name) as model: dimage = model.data + dimage = background_subtract(dimage) if self.sed_file is None: # Default pipeline will use sed_file=None, so we need to compute @@ -150,7 +244,13 @@ def create_pixel_list(self): for i in range(len(self.IDs)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) - def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): + def disperse_all(self, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp:np.ndarray, + cache=False): """ Compute dispersed pixel values for all sources identified in the segmentation map. @@ -177,7 +277,9 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.simulated_image = np.zeros(self.dims, float) # Loop over all source ID's from segmentation map + pool_args = [] for i in range(len(self.IDs)): + if self.cache: self.cached_object[i] = {} self.cached_object[i]['x'] = [] @@ -189,13 +291,42 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.cached_object[i]['miny'] = [] self.cached_object[i]['maxy'] = [] - self.disperse_chunk(i, order, wmin, wmax, sens_waves, sens_resp) + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp,] + pool_args.append(disperse_chunk_args) + + t0 = time.time() + if self.max_cpu > 1: + # put this log message here to avoid printing it for every chunk + log.info(f"Using multiprocessing with {self.max_cpu} cores to compute dispersion") - def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): + disperse_chunk_output = [] + for i in range(len(self.IDs)): + disperse_chunk_output.append(self.disperse_chunk(*pool_args[i])) + t1 = time.time() + log.info(f"Wall clock time for disperse_chunk order {order}: {(t1-t0):.1f} sec") + + # Collect results into simulated image and slit models + for i, this_output in enumerate(disperse_chunk_output): + [this_image, this_bounds, this_sid, this_order] = this_output + slit = self.construct_slitmodel_for_chunk(this_image, this_bounds, this_sid, this_order) + self.simulated_image += this_image + if slit is not None: + self.simul_slits.slits.append(slit) + self.simul_slits_order.append(this_order) + self.simul_slits_sid.append(this_sid) + + + def disperse_chunk(self, + c: int, + order: int, + wmin: float, + wmax: float, + sens_waves: np.ndarray, + sens_resp: np.ndarray, + ) -> tuple[np.ndarray, list, int, int]: """ Method that computes dispersion for a single source. To be called after create_pixel_list(). - Parameters ---------- c : int @@ -210,6 +341,17 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): Wavelength array from photom reference file sens_resp : float array Response (flux calibration) array from photom reference file + + Returns + ------- + this_object : np.ndarray + 2D array of dispersed pixel values for the source + thisobj_bounds : list + [minx, maxx, miny, maxy] bounds of the object + sid : int + Source ID + order : int + Spectral order number """ sid = int(self.IDs[c]) @@ -224,23 +366,18 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # Loop over all pixels in list for object "c" log.debug(f"source contains {len(self.xs[c])} pixels") for i in range(len(self.xs[c])): - - # Here "i" and "ID" are just indexes into the pixel list for the object + # Here "i" is just an index into the pixel list for the object # being processed, as opposed to the ID number of the object itself - ID = i - # xc, yc are the coordinates of the central pixel of the group # of pixels surrounding the direct image pixel index width = 1.0 height = 1.0 xc = self.xs[c][i] + 0.5 * width yc = self.ys[c][i] + 0.5 * height - # "lams" is the array of wavelengths previously stored in flux list # and correspond to the central wavelengths of the filters used in # the input direct image(s). For the simple case of 1 combined direct image, # this contains a single value (e.g. 4.44 for F444W). - # "fluxes" is the array of pixel values from the direct image(s). # For the simple case of 1 combined direct image, this contains a # a single value (just like "lams"). @@ -248,21 +385,18 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): (self.fluxes[lm][c][i], lm) for lm in sorted(self.fluxes.keys()) if self.fluxes[lm][c][i] != 0 ])) - - pars_i = (xc, yc, width, height, lams, fluxes, self.order, + pars_i = [xc, yc, width, height, lams, fluxes, self.order, self.wmin, self.wmax, self.sens_waves, self.sens_resp, - self.seg_wcs, self.grism_wcs, ID, self.dims[::-1], 2, - self.extrapolate_sed, self.xoffset, self.yoffset) - + self.seg_wcs, self.grism_wcs, i, self.dims[::-1], 2, + self.extrapolate_sed, self.xoffset, self.yoffset] pars.append(pars_i) - # now have full pars list for all pixels for this object + #if i == 0: + # print([type(arg) for arg in pars_i]) #all these need to be pickle-able + # pass parameters into dispersed_pixel, either using multiprocessing or not time1 = time.time() if self.max_cpu > 1: - ctx = multiprocessing.get_context("forkserver") - mypool = ctx.Pool(self.max_cpu) # Create the pool - all_res = mypool.imap_unordered(dispersed_pixel, pars) # Fill the pool - mypool.close() # Drain the pool + all_res = disperse_multiprocess(pars, self.max_cpu) else: all_res = [] for i in range(len(pars)): @@ -270,8 +404,8 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # Initialize blank image for this source this_object = np.zeros(self.dims, float) - nres = 0 + bounds = [] for pp in all_res: if pp is None: continue @@ -289,10 +423,10 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): maxy = int(max(y)) a = sparse.coo_matrix((f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1)).toarray() - + # Accumulate results into simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a this_object[miny:maxy + 1, minx:maxx + 1] += a + bounds.append([minx, maxx, miny, maxy]) if self.cache: self.cached_object[c]['x'].append(x) @@ -306,56 +440,119 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): time2 = time.time() log.debug(f"Elapsed time {time2-time1} sec") + # figure out global bounds of object + if len(bounds) > 0: + bounds = np.array(bounds) + thisobj_minx = int(np.min(bounds[:, 0])) + thisobj_maxx = int(np.max(bounds[:, 1])) + thisobj_miny = int(np.min(bounds[:, 2])) + thisobj_maxy = int(np.max(bounds[:, 3])) + thisobj_bounds = [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] + return (this_object, thisobj_bounds, sid, order) + return (this_object, None, sid, order) + + + @staticmethod + def construct_slitmodel_for_chunk(chunk_data: np.ndarray, + bounds: list, + sid: int, + order: int, + ) -> datamodels.SlitModel: + ''' + Parameters + ---------- + chunk_data : np.ndarray + Dispersed model of segmentation map source + bounds : list + The bounds of the object + sid : int + The source ID + order : int + The spectral order number - return this_object + Returns + ------- + slit : `jwst.datamodels.SlitModel` + Slit model containing the dispersed pixel values + ''' + if bounds is None: + return None + + [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] = bounds + slit = datamodels.SlitModel() + slit.source_id = sid + slit.name = f"source_{sid}" + slit.xstart = thisobj_minx + slit.xsize = thisobj_maxx - thisobj_minx + 1 + slit.ystart = thisobj_miny + slit.ysize = thisobj_maxy - thisobj_miny + 1 + slit.meta.wcsinfo.spectral_order = order + slit.data = chunk_data[thisobj_miny:thisobj_maxy + 1, thisobj_minx:thisobj_maxx + 1] - def disperse_all_from_cache(self, trans=None): - if not self.cache: - return + return slit - self.simulated_image = np.zeros(self.dims, float) - for i in range(len(self.IDs)): - this_object = self.disperse_chunk_from_cache(i, trans=trans) - return this_object +# class ObservationFromCache: +# ''' +# this isn't how it should work. If we're going to use a cache, we should +# be checking if a pixel is in the cache before dispersing it. Then load it if it's there, +# otherwise calculate it. The functions below need to be refactored. +# ''' - def disperse_chunk_from_cache(self, c, trans=None): - """Method that handles the dispersion. To be called after create_pixel_list()""" +# def __init__(self, cache, dims): +# self.cache = cache +# self.dims = dims +# self.simulated_image = np.zeros(self.dims, float) +# self.cached_object = {} - if not self.cache: - return +# def disperse_all_from_cache(self, trans=None): +# if not self.cache: +# return - time1 = time.time() +# self.simulated_image = np.zeros(self.dims, float) - # Initialize blank image for this object - this_object = np.zeros(self.dims, float) +# for i in range(len(self.IDs)): +# this_object = self.disperse_chunk_from_cache(i, trans=trans) - if trans is not None: - log.debug("Applying a transmission function...") +# return this_object - for i in range(len(self.cached_object[c]['x'])): - x = self.cached_object[c]['x'][i] - y = self.cached_object[c]['y'][i] - f = self.cached_object[c]['f'][i] * 1. - w = self.cached_object[c]['w'][i] +# def disperse_chunk_from_cache(self, c, trans=None): +# """Method that handles the dispersion. To be called after create_pixel_list()""" - if trans is not None: - f *= trans(w) +# if not self.cache: +# return - minx = self.cached_object[c]['minx'][i] - maxx = self.cached_object[c]['maxx'][i] - miny = self.cached_object[c]['miny'][i] - maxy = self.cached_object[c]['maxy'][i] +# time1 = time.time() - a = sparse.coo_matrix((f, (y - miny, x - minx)), - shape=(maxy - miny + 1, maxx - minx + 1)).toarray() +# # Initialize blank image for this object +# this_object = np.zeros(self.dims, float) - # Accumulate the results into the simulated images - self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a - this_object[miny:maxy + 1, minx:maxx + 1] += a +# if trans is not None: +# log.debug("Applying a transmission function...") - time2 = time.time() - log.debug(f"Elapsed time {time2-time1} sec") +# for i in range(len(self.cached_object[c]['x'])): +# x = self.cached_object[c]['x'][i] +# y = self.cached_object[c]['y'][i] +# f = self.cached_object[c]['f'][i] * 1. +# w = self.cached_object[c]['w'][i] + +# if trans is not None: +# f *= trans(w) + +# minx = self.cached_object[c]['minx'][i] +# maxx = self.cached_object[c]['maxx'][i] +# miny = self.cached_object[c]['miny'][i] +# maxy = self.cached_object[c]['maxy'][i] + +# a = sparse.coo_matrix((f, (y - miny, x - minx)), +# shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + +# # Accumulate the results into the simulated images +# self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a +# this_object[miny:maxy + 1, minx:maxx + 1] += a + +# time2 = time.time() +# log.debug(f"Elapsed time {time2-time1} sec") - return this_object +# return this_object diff --git a/jwst/wfss_contam/tests/__init__.py b/jwst/wfss_contam/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/wfss_contam/tests/data/__init__.py b/jwst/wfss_contam/tests/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jwst/wfss_contam/tests/data/grism_wcs.asdf b/jwst/wfss_contam/tests/data/grism_wcs.asdf new file mode 100644 index 0000000000..8334394f81 --- /dev/null +++ b/jwst/wfss_contam/tests/data/grism_wcs.asdf @@ -0,0 +1,1180 @@ +#ASDF 1.0.0 +#ASDF_STANDARD 1.5.0 +%YAML 1.1 +%TAG ! tag:stsci.edu:asdf/ +--- !core/asdf-1.1.0 +asdf_library: !core/software-1.0.0 {author: The ASDF Developers, homepage: 'http://github.com/asdf-format/asdf', + name: asdf, version: 3.1.1.dev2+g15e830d} +history: + extensions: + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/astronomy/gwcs/extensions/gwcs-1.2.0 + software: !core/software-1.0.0 {name: gwcs, version: 0.21.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/astronomy/coordinates/extensions/coordinates-1.0.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/core/extensions/core-1.5.0 + software: !core/software-1.0.0 {name: asdf, version: 3.1.1.dev2+g15e830d} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://asdf-format.org/transform/extensions/transform-1.5.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf_astropy._manifest.CompoundManifestExtension + extension_uri: asdf://astropy.org/core/extensions/core-1.5.0 + software: !core/software-1.0.0 {name: asdf-astropy, version: 0.5.0} + - !core/extension_metadata-1.0.0 + extension_class: asdf.extension._manifest.ManifestExtension + extension_uri: asdf://stsci.edu/jwst_pipeline/extensions/jwst_transforms-1.0.0 + software: !core/software-1.0.0 {name: stdatamodels, version: 1.10.1} +wcs: ! + name: '' + pixel_shape: null + steps: + - ! + frame: ! + axes_names: [x_grism, y_grism] + axes_order: [0, 1] + axis_physical_types: ['custom:x_grism', 'custom:y_grism'] + name: grism_detector + unit: [!unit/unit-1.0.0 pixel, !unit/unit-1.0.0 pixel] + transform: !transform/compose-1.2.0 + bounding_box: !transform/property/bounding_box-1.0.0 + ignore: [] + intervals: + x0: [-0.5, 319.5] + x1: [-0.5, 340.5] + order: C + forward: + - !transform/compose-1.2.0 + forward: + - !transform/remap_axes-1.3.0 + inputs: [x0, x1] + mapping: [0, 1, 0, 0, 0] + outputs: [x0, x1, x2, x3, x4] + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: 122.0 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: 1031.0 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 482.26565028841355 + outputs: [y] + value: 482.26565028841355 + inputs: [x0, x1, x] + outputs: [y0, y1, y] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 1205.025009833007 + outputs: [y] + value: 1205.025009833007 + inputs: [x00, x10, x0, x1] + outputs: [y00, y10, y0, y1] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + inverse: !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + outputs: [y] + value: 1.0 + outputs: [y] + value: 1.0 + inputs: [x00, x10, x0, x1, x] + outputs: [y00, y10, y0, y1, y] + inputs: [x0, x1] + outputs: [y00, y10, y0, y1, y] + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - ! + inputs: [x, y, x0, y0, order] + inverse: ! + inputs: [x, y, wavelength, order] + lmodels: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 35 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 36 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 37 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 38 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 39 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + model_type: NIRISSBackwardGrismDispersion + name: niriss_backward_grism_dispersion + orders: [1, 2, 3, -1, 0] + outputs: [x, y, x0, y0, order] + theta: 0.006000000000653927 + xmodels: + - - &id001 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 0 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id002 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 1 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id003 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 2 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id004 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 3 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id005 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 4 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id006 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 5 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id007 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 6 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id008 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 7 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id009 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 8 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id010 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 9 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id011 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 10 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id012 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 11 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id013 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 12 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id014 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 13 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id015 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 14 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + ymodels: + - - &id016 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 15 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id017 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 16 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id018 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 17 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id019 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 18 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id020 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 19 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id021 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 20 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id022 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 21 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id023 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 22 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id024 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 23 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id025 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 24 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id026 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 25 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id027 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 26 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - - &id028 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 27 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id029 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 28 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - &id030 !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 29 + datatype: float64 + byteorder: little + shape: [3, 3] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + lmodels: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 30 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 31 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 32 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 33 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 34 + datatype: float64 + byteorder: little + shape: [2] + domain: [-1, 1] + inputs: [x] + outputs: [y] + window: [-1, 1] + model_type: NIRISSForwardRowGrismDispersion + name: niriss_forward_row_grism_dispersion + orders: [1, 2, 3, -1, 0] + outputs: [x, y, wavelength, order] + theta: -0.006000000000653927 + xmodels: + - - *id001 + - *id002 + - *id003 + - - *id004 + - *id005 + - *id006 + - - *id007 + - *id008 + - *id009 + - - *id010 + - *id011 + - *id012 + - - *id013 + - *id014 + - *id015 + ymodels: + - - *id016 + - *id017 + - *id018 + - - *id019 + - *id020 + - *id021 + - - *id022 + - *id023 + - *id024 + - - *id025 + - *id026 + - *id027 + - - *id028 + - *id029 + - *id030 + - !transform/remap_axes-1.3.0 + inputs: [x0, x1, x2, x3] + mapping: [0, 1, 2, 3] + outputs: [x0, x1, x2, x3] + inputs: [x, y, x0, y0, order] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + - !transform/multiply-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + name: velocity_correction + outputs: [y] + value: 0.99999381609348 + inputs: [x0] + inverse: !transform/divide-1.2.0 + forward: + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + - !transform/constant-1.4.0 + dimensions: 1 + inputs: [x] + name: inv_vel_correction + outputs: [y] + value: 0.99999381609348 + inputs: [x0] + outputs: [x0] + outputs: [x0] + inputs: [x00, x10, x01] + outputs: [x00, x10, x01] + - !transform/identity-1.2.0 + inputs: [x0] + outputs: [x0] + inputs: [x00, x10, x01, x0] + outputs: [x00, x10, x01, x0] + inputs: [x, y, x0, y0, order] + outputs: [x00, x10, x01, x0] + inputs: [x0, x1] + outputs: [x00, x10, x01, x0] + - ! + frame: ! + frames: + - ! + axes_names: [x, y] + axes_order: [0, 1] + axis_physical_types: ['custom:x', 'custom:y'] + name: detectorspatial + unit: [!unit/unit-1.0.0 pixel, !unit/unit-1.0.0 pixel] + - &id033 ! + axes_names: [wavelength] + axes_order: [2] + axis_physical_types: [em.wl] + name: spectral + unit: [!unit/unit-1.0.0 um] + name: detector + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + bounding_box: !transform/property/bounding_box-1.0.0 + ignore: [] + intervals: + x0: [-0.5, 2047.5] + x1: [-0.5, 2047.5] + order: C + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: 2.119 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -1.0476 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: -1023.5 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -1023.5 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - &id031 !transform/remap_axes-1.3.0 + inputs: [x0, x1] + inverse: !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + mapping: [0, 1, 0, 1] + outputs: [x0, x1, x2, x3] + inputs: [x0, x1] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 40 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 41 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + inverse: !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 42 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 43 + datatype: float64 + byteorder: little + shape: [6, 6] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + outputs: [z0, z1] + outputs: [z0, z1] + inputs: [x0, x1] + outputs: [z0, z1] + - &id032 !transform/identity-1.2.0 + inputs: [x0, x1] + inverse: !transform/remap_axes-1.3.0 + inputs: [x0, x1] + mapping: [0, 1, 0, 1] + outputs: [x0, x1, x2, x3] + n_dims: 2 + outputs: [x0, x1] + inputs: [x0, x1] + outputs: [x0, x1] + - *id031 + inputs: [x0, x1] + outputs: [x0, x1, x2, x3] + - !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 44 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 45 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + inverse: !transform/concatenate-1.2.0 + forward: + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 46 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + - !transform/polynomial-1.2.0 + coefficients: !core/ndarray-1.0.0 + source: 47 + datatype: float64 + byteorder: little + shape: [2, 2] + domain: + - [-1, 1] + - [-1, 1] + inputs: [x, y] + outputs: [z] + window: + - [-1, 1] + - [-1, 1] + inputs: [x0, y0, x1, y1] + outputs: [z0, z1] + outputs: [z0, z1] + inputs: [x0, x1] + outputs: [z0, z1] + - *id032 + inputs: [x0, x1] + outputs: [x0, x1] + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + offset: -291.141 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + offset: -698.015 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [y0, y1, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [v2, v3] + axes_order: [0, 1] + axis_physical_types: ['custom:v2', 'custom:v3'] + name: v2v3spatial + unit: [!unit/unit-1.0.0 arcsec, !unit/unit-1.0.0 arcsec] + - *id033 + name: v2v3 + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/scale-1.2.0 + factor: 0.9999939001894596 + inputs: [x] + name: dva_scale_v2 + outputs: [y] + - !transform/scale-1.2.0 + factor: 0.9999939001894596 + inputs: [x] + name: dva_scale_v3 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - !transform/concatenate-1.2.0 + forward: + - !transform/shift-1.2.0 + inputs: [x] + name: dva_v2_shift + offset: -0.0017759049405570732 + outputs: [y] + - !transform/shift-1.2.0 + inputs: [x] + name: dva_v3_shift + offset: -0.004257759254392014 + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + inputs: [x0, x1] + name: DVA_Correction + outputs: [y0, y1] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [y0, y1, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [v2, v3] + axes_order: [0, 1] + axis_physical_types: ['custom:v2', 'custom:v3'] + name: v2v3vacorrspatial + unit: [!unit/unit-1.0.0 arcsec, !unit/unit-1.0.0 arcsec] + - *id033 + name: v2v3vacorr + transform: !transform/concatenate-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/compose-1.2.0 + forward: + - !transform/concatenate-1.2.0 + forward: + - !transform/scale-1.2.0 + factor: 0.0002777777777777778 + inputs: [x] + outputs: [y] + - !transform/scale-1.2.0 + factor: 0.0002777777777777778 + inputs: [x] + outputs: [y] + inputs: [x0, x1] + outputs: [y0, y1] + - ! + inputs: [lon, lat] + outputs: [x, y, z] + transform_type: spherical_to_cartesian + wrap_lon_at: 180 + inputs: [x0, x1] + outputs: [x, y, z] + - !transform/rotate_sequence_3d-1.0.0 + angles: [-0.0808725, 0.19389305555555555, 196.1037680531535, 65.83768326569226, + -260.82488865179715] + axes_order: zyxyz + inputs: [x, y, z] + outputs: [x, y, z] + rotation_type: cartesian + inputs: [x0, x1] + outputs: [x, y, z] + - ! + inputs: [x, y, z] + outputs: [lon, lat] + transform_type: cartesian_to_spherical + wrap_lon_at: 360 + inputs: [x0, x1] + name: v23tosky + outputs: [lon, lat] + - !transform/identity-1.2.0 + inputs: [x0, x1] + n_dims: 2 + outputs: [x0, x1] + inputs: [x00, x10, x01, x11] + outputs: [lon, lat, x0, x1] + - ! + frame: ! + frames: + - ! + axes_names: [lon, lat] + axes_order: [0, 1] + axis_physical_types: [pos.eq.ra, pos.eq.dec] + name: sky + reference_frame: ! + frame_attributes: {} + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] + - *id033 + name: world + transform: null +... +BLK0HHH_nHLdc=yO@uQ?*,U pf=)c{>uF.Z>BLK0HHH.+`1˛XMVOjMts+i>ʳD rr?b>JžBLK0HHHZĬF$}OX +7?= AH?VStouf%'\tqܙ˘Z>BLK0HHHx"~q'oF 'kC4V<-擪;QTc?Mb>BLK0HHH80$zݗYxJSD@Psjx ՞l>EapV}+þȉ + >BLK0HHHҏ98VIik-Skn@q?,zþDL.z?G:# ϾBLK0HHH`l^vCq`rc]Fj?a;ݾoBA2ԣ]?BLK0HHHЂ#`x#Dl]쳿_2>Y3?tU>Ҽs"BLK0HHHy˳y~ZQ@?h$]~7ʿcħ[\D?BLK0HHHO8,nH& O3faw@J_VO?)M¦>Qs?Q"ȗsqPT3BLK0HHHkh. )]A$t@P0Jm5-?i6m۾}x9 V> +$~ך>BLK0HHH/l{k5ҕy4-cV3c31U>&JJ?%"Q>2\BLK0HHHpzkcPtY\RSp?jXNu>uXY$!ѾI?BLK0HHH Q4NfRyW},?T ! + ?]1᭿)&?BLK0HHHRq3:IWw7F5K@CR~Ό?6ep ? Q'LǦ_'?BLK0HHHMy0g? nSXjyoоJL>eY(r>X 7X4BLK0HHH,d3.DMP5?@>hy&TU񾀣d#GAD>BLK0HHH^lH.\2MWy +ɾEY=q>H9e >h1BLK0HHH$ďAR# kQ?vļkն˭BLK0HHHeNSZҘGG-To+6{ 2|ilw>GɰE??i>' + (X͕BLK0HHH-&?Šͅr@?O;(?thoo6 :o>%hB>BLK0HHH8 =w/D"gz,avN?HA@@%E8V$>GfqQ>BLK0HHHxyDk +8&ajOnq>iI2?ף摯9lBLK0HHH+OKV/_s%߶~ץ?. +K? S|iY$v,3׬NI^Q}>BLK0HHHbX"\M#RQ9lCs-4=b!|jF=P.|iDP$BLK0HHHX* 0d`Pf7%ֹdhU=JC= -Ǽ6EBLK0HHHGRn؏X6p%%կ!H$U=5d=z/ǼʟREBLK0_ėV˱+L+??BLK0_ėV˱+L+??BLK0_ėV˱+L+??BLK0_ėV˱+L+??BLK0_ėV˱+L+??BLK0 + "ZU`Zc{޿RJ)?BLK0 + "ZU`Zc{޿RJ)?BLK0 + "ZU`Zc{޿RJ)?BLK0 + "ZU`Zc{޿RJ)?BLK0 + "ZU`Zc{޿RJ)?BLK0   n ^߉HfqDZ;{@ƥz&jIN09Fu?ث0o*\Tvܽ.s& ^q2+SA)B<1[p=jТ%WjE`<ؼļ5?Uk*=BLK0    1|k1Lϐl>.%jGH>X +YfT-.@IBC5R>ϖ~ʽ\,W <->i^ v\sdd>^K_ + d"=ñyN&>zj|Ƚv:=BLK0   \fpMOck+o(>Uc.@R0D^5>@Skb>n=J9$ұ?Ӽ?f>&wRĝCdvb'=A]L*6U>ěX>P7]l=ijq.!&>TԄDؽN%=oq=ez[=BLK0 b"I,%܂x?`BLK0 <6Y_zջ)J&`?܂x?BLK0 b"I,%܂x?`BLK0 <6Y_zջ)J&`?܂x?#ASDF BLOCK INDEX +%YAML 1.1 +--- +- 38436 +- 38562 +- 38688 +- 38814 +- 38940 +- 39066 +- 39192 +- 39318 +- 39444 +- 39570 +- 39696 +- 39822 +- 39948 +- 40074 +- 40200 +- 40326 +- 40452 +- 40578 +- 40704 +- 40830 +- 40956 +- 41082 +- 41208 +- 41334 +- 41460 +- 41586 +- 41712 +- 41838 +- 41964 +- 42090 +- 42216 +- 42286 +- 42356 +- 42426 +- 42496 +- 42566 +- 42636 +- 42706 +- 42776 +- 42846 +- 42916 +- 43258 +- 43600 +- 43942 +- 44284 +- 44370 +- 44456 +- 44542 +... diff --git a/jwst/wfss_contam/tests/data/segmentation_wcs.asdf b/jwst/wfss_contam/tests/data/segmentation_wcs.asdf new file mode 100644 index 0000000000..48c6218c8c Binary files /dev/null and b/jwst/wfss_contam/tests/data/segmentation_wcs.asdf differ diff --git a/jwst/wfss_contam/tests/test_disperse.py b/jwst/wfss_contam/tests/test_disperse.py new file mode 100644 index 0000000000..4ed558fade --- /dev/null +++ b/jwst/wfss_contam/tests/test_disperse.py @@ -0,0 +1,31 @@ +import pytest +import numpy as np +from jwst.wfss_contam.disperse import flux_interpolator_injector, determine_wl_spacing + +''' +Note that main disperse.py call is tested in test_observations.py because +it requires all the fixtures defined there. +''' + +@pytest.mark.parametrize("lams, flxs, extrapolate_sed, expected_outside_bounds", + [([1, 3], [1, 3], False, 0), + ([2], [2], False, 2), + ([1, 3], [1, 3], True, 4)]) +def test_interpolate_fluxes(lams, flxs, extrapolate_sed, expected_outside_bounds): + + flux_interpf = flux_interpolator_injector(lams, flxs, extrapolate_sed) + assert flux_interpf(2.0) == 2.0 + assert flux_interpf(4.0) == expected_outside_bounds + + +@pytest.mark.parametrize("lams, expected_dw", + [([1, 1.2, 1.4], 0.05), + ([1, 1.02, 1.04], 0.01) + ]) +def test_determine_wl_spacing(lams, expected_dw): + + dw = 0.1 + oversample_factor = 2 + dw_out = determine_wl_spacing(dw, np.array(lams), oversample_factor) + + assert np.isclose(dw_out, expected_dw, atol=1e-8) diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py new file mode 100644 index 0000000000..99e21d763e --- /dev/null +++ b/jwst/wfss_contam/tests/test_observations.py @@ -0,0 +1,288 @@ +import pytest +import numpy as np +import asdf +import os + +from astropy.convolution import convolve +from photutils.segmentation import make_2dgaussian_kernel +from astropy.stats import sigma_clipped_stats +from photutils.datasets import make_100gaussians_image +from photutils.segmentation import SourceFinder + +from jwst.wfss_contam.observations import background_subtract, _select_ids, Observation +from jwst.wfss_contam.disperse import dispersed_pixel +from jwst.wfss_contam.tests import data +from jwst.datamodels import SegmentationMapModel, ImageModel, MultiSlitModel + +data_path = os.path.split(os.path.abspath(data.__file__))[0] +DIR_IMAGE = "direct_image.fits" + +#@pytest.fixture(scope='module') +#def create_source_catalog(): +# '''Mock source catalog''' +# pass + +@pytest.fixture(scope='module') +def direct_image(): + data = make_100gaussians_image() + kernel = make_2dgaussian_kernel(3, size=5) + data = convolve(data, kernel) + return data + + +@pytest.fixture(scope='module') +def direct_image_with_gradient(tmp_cwd_module, direct_image): + ny, nx = direct_image.shape + y, x = np.mgrid[:ny, :nx] + gradient = x * y / 5000.0 + data = direct_image + gradient + + # obs expects input list of direct image filenames + model = ImageModel(data=data) + model.save(DIR_IMAGE) + + return model + + +@pytest.fixture(scope='module') +def segmentation_map(direct_image): + mean, median, stddev = sigma_clipped_stats(direct_image, sigma=3.0) + threshold = median+3*stddev + finder = SourceFinder(npixels=10) + segm = finder(direct_image, threshold) + + # turn this into a jwst datamodel + model = SegmentationMapModel(data=segm.data) + with asdf.open(os.path.join(data_path, "segmentation_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] + model.meta.wcs = wcsobj + + return model + + +@pytest.fixture(scope='module') +def grism_wcs(): + with asdf.open(os.path.join(data_path, "grism_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] + return wcsobj + + +@pytest.fixture(scope='module') +def observation(direct_image_with_gradient, segmentation_map, grism_wcs): + ''' + set up observation object with mock data. + direct_image_with_gradient still needs to be run to produce the file, + even though it is not called directly + ''' + filter_name = "F200W" + seg = segmentation_map.data + all_IDs = np.array(list(set(np.ravel(seg)))) + IDs = all_IDs[50:52] + obs = Observation([DIR_IMAGE], segmentation_map, grism_wcs, filter_name, ID=IDs, + sed_file=None, extrapolate_sed=False, + boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1) + return obs + + +@pytest.mark.parametrize("ID, expected", [(None, [1,2,3]), (2, [2]), (np.array([1,3]), [1,3])]) +def test_select_ids(ID, expected): + all_ids = [1, 2, 3] + assert _select_ids(ID, all_ids) == expected + assert isinstance(_select_ids(ID, all_ids), list) + + +def test_background_subtract(direct_image_with_gradient): + + data = direct_image_with_gradient.data + subtracted_data = background_subtract(data) + mean, median, stddev = sigma_clipped_stats(subtracted_data, sigma=3.0) + assert np.isclose(mean, 0.0, atol=0.2*stddev) + + +def test_create_pixel_list(observation, segmentation_map): + ''' + create_pixel_list is called on initialization + compare the pixel lists to values determined directly + from segmentation map + + Note: still need test coverage for flux dictionary + ''' + seg = segmentation_map.data + all_IDs = np.array(list(set(np.ravel(seg)))) + IDs = all_IDs[50:52] + for i, ID in enumerate(IDs): + pixels_y, pixels_x = np.where(seg == ID) + assert np.all(np.isin(observation.xs[i], pixels_x)) + assert np.all(np.isin(observation.ys[i], pixels_y)) + assert len(observation.fluxes[2.0][i]) == pixels_x.size + + +def test_disperse_chunk(observation): + ''' + Note: it's not obvious how to get a trivial flux example from first principles + even setting all input fluxes in dict to 1, because transforms change + pixel areas in nontrivial ways. seems a bad idea to write a test that + asserts the answer as it is currently, in case step gets updated slightly + in the future + ''' + obs = observation + i = 1 + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # set all fluxes to unity to try to make a trivial example + obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + #trivial bookkeeping + assert sid == obs.IDs[i] + assert order == order_out + + # check size of object is same as input dims + assert chunk.shape == obs.dims + + #check that the chunk is zero outside the bounds + assert np.all(chunk[:chunk_bounds[2]-1,:] == 0) + assert np.all(chunk[chunk_bounds[3]+1:,] == 0) + assert np.all(chunk[:,:chunk_bounds[0]-1] == 0) + assert np.all(chunk[:,chunk_bounds[1]+1:] == 0) + + +def test_disperse_chunk_null(observation): + ''' + ensure bounds return None when all dispersion is off image + ''' + obs = observation + i = 0 #i==0 source happens to be far left on image + order = 3 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + assert chunk_bounds is None + assert np.all(chunk == 0) + + +def test_disperse_all(observation): + + obs = observation + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # shorten pixel list to make this test take less time + obs.xs = obs.xs[:3] + obs.ys = obs.ys[:3] + obs.fluxes[2.0] = obs.fluxes[2.0][:3] + obs.disperse_all(order, wmin, wmax, sens_waves, sens_resp, cache=False) + + # test simulated image. should be mostly but not all zeros + assert obs.simulated_image.shape == obs.dims + assert not np.allclose(obs.simulated_image, 0.0) + assert np.median(obs.simulated_image) == 0.0 + + # test simulated slits and their associated metadata + # only the second of the two obs IDs is in the simulated image + assert obs.simul_slits_order == [order,]*1 + assert obs.simul_slits_sid == obs.IDs[-1:] + assert type(obs.simul_slits) == MultiSlitModel + + +def test_disperse_oversample_same_result(grism_wcs, segmentation_map): + ''' + Coverage for bug where wavelength oversampling led to double-counted fluxes + + note: segmentation_map fixture needs to be able to find module-scoped direct_image + fixture, so it must be imported here + ''' + + # manual input of input params set the same as test_observations.py + x0 = 300.5 + y0 = 300.5 + order = 1 + width = 1.0 + height = 1.0 + lams = [2.0] + flxs = [1.0] + ID = 0 + naxis = (300, 500) + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + seg_wcs = segmentation_map.meta.wcs + 0, (300, 500), 2, False, + xoffset = 2200 + yoffset = 1000 + + + xs, ys, areas, lams_out, counts_1, ID = dispersed_pixel( + x0, y0, width, height, lams, flxs, order, wmin, wmax, + sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, + yoffset=yoffset) + + xs, ys, areas, lams_out, counts_3, ID = dispersed_pixel( + x0, y0, width, height, lams, flxs, order, wmin, wmax, + sens_waves, sens_resp, seg_wcs, grism_wcs, ID, naxis, + oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, + yoffset=yoffset) + + assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) + + +def test_construct_slitmodel_for_chunk(observation): + ''' + test that the chunk is constructed correctly + ''' + obs = observation + i = 1 + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # set all fluxes to unity to try to make a trivial example + obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + slit = obs.construct_slitmodel_for_chunk(chunk, chunk_bounds, sid, order_out) + + # check that the metadata is correct + assert slit.xstart == chunk_bounds[0] + assert slit.xsize == chunk_bounds[1] - chunk_bounds[0] + 1 + assert slit.ystart == chunk_bounds[2] + assert slit.ysize == chunk_bounds[3] - chunk_bounds[2] + 1 + assert slit.source_id == sid + assert slit.meta.wcsinfo.spectral_order == order_out + assert np.allclose(slit.data, chunk[chunk_bounds[2]:chunk_bounds[3]+1, chunk_bounds[0]:chunk_bounds[1]+1]) diff --git a/jwst/wfss_contam/tests/test_wfss_contam.py b/jwst/wfss_contam/tests/test_wfss_contam.py new file mode 100644 index 0000000000..b1418b77d1 --- /dev/null +++ b/jwst/wfss_contam/tests/test_wfss_contam.py @@ -0,0 +1,116 @@ +import pytest +from jwst.wfss_contam.wfss_contam import CommonSlitEncompass, CommonSlitPreferFirst, SlitOverlapError, UnmatchedSlitIDError, determine_multiprocessing_ncores, _cut_frame_to_match_slit, _find_matching_simul_slit +from jwst.datamodels import SlitModel +import numpy as np + + +@pytest.mark.parametrize("max_cores, num_cores, expected", + [("none", 4, 1), + ("quarter", 4, 1), + ("half", 4, 2), + ("all", 4, 4), + ("none", 1, 1), + (None, 1, 1,), + (3, 5, 3)]) +def test_determine_multiprocessing_ncores(max_cores, num_cores, expected): + assert determine_multiprocessing_ncores(max_cores, num_cores) == expected + + +@pytest.fixture(scope="module") +def contam(): + return np.ones((10, 10))*0.1 + +@pytest.fixture(scope="module") +def slit0(): + slit = SlitModel(data=np.ones((5, 3))) + slit.xstart = 2 + slit.ystart = 3 + slit.xsize = 3 + slit.ysize = 5 + slit.meta.wcsinfo.spectral_order = 1 + slit.source_id = 1 + return slit + + +@pytest.fixture(scope="module") +def slit1(): + slit = SlitModel(data=np.ones((4, 4))*0.5) + slit.xstart = 3 + slit.ystart = 2 + slit.xsize = 4 + slit.ysize = 4 + return slit + + +@pytest.fixture(scope="module") +def slit2(): + slit = SlitModel(data=np.ones((3, 5))*0.1) + slit.xstart = 300 + slit.ystart = 200 + slit.xsize = 5 + slit.ysize = 3 + return slit + + +def test_find_matching_simul_slit(slit0): + sids = [0, 1, 1] + orders = [1, 1, 2] + idx = _find_matching_simul_slit(slit0, sids, orders) + assert idx == 1 + + +def test_find_matching_simul_slit_no_match(slit0): + sids = [0, 1, 1] + orders = [1, 2, 2] + with pytest.raises(UnmatchedSlitIDError): + _find_matching_simul_slit(slit0, sids, orders) + + +def test_cut_frame_to_match_slit(slit0, contam): + cut_contam = _cut_frame_to_match_slit(contam, slit0) + assert cut_contam.shape == (5, 3) + assert np.all(cut_contam == 0.1) + + +def test_common_slit_encompass(slit0, slit1): + slit0_final, slit1_final = CommonSlitEncompass(slit0.copy(), slit1.copy()).match_backplane() + + # check indexing in metadata + assert slit0_final.xstart == slit1_final.xstart + assert slit0_final.ystart == slit1_final.ystart + assert slit0_final.xsize == slit1_final.xsize + assert slit0_final.ysize == slit1_final.ysize + assert slit0_final.data.shape == slit1_final.data.shape + + # check data overlap + assert np.count_nonzero(slit0_final.data) == 15 + assert np.count_nonzero(slit1_final.data) == 16 + assert np.count_nonzero(slit0_final.data * slit1_final.data) == 6 + + # check data values + assert np.all(slit0_final.data[1:6, 0:3] == 1) + assert np.all(slit1_final.data[0:4, 1:5] == 0.5) + + +def test_common_slit_prefer(slit0, slit1): + + slit0_final, slit1_final = CommonSlitPreferFirst(slit0.copy(), slit1.copy()).match_backplane() + assert slit0_final.xstart == slit0.xstart + assert slit0_final.ystart == slit0.ystart + assert slit0_final.xsize == slit0.xsize + assert slit0_final.ysize == slit0.ysize + assert slit0_final.data.shape == slit0.data.shape + assert np.all(slit0_final.data == slit0.data) + + assert slit1_final.xstart == slit0.xstart + assert slit1_final.ystart == slit0.ystart + assert slit1_final.xsize == slit0.xsize + assert slit1_final.ysize == slit0.ysize + assert slit1_final.data.shape == slit0.data.shape + assert np.count_nonzero(slit1_final.data) == 6 + + +def test_common_slit_prefer_expected_raise(slit0, slit2): + + with pytest.raises(SlitOverlapError): + CommonSlitPreferFirst(slit0.copy(), slit2.copy()).match_backplane() \ No newline at end of file diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 841538442f..77b3446475 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -1,11 +1,11 @@ -# -# Top level module for WFSS contamination correction. -# import logging import multiprocessing +from typing import Protocol, Union import numpy as np from stdatamodels.jwst import datamodels +from astropy.table import Table +import copy from .observations import Observation from .sens1d import get_photom_data @@ -14,7 +14,222 @@ log.setLevel(logging.DEBUG) -def contam_corr(input_model, waverange, photom, max_cores): +def determine_multiprocessing_ncores(max_cores: Union[str, int], num_cores) -> int: + + """Determine the number of cores to use for multiprocessing. + + Parameters + ---------- + max_cores : string or int + Number of cores to use for multiprocessing. If set to 'none' + (the default), then no multiprocessing will be done. The other + allowable string values are 'quarter', 'half', and 'all', which indicate + the fraction of cores to use for multi-proc. The total number of + cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. + num_cores : int + Number of cores available on the machine + + Returns + ------- + ncpus : int + Number of cores to use for multiprocessing + """ + match max_cores: + case 'none': + return 1 + case None: + return 1 + case 'quarter': + return num_cores // 4 or 1 + case 'half': + return num_cores // 2 or 1 + case 'all': + return num_cores + case int(): + if max_cores <= num_cores and max_cores > 0: + return max_cores + log.warning(f"Requested {max_cores} cores exceeds the number of cores available on this machine ({num_cores}). Using all available cores.") + return max_cores + case _: + raise ValueError(f"Invalid value for max_cores: {max_cores}") + + +class UnmatchedSlitIDError(Exception): + pass + + +def _find_matching_simul_slit(slit: datamodels.SlitModel, + simul_slit_sids: list[int], + simul_slit_orders: list[int], + ) -> int: + """ + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Source slit model + simul_slit_sids : list + List of source IDs for simulated slits + simul_slit_orders : list + List of spectral orders for simulated slits + + Returns + ------- + good_idx : int + Index of the matching simulated slit in the list of simulated slits + """ + + # Retrieve simulated slit for this source only + sid = slit.source_id + order = slit.meta.wcsinfo.spectral_order + good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order) + if not any(good): + raise UnmatchedSlitIDError(f"Source ID {sid} order {order} requested by input slit model \ + but not found in simulated slits. Setting contamination correction to zero for that slit.") + return np.where(good)[0][0] + + +def _cut_frame_to_match_slit(contam: np.ndarray, slit: datamodels.SlitModel) -> np.ndarray: + + """Cut out the contamination image to match the extent of the source slit. + + Parameters + ---------- + contam : 2D array + Contamination image for the full grism exposure + slit : `~jwst.datamodels.SlitModel` + Source slit model + + Returns + ------- + cutout : 2D array + Contamination image cutout that matches the extent of the source slit + + """ + x1 = slit.xstart + y1 = slit.ystart + cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] + + return cutout + + +class SlitOverlapError(Exception): + pass + +class CommonSlit(Protocol): + ''' + class protocol for two slits that represent the same source and order, e.g. data and model + ''' + slit0: datamodels.SlitModel + slit1: datamodels.SlitModel + + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + ... + + +class CommonSlitPreferFirst(CommonSlit): + ''' + Treat slit0 as the reference slit, and match attributes of slit1 to it + ''' + def __init__(self, slit0: datamodels.SlitModel, slit1: datamodels.SlitModel): + self.slit0 = slit0 + self.slit1 = slit1 + + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + + data0 = self.slit0.data + data1 = self.slit1.data + + x1 = self.slit1.xstart - self.slit0.xstart + y1 = self.slit1.ystart - self.slit0.ystart + backplane1 = np.zeros_like(data0) + + i0 = max([y1,0]) + i1 = min([y1+data1.shape[0], data0.shape[0], data1.shape[0]]) + j0 = max([x1,0]) + j1 = min([x1+data1.shape[1], data0.shape[1], data1.shape[1]]) + if i0 >= i1 or j0 >= j1: + raise SlitOverlapError(f"No overlap region between data and model for slit {self.slit0.source_id}, \ + order {self.slit0.meta.wcsinfo.spectral_order}. \ + Setting contamination correction to zero for that slit.") + + backplane1[i0:i1, j0:j1] = data1[i0:i1, j0:j1] + + self.slit1.data = backplane1 + self.slit1.xstart = self.slit0.xstart + self.slit1.ystart = self.slit0.ystart + self.slit1.xsize = self.slit0.xsize + self.slit1.ysize = self.slit0.ysize + + return self.slit0, self.slit1 + + +class CommonSlitEncompass(CommonSlit): + ''' + Encompass the data from both slits in a common backplane + ''' + def __init__(self, slit0: datamodels.SlitModel, slit1: datamodels.SlitModel): + self.slit0 = slit0 + self.slit1 = slit1 + + def match_backplane(self) -> tuple[datamodels.SlitModel, datamodels.SlitModel]: + ''' + put data from the two slits into a common backplane + so outputs have the same dimensions + and alignment is based on slit.xstart, slit.ystart + + Parameters + ---------- + slit0 : SlitModel + First slit model + slit1 : SlitModel + Second slit model + + Returns + ------- + slit0 : SlitModel + First slit model with data updated to common backplane + slit1 : SlitModel + Second slit model with data updated to common backplane + ''' + + data0 = self.slit0.data + data1 = self.slit1.data + + shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1])) + xmin = min(self.slit0.xstart, self.slit1.xstart) + ymin = min(self.slit0.ystart, self.slit1.ystart) + shape = max(self.slit0.xsize + self.slit0.xstart - xmin, + self.slit1.xsize + self.slit1.xstart - xmin), \ + max(self.slit0.ysize + self.slit0.ystart - ymin, + self.slit1.ysize + self.slit1.ystart - ymin) + x0 = self.slit0.xstart - xmin + y0 = self.slit0.ystart - ymin + x1 = self.slit1.xstart - xmin + y1 = self.slit1.ystart - ymin + + backplane0 = np.zeros(shape).T + backplane0[y0:y0+data0.shape[0], x0:x0+data0.shape[1]] = data0 + backplane1 = np.zeros(shape).T + backplane1[y1:y1+data1.shape[0], x1:x1+data1.shape[1]] = data1 + + self.slit0.data = backplane0 + self.slit1.data = backplane1 + for slit in [self.slit0, self.slit1]: + slit.xstart = xmin + slit.ystart = ymin + slit.xsize = shape[0] + slit.ysize = shape[1] + + return self.slit0, self.slit1 + + +def contam_corr(input_model: datamodels.MultiSlitModel, + waverange: datamodels.WavelengthrangeModel, + photom: datamodels.NrcWfssPhotomModel | datamodels.NisWfssPhotomModel, + max_cores: str | int = "none", + brightest_n: int = None, + ) -> tuple[datamodels.MultiSlitModel, datamodels.ImageModel, datamodels.MultiSlitModel, datamodels.MultiSlitModel]: """ The main WFSS contamination correction function @@ -25,13 +240,20 @@ def contam_corr(input_model, waverange, photom, max_cores): waverange : `~jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` - Photom (flux cal) reference file model - max_cores : string + Photom (flux cal) reference file model + max_cores : string or int Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other - allowable values are 'quarter', 'half', and 'all', which indicate + allowable string values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. + brightest_n : int + Number of sources to simulate. If None, then all sources in the + input model will be simulated. Requires loading the source catalog + file if not None. Note runtime scales non-linearly with this number + because brightest (and therefore typically largest) sources are + simulated first. Returns ------- @@ -43,28 +265,15 @@ def contam_corr(input_model, waverange, photom, max_cores): Contamination estimate images for each source slit """ - # Determine number of cpu's to use for multi-processing - if max_cores == 'none': - ncpus = 1 - else: - num_cores = multiprocessing.cpu_count() - if max_cores == 'quarter': - ncpus = num_cores // 4 or 1 - elif max_cores == 'half': - ncpus = num_cores // 2 or 1 - elif max_cores == 'all': - ncpus = num_cores - else: - ncpus = 1 - log.debug(f"Found {num_cores} cores; using {ncpus}") + + num_cores = multiprocessing.cpu_count() + ncpus = determine_multiprocessing_ncores(max_cores, num_cores) # Initialize output model output_model = input_model.copy() - # Get the segmentation map for this grism exposure + # Get the segmentation map, direct image for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) - - # Get the direct image from which the segmentation map was constructed direct_file = input_model.meta.direct_image image_names = [direct_file] log.debug(f"Direct image names={image_names}") @@ -98,32 +307,38 @@ def contam_corr(input_model, waverange, photom, max_cores): else: filter_name = filter_kwd - # Load lists of wavelength ranges and flux cal info for all orders - wmin = {} - wmax = {} - sens_waves = {} - sens_response = {} - for order in spec_orders: - wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) - wmin[order] = wavelength_range[order][0] - wmax[order] = wavelength_range[order][1] - # Load the sensitivity (inverse flux cal) data for this mode and order - sens_waves[order], sens_response[order] = get_photom_data(photom, filter_kwd, pupil_kwd, order) - log.debug(f"wmin={wmin}, wmax={wmax}") + # select a subset of the brightest sources using source catalog + if brightest_n is not None: + source_catalog = Table.read(input_model.meta.source_catalog, format='ascii.ecsv') + source_catalog.sort("isophotal_abmag", reverse=False) #magnitudes in ascending order, since brighter is smaller mag number + selected_IDs = list(source_catalog["label"])[:brightest_n] + else: + selected_IDs = None - # Initialize the simulated image object - simul_all = None obs = Observation(image_names, seg_model, grism_wcs, filter_name, - boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus) + boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus, + ID=selected_IDs) + + good_slits = [slit for slit in output_model.slits if slit.source_id in obs.IDs] + output_model = datamodels.MultiSlitModel() + output_model.update(input_model) + output_model.slits.extend(good_slits) + log.info(f"Simulating only the brightest {brightest_n} sources") + - # Create simulated grism image for each order and sum them up + simul_all = None for order in spec_orders: - log.info(f"Creating full simulated grism image for order {order}") - obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], - sens_response[order]) + # Load lists of wavelength ranges and flux cal info + wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) + wmin = wavelength_range[order][0] + wmax = wavelength_range[order][1] + log.debug(f"wmin={wmin}, wmax={wmax} for order {order}") + sens_waves, sens_response = get_photom_data(photom, filter_kwd, pupil_kwd, order) - # Accumulate result for this order into the combined image + # Create simulated grism image for each order and sum them up + log.info(f"Creating full simulated grism image for order {order}") + obs.disperse_all(order, wmin, wmax, sens_waves, sens_response) if simul_all is None: simul_all = obs.simulated_image else: @@ -133,6 +348,9 @@ def contam_corr(input_model, waverange, photom, max_cores): simul_model = datamodels.ImageModel(data=simul_all) simul_model.update(input_model, only="PRIMARY") + simul_slit_sids = np.array(obs.simul_slits_sid) + simul_slit_orders = np.array(obs.simul_slits_order) + # Loop over all slits/sources to subtract contaminating spectra log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() @@ -140,62 +358,27 @@ def contam_corr(input_model, waverange, photom, max_cores): slits = [] for slit in output_model.slits: - # Create simulated spectrum for this source only - sid = slit.source_id - order = slit.meta.wcsinfo.spectral_order - chunk = np.where(obs.IDs == sid)[0][0] # find chunk for this source - - obs.simulated_image = np.zeros(obs.dims) - obs.disperse_chunk(chunk, order, wmin[order], wmax[order], - sens_waves[order], sens_response[order]) - this_source = obs.simulated_image - - # Contamination estimate is full simulated image minus this source - contam = simul_all - this_source + try: + good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders) + this_simul = obs.simul_slits.slits[good_idx] + slit, this_simul = CommonSlitPreferFirst(slit, this_simul).match_backplane() + simul_all_cut = _cut_frame_to_match_slit(simul_all, slit) + contam_cut = simul_all_cut - this_simul.data - # Create a cutout of the contam image that matches the extent - # of the source slit - x1 = slit.xstart - 1 - y1 = slit.ystart - 1 - cutout = contam[y1:y1 + slit.ysize, x1:x1 + slit.xsize] - new_slit = datamodels.SlitModel(data=cutout) - copy_slit_info(slit, new_slit) - slits.append(new_slit) + except (UnmatchedSlitIDError, SlitOverlapError) as e: + log.warning(e) + contam_cut = np.zeros_like(slit.data) + + contam_slit = copy.copy(slit) + contam_slit.data = contam_cut + slits.append(contam_slit) - # Subtract the cutout from the source slit - slit.data -= cutout + # Subtract the contamination from the source slit + slit.data -= contam_cut # Save the contamination estimates for all slits contam_model.slits.extend(slits) - # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = 'COMPLETE' - return output_model, simul_model, contam_model - - -def copy_slit_info(input_slit, output_slit): - - """Copy meta info from one slit to another. - - Parameters - ---------- - input_slit : SlitModel - Input slit model from which slit-specific info will be copied - - output_slit : SlitModel - Output slit model to which slit-specific info will be copied - - """ - output_slit.name = input_slit.name - output_slit.xstart = input_slit.xstart - output_slit.ystart = input_slit.ystart - output_slit.xsize = input_slit.xsize - output_slit.ysize = input_slit.ysize - output_slit.source_id = input_slit.source_id - output_slit.source_type = input_slit.source_type - output_slit.source_xpos = input_slit.source_xpos - output_slit.source_ypos = input_slit.source_ypos - output_slit.meta.wcsinfo.spectral_order = input_slit.meta.wcsinfo.spectral_order - output_slit.meta.wcsinfo.dispersion_direction = input_slit.meta.wcsinfo.dispersion_direction - output_slit.meta.wcs = input_slit.meta.wcs + return output_model, simul_model, contam_model, obs.simul_slits diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index 91ee451195..52944538d0 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -20,16 +20,17 @@ class WfssContamStep(Step): save_contam_images = boolean(default=False) # Save source contam estimates maximum_cores = option('none', 'quarter', 'half', 'all', default='none') skip = boolean(default=True) + brightest_n = integer(default=None) """ reference_file_types = ['photom', 'wavelengthrange'] - def process(self, input_model, *args, **kwargs): + def process(self, + input_model: str | datamodels.MultiSlitModel, + ) -> datamodels.MultiSlitModel: with datamodels.open(input_model) as dm: - max_cores = self.maximum_cores - # Get the wavelengthrange ref file waverange_ref = self.get_reference_file(dm, 'wavelengthrange') self.log.info(f'Using WAVELENGTHRANGE reference file {waverange_ref}') @@ -40,18 +41,20 @@ def process(self, input_model, *args, **kwargs): self.log.info(f'Using PHOTOM reference file {photom_ref}') photom_model = datamodels.open(photom_ref) - result, simul, contam = wfss_contam.contam_corr(dm, + result, simul, contam, simul_slits = wfss_contam.contam_corr(dm, waverange_model, photom_model, - max_cores) + self.maximum_cores, + brightest_n=self.brightest_n) # Save intermediate results, if requested if self.save_simulated_image: simul_path = self.save_model(simul, suffix="simul", force=True) self.log.info(f'Full-frame simulated grism image saved to "{simul_path}"') + simul_slits_path = self.save_model(simul_slits, suffix="simul_slits", force=True) + self.log.info(f'Simulated slits saved to "{simul_slits_path}"') if self.save_contam_images: contam_path = self.save_model(contam, suffix="contam", force=True) self.log.info(f'Contamination estimates saved to "{contam_path}"') - # Return the corrected data return result