diff --git a/jwst/assign_wcs/nircam.py b/jwst/assign_wcs/nircam.py index 1c8c7a2fea..d484142605 100644 --- a/jwst/assign_wcs/nircam.py +++ b/jwst/assign_wcs/nircam.py @@ -212,12 +212,12 @@ def tsgrism(input_model, reference_files): # Get the disperser parameters which are defined as a model for each # spectral order with NIRCAMGrismModel(reference_files['specwcs']) as f: - displ = f.displ - dispx = f.dispx - dispy = f.dispy - invdispx = f.invdispx - invdispl = f.invdispl - orders = f.orders + displ = f.displ.instance + dispx = f.dispx.instance + dispy = f.dispy.instance + invdispx = f.invdispx.instance + invdispl = f.invdispl.instance + orders = f.orders.instance # now create the appropriate model for the grismr det2det = NIRCAMForwardRowGrismDispersion(orders, @@ -377,13 +377,13 @@ def wfss(input_model, reference_files): # Get the disperser parameters which are defined as a model for each # spectral order with NIRCAMGrismModel(reference_files['specwcs']) as f: - displ = f.displ - dispx = f.dispx - dispy = f.dispy - invdispx = f.invdispx - invdispy = f.invdispy - invdispl = f.invdispl - orders = f.orders + displ = f.displ.instance + dispx = f.dispx.instance + dispy = f.dispy.instance + invdispx = f.invdispx.instance + invdispy = f.invdispy.instance + invdispl = f.invdispl.instance + orders = f.orders.instance # now create the appropriate model for the grism[R/C] if "GRISMR" in input_model.meta.instrument.pupil: diff --git a/jwst/assign_wcs/niriss.py b/jwst/assign_wcs/niriss.py index 754cb50f06..ca35937775 100644 --- a/jwst/assign_wcs/niriss.py +++ b/jwst/assign_wcs/niriss.py @@ -407,11 +407,11 @@ def wfss(input_model, reference_files): # Get the disperser parameters which are defined as a model for each # spectral order with NIRISSGrismModel(reference_files['specwcs']) as f: - dispx = f.dispx - dispy = f.dispy - displ = f.displ - invdispl = f.invdispl - orders = f.orders + dispx = f.dispx.instance + dispy = f.dispy.instance + displ = f.displ.instance + invdispl = f.invdispl.instance + orders = f.orders.instance fwcpos_ref = f.fwcpos_ref # This is the actual rotation from the input model diff --git a/jwst/wfss_contam/disperse.py b/jwst/wfss_contam/disperse.py index 72e5ad7770..e454c0ea43 100644 --- a/jwst/wfss_contam/disperse.py +++ b/jwst/wfss_contam/disperse.py @@ -1,3 +1,4 @@ +from functools import partial import numpy as np import warnings @@ -7,6 +8,86 @@ from .sens1d import create_1d_sens +def _flat_lam(fluxes, _lams) -> np.ndarray: + return fluxes[0] + + +def flux_interpolator_injector( + lams, + flxs, + extrapolate_sed, +): + """ + Create a function that returns the flux at a given wavelength. + + If only one direct image is in use, this function will always return the same value. + + Parameters + ---------- + lams : np.ndarray[float] + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + flxs : np.ndarray[float] + Array of fluxes (flam) for the pixels contained in x0, y0. If a single + direct image is in use, this will be a single value. + extrapolate_sed : bool + Whether to allow for the SED of the object to be extrapolated + when it does not fully cover the needed wavelength range. Default if False. + + Returns + ------- + flux : function + Function that returns the flux at a given wavelength. + If only one direct image is in use, this function will always return the same value. + """ + if len(lams) > 1: + # If we have direct image flux values from more than one filter (lams), + # we have the option to extrapolate the fluxes outside the + # wavelength range of the direct images + if extrapolate_sed is False: + return interp1d(lams, flxs, fill_value=0.0, bounds_error=False) + else: + return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) + else: + # If we only have flux from one wavelength, just use that + # single flux value at all wavelengths + return partial(_flat_lam, flxs) + + +def determine_wl_spacing( + dw, + lams, + oversample_factor, +): + """ + Determine the wavelength spacing to use for the dispersed pixels. + + Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, + whichever is smaller, divided by oversampling requested + + Parameters + ---------- + dw : float + The natural wavelength scale of the grism image + lams : np.ndarray[float] + Array of wavelengths corresponding to the fluxes (flxs) for each pixel. + One wavelength per direct image, so can be a single value. + oversample_factor : int + The amount of oversampling + + Returns + ------- + dlam : float + The wavelength spacing to use for the dispersed pixels + """ + # + if len(lams) > 1: + input_dlam = np.median(lams[1:] - lams[:-1]) + if input_dlam < dw: + return input_dlam / oversample_factor + return dw / oversample_factor + + def dispersed_pixel( x0, y0, @@ -95,25 +176,17 @@ def dispersed_pixel( 1D array of the wavelengths of each dispersed pixel counts : array 1D array of counts for each dispersed pixel + source_id : int + The source ID of the source being processed. Returned in the output unmodified; + used only for bookkeeping. TODO this is not implemented properly right now and + should probably just be removed. """ # Setup the transforms we need from the input WCS objects sky_to_imgxy = grism_wcs.get_transform("world", "detector") imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector") - # Setup function for retrieving flux values at each dispersed wavelength - if len(lams) > 1: - # If we have direct image flux values from more than one filter (lambda), - # we have the option to extrapolate the fluxes outside the - # wavelength range of the direct images - if extrapolate_sed is False: - flux = interp1d(lams, flxs, fill_value=0.0, bounds_error=False) - else: - flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False) - else: - # If we only have flux from one lambda, just use that - # single flux value at all wavelengths - def flux(_x): - return flxs[0] + # Set up function for retrieving flux values at each dispersed wavelength + flux_interpolator = flux_interpolator_injector(lams, flxs, extrapolate_sed) # Get x/y positions in the grism image corresponding to wmin and wmax: # Start with RA/Dec of the input pixel position in segmentation map, @@ -127,20 +200,9 @@ def flux(_x): dxw = xwmax - xwmin dyw = ywmax - ywmin - # Compute the delta-wave per pixel - dw = np.abs((wmax - wmin) / (dyw - dxw)) - - # Use a natural wavelength scale or the wavelength scale of the input SED/spectrum, - # whichever is smaller, divided by oversampling requested - if len(lams) > 1: - input_dlam = np.median(lams[1:] - lams[:-1]) - if input_dlam < dw: - dlam = input_dlam / oversample_factor - else: - # this value gets used when we only have 1 direct image wavelength - dlam = dw / oversample_factor - # Create list of wavelengths on which to compute dispersed pixels + dw = np.abs((wmax - wmin) / (dyw - dxw)) + dlam = determine_wl_spacing(dw, lams, oversample_factor) lambdas = np.arange(wmin, wmax + dlam, dlam) n_lam = len(lambdas) @@ -174,9 +236,11 @@ def flux(_x): # values are naturally in units of physical fluxes, so we divide out # the sensitivity (flux calibration) values to convert to units of # countrate (DN/s). + # flux_interpolator(lams) is either single-valued (for a single direct image) + # or an array of the same length as lams (for multiple direct images in different filters) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") - counts = flux(lams) * areas / (sens * oversample_factor) + counts = flux_interpolator(lams) * areas / (sens * oversample_factor) counts[no_cal] = 0.0 # set to zero where no flux cal info available return xs, ys, areas, lams, counts, source_id diff --git a/jwst/wfss_contam/observations.py b/jwst/wfss_contam/observations.py index 8b975b4c85..249b29829d 100644 --- a/jwst/wfss_contam/observations.py +++ b/jwst/wfss_contam/observations.py @@ -1,6 +1,6 @@ import time -import multiprocessing import numpy as np +import multiprocessing as mp from scipy import sparse @@ -17,8 +17,20 @@ log.setLevel(logging.DEBUG) +def _disperse_multiprocess(pars, max_cpu): + ctx = mp.get_context("forkserver") + with ctx.Pool(max_cpu) as mypool: + all_res = mypool.starmap(dispersed_pixel, pars) + + return all_res + + def background_subtract( - data, box_size=None, filter_size=(3, 3), sigma=3.0, exclude_percentile=30.0 + data, + box_size=None, + filter_size=(3, 3), + sigma=3.0, + exclude_percentile=30.0, ): """ Apply a simple astropy background subtraction. @@ -64,6 +76,39 @@ def background_subtract( return data - bkg.background +def _select_ids(source_id, all_ids): + """ + Select the source IDs to be processed based on the input ID parameter. + + Parameters + ---------- + source_id : int or list-like + ID(s) of source to process. If None, all sources processed. + all_ids : np.ndarray + Array of all source IDs in the segmentation map + + Returns + ------- + selected_IDs : list + List of selected source IDs + """ + if source_id is None: + log.info(f"Loading all {len(all_ids)} sources from segmentation map") + return all_ids + + elif isinstance(source_id, int): + log.info(f"Loading single source {source_id} from segmentation map") + return [source_id] + + elif isinstance(source_id, list) or isinstance(source_id, np.ndarray): + log.info( + f"Loading {len(source_id)} of {len(all_ids)} selected sources from segmentation map" + ) + return list(source_id) + else: + raise ValueError("ID must be an integer or a list of integers") + + class Observation: """Define an observation leading to a single grism image.""" @@ -73,7 +118,7 @@ def __init__( segmap_model, grism_wcs, filter_name, - source_id=0, + source_id=None, sed_file=None, extrapolate_sed=False, boundaries=None, @@ -121,9 +166,10 @@ def __init__( self.seg_wcs = segmap_model.meta.wcs self.grism_wcs = grism_wcs self.source_id = source_id - self.source_ids = [] self.dir_image_names = direct_images self.seg = segmap_model.data + all_ids = np.array(list(set(np.ravel(self.seg)))) + self.source_ids = _select_ids(source_id, all_ids) self.filter = filter_name self.sed_file = sed_file # should always be NONE for baseline pipeline (use flat SED) self.cache = False @@ -136,13 +182,13 @@ def __init__( if len(boundaries) == 0: log.debug("No boundaries passed.") self.xstart = 0 - self.xend = self.xstart + self.dims[0] - 1 + self.xend = self.xstart + self.seg.shape[0] - 1 self.ystart = 0 - self.yend = self.ystart + self.dims[1] - 1 + self.yend = self.ystart + self.seg.shape[1] - 1 else: self.xstart, self.xend, self.ystart, self.yend = boundaries self.dims = (self.yend - self.ystart + 1, self.xend - self.xstart + 1) - log.debug(f"Using simulated image size of {self.dims[1]} {self.dims[0]}") + log.debug(f"Using simulated image size of ({self.dims[1]}, {self.dims[0]}).") # Allow for SED extrapolation self.extrapolate_sed = extrapolate_sed @@ -152,32 +198,24 @@ def __init__( # Create pixel lists for sources labeled in segmentation map self.create_pixel_list() + # Initialize the list of slits + self.simul_slits = datamodels.MultiSlitModel() + self.simul_slits_order = [] + self.simul_slits_sid = [] + def create_pixel_list(self): - """Create a list of pixels to be dispersed, grouped per object ID.""" - if self.source_id == 0: - # When source_id=0, all sources in the segmentation map are processed. - # This creates a huge list of all x,y pixel indices that have non-zero values - # in the seg map, sorted by those indices belonging to a particular source ID. - self.xs = [] - self.ys = [] - all_ids = np.array(list(set(np.ravel(self.seg)))) - all_ids = all_ids[all_ids > 0] - self.source_ids = all_ids - log.info(f"Loading {len(all_ids)} sources from segmentation map") - for source_id in all_ids: - ys, xs = np.nonzero(self.seg == source_id) - if len(xs) > 0 and len(ys) > 0: - self.xs.append(xs) - self.ys.append(ys) + """ + Create a list of pixels to be dispersed, grouped per object ID. - else: - # Process only the given source ID - log.info(f"Loading source {self.source_id} from segmentation map") - ys, xs = np.nonzero(self.seg == self.source_id) + When ID is None, all sources in the segmentation map are processed. + """ + self.xs = [] + self.ys = [] + for source_id in self.source_ids: + ys, xs = np.nonzero(self.seg == source_id) if len(xs) > 0 and len(ys) > 0: - self.xs = [xs] - self.ys = [ys] - self.source_ids = [self.source_id] + self.xs.append(xs) + self.ys.append(ys) # Populate lists of direct image flux values for the sources. self.fluxes = {} @@ -216,7 +254,15 @@ def create_pixel_list(self): for i in range(len(self.source_ids)): self.fluxes["sed"].append(dnew[self.ys[i], self.xs[i]]) - def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): + def disperse_all( + self, + order, + wmin, + wmax, + sens_waves, + sens_resp, + cache=False, + ): """ Compute dispersed pixel values for all sources identified in the segmentation map. @@ -241,7 +287,8 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): # Initialize the simulated dispersed image self.simulated_image = np.zeros(self.dims, float) - # Loop over all source IDs from segmentation map + # Loop over all source ID's from segmentation map + pool_args = [] for i in range(len(self.source_ids)): if self.cache: self.cached_object[i] = {} @@ -254,9 +301,46 @@ def disperse_all(self, order, wmin, wmax, sens_waves, sens_resp, cache=False): self.cached_object[i]["miny"] = [] self.cached_object[i]["maxy"] = [] - self.disperse_chunk(i, order, wmin, wmax, sens_waves, sens_resp) + disperse_chunk_args = [ + i, + order, + wmin, + wmax, + sens_waves, + sens_resp, + ] + pool_args.append(disperse_chunk_args) + + t0 = time.time() + if self.max_cpu > 1: + # put this log message here to avoid printing it for every chunk + log.info(f"Using multiprocessing with {self.max_cpu} cores to compute dispersion") - def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): + disperse_chunk_output = [] + for i in range(len(self.source_ids)): + disperse_chunk_output.append(self.disperse_chunk(*pool_args[i])) + t1 = time.time() + log.info(f"Wall clock time for disperse_chunk order {order}: {(t1 - t0):.1f} sec") + + # Collect results into simulated image and slit models + for this_output in disperse_chunk_output: + [this_image, this_bounds, this_sid, this_order] = this_output + slit = self.construct_slitmodel_for_chunk(this_image, this_bounds, this_sid, this_order) + self.simulated_image += this_image + if slit is not None: + self.simul_slits.slits.append(slit) + self.simul_slits_order.append(this_order) + self.simul_slits_sid.append(this_sid) + + def disperse_chunk( + self, + c, + order, + wmin, + wmax, + sens_waves, + sens_resp, + ): """ Compute dispersion for a single source; to be called after create_pixel_list(). @@ -277,8 +361,14 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): Returns ------- - np.ndarray - 2D dispersed image for this source + this_object : np.ndarray + 2D array of dispersed pixel values for the source + thisobj_bounds : list + [minx, maxx, miny, maxy] bounds of the object + sid : int + Source ID + order : int + Spectral order number """ sid = int(self.source_ids[c]) self.order = order @@ -286,26 +376,26 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): self.wmax = wmax self.sens_waves = sens_waves self.sens_resp = sens_resp - log.info(f"Dispersing source {sid}, order {self.order}") - pars = [] # initialize params for this object + log.info( + f"Dispersing source {sid}, order {self.order}. " + f"Source contains {len(self.xs[c])} pixels." + ) # Loop over all pixels in list for object "c" - log.debug(f"source contains {len(self.xs[c])} pixels") + pars = [] # initialize params for this object for i in range(len(self.xs[c])): - # Here "i" just indexes the pixel list for the object being processed - + # Here "i" is just an index into the pixel list for the object + # being processed, as opposed to the ID number of the object itself # xc, yc are the coordinates of the central pixel of the group # of pixels surrounding the direct image pixel index width = 1.0 height = 1.0 xc = self.xs[c][i] + 0.5 * width yc = self.ys[c][i] + 0.5 * height - # "lams" is the array of wavelengths previously stored in flux list # and correspond to the central wavelengths of the filters used in # the input direct image(s). For the simple case of 1 combined direct image, # this contains a single value (e.g. 4.44 for F444W). - # "fluxes" is the array of pixel values from the direct image(s). # For the simple case of 1 combined direct image, this contains a # a single value (just like "lams"). @@ -320,8 +410,7 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): strict=True, ), ) - - pars_i = ( + pars_i = [ xc, yc, width, @@ -335,23 +424,21 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): self.sens_resp, self.seg_wcs, self.grism_wcs, - i, # TODO: this is not the source_id as the docstring to dispersed_pixel says + i, self.dims[::-1], 2, self.extrapolate_sed, self.xoffset, self.yoffset, - ) - + ] pars.append(pars_i) - # now have full pars list for all pixels for this object + # if i == 0: + # print([type(arg) for arg in pars_i]) #all these need to be pickle-able + # pass parameters into dispersed_pixel, either using multiprocessing or not time1 = time.time() if self.max_cpu > 1: - ctx = multiprocessing.get_context("forkserver") - mypool = ctx.Pool(self.max_cpu) # Create the pool - all_res = mypool.imap_unordered(dispersed_pixel, pars) # Fill the pool - mypool.close() # Drain the pool + all_res = _disperse_multiprocess(pars, self.max_cpu) else: all_res = [] for i in range(len(pars)): @@ -359,8 +446,8 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): # Initialize blank image for this source this_object = np.zeros(self.dims, float) - nres = 0 + bounds = [] for pp in all_res: if pp is None: continue @@ -381,8 +468,8 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): ).toarray() # Accumulate results into simulated images - self.simulated_image[miny : maxy + 1, minx : maxx + 1] += a this_object[miny : maxy + 1, minx : maxx + 1] += a + bounds.append([minx, maxx, miny, maxy]) if self.cache: self.cached_object[c]["x"].append(x) @@ -396,93 +483,120 @@ def disperse_chunk(self, c, order, wmin, wmax, sens_waves, sens_resp): time2 = time.time() log.debug(f"Elapsed time {time2 - time1} sec") - - return this_object - - def disperse_all_from_cache(self, trans=None): + # figure out global bounds of object + if len(bounds) > 0: + bounds = np.array(bounds) + thisobj_minx = int(np.min(bounds[:, 0])) + thisobj_maxx = int(np.max(bounds[:, 1])) + thisobj_miny = int(np.min(bounds[:, 2])) + thisobj_maxy = int(np.max(bounds[:, 3])) + thisobj_bounds = [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] + return (this_object, thisobj_bounds, sid, order) + return (this_object, None, sid, order) + + @staticmethod + def construct_slitmodel_for_chunk( + chunk_data, + bounds, + sid, + order, + ): """ - Compute dispersed pixel values for all sources identified in the segmentation map. - - Load data from cache where available. Currently not used. + Turn output image from a chunk into a slit model. Parameters ---------- - trans : function - Transmission function to apply to the flux values + chunk_data : np.ndarray + Dispersed model of segmentation map source + bounds : list + The bounds of the object + sid : int + The source ID + order : int + The spectral order number Returns ------- - np.ndarray - 2D dispersed image for this source - - Notes - ----- - The return value of `this_object` appears to be a bug. - However, this is currently not used, and if the INS team wants to re-enable - caching, all functions here need updating anyway, so not fixing at this time. + slit : `jwst.datamodels.SlitModel` + Slit model containing the dispersed pixel values """ - if not self.cache: - return + if bounds is None: + return None - self.simulated_image = np.zeros(self.dims, float) + [thisobj_minx, thisobj_maxx, thisobj_miny, thisobj_maxy] = bounds + slit = datamodels.SlitModel() + slit.source_id = sid + slit.name = f"source_{sid}" + slit.xstart = thisobj_minx + slit.xsize = thisobj_maxx - thisobj_minx + 1 + slit.ystart = thisobj_miny + slit.ysize = thisobj_maxy - thisobj_miny + 1 + slit.meta.wcsinfo.spectral_order = order + slit.data = chunk_data[thisobj_miny : thisobj_maxy + 1, thisobj_minx : thisobj_maxx + 1] - for i in range(len(self.source_ids)): - this_object = self.disperse_chunk_from_cache(i, trans=trans) + return slit - return this_object - def disperse_chunk_from_cache(self, c, trans=None): - """ - Compute dispersion for a single source; to be called after create_pixel_list(). +# class ObservationFromCache: +# ''' +# this isn't how it should work. If we're going to use a cache, we should +# be checking if a pixel is in the cache before dispersing it. Then load it if it's there, +# otherwise calculate it. The functions below need to be refactored. +# ''' - Load data from cache where available. Currently not used. +# def __init__(self, cache, dims): +# self.cache = cache +# self.dims = dims +# self.simulated_image = np.zeros(self.dims, float) +# self.cached_object = {} - Parameters - ---------- - c : int - Chunk (source) number to process - trans : function - Transmission function to apply to the flux values +# def disperse_all_from_cache(self, trans=None): +# if not self.cache: +# return - Returns - ------- - np.ndarray - 2D dispersed image for this source - """ - if not self.cache: - return +# self.simulated_image = np.zeros(self.dims, float) - time1 = time.time() +# for i in range(len(self.source_ids)): +# this_object = self.disperse_chunk_from_cache(i, trans=trans) - # Initialize blank image for this object - this_object = np.zeros(self.dims, float) +# return this_object - if trans is not None: - log.debug("Applying a transmission function...") +# def disperse_chunk_from_cache(self, c, trans=None): +# """Method that handles the dispersion. To be called after create_pixel_list()""" - for i in range(len(self.cached_object[c]["x"])): - x = self.cached_object[c]["x"][i] - y = self.cached_object[c]["y"][i] - f = self.cached_object[c]["f"][i] * 1.0 - w = self.cached_object[c]["w"][i] +# if not self.cache: +# return - if trans is not None: - f *= trans(w) +# time1 = time.time() - minx = self.cached_object[c]["minx"][i] - maxx = self.cached_object[c]["maxx"][i] - miny = self.cached_object[c]["miny"][i] - maxy = self.cached_object[c]["maxy"][i] +# # Initialize blank image for this object +# this_object = np.zeros(self.dims, float) - a = sparse.coo_matrix( - (f, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1) - ).toarray() +# if trans is not None: +# log.debug("Applying a transmission function...") - # Accumulate the results into the simulated images - self.simulated_image[miny : maxy + 1, minx : maxx + 1] += a - this_object[miny : maxy + 1, minx : maxx + 1] += a +# for i in range(len(self.cached_object[c]['x'])): +# x = self.cached_object[c]['x'][i] +# y = self.cached_object[c]['y'][i] +# f = self.cached_object[c]['f'][i] * 1. +# w = self.cached_object[c]['w'][i] - time2 = time.time() - log.debug(f"Elapsed time {time2 - time1} sec") +# if trans is not None: +# f *= trans(w) + +# minx = self.cached_object[c]['minx'][i] +# maxx = self.cached_object[c]['maxx'][i] +# miny = self.cached_object[c]['miny'][i] +# maxy = self.cached_object[c]['maxy'][i] + +# a = sparse.coo_matrix((f, (y - miny, x - minx)), +# shape=(maxy - miny + 1, maxx - minx + 1)).toarray() + +# # Accumulate the results into the simulated images +# self.simulated_image[miny:maxy + 1, minx:maxx + 1] += a +# this_object[miny:maxy + 1, minx:maxx + 1] += a + +# time2 = time.time() +# log.debug(f"Elapsed time {time2-time1} sec") - return this_object +# return this_object diff --git a/jwst/wfss_contam/tests/test_disperse.py b/jwst/wfss_contam/tests/test_disperse.py new file mode 100644 index 0000000000..4ed558fade --- /dev/null +++ b/jwst/wfss_contam/tests/test_disperse.py @@ -0,0 +1,31 @@ +import pytest +import numpy as np +from jwst.wfss_contam.disperse import flux_interpolator_injector, determine_wl_spacing + +''' +Note that main disperse.py call is tested in test_observations.py because +it requires all the fixtures defined there. +''' + +@pytest.mark.parametrize("lams, flxs, extrapolate_sed, expected_outside_bounds", + [([1, 3], [1, 3], False, 0), + ([2], [2], False, 2), + ([1, 3], [1, 3], True, 4)]) +def test_interpolate_fluxes(lams, flxs, extrapolate_sed, expected_outside_bounds): + + flux_interpf = flux_interpolator_injector(lams, flxs, extrapolate_sed) + assert flux_interpf(2.0) == 2.0 + assert flux_interpf(4.0) == expected_outside_bounds + + +@pytest.mark.parametrize("lams, expected_dw", + [([1, 1.2, 1.4], 0.05), + ([1, 1.02, 1.04], 0.01) + ]) +def test_determine_wl_spacing(lams, expected_dw): + + dw = 0.1 + oversample_factor = 2 + dw_out = determine_wl_spacing(dw, np.array(lams), oversample_factor) + + assert np.isclose(dw_out, expected_dw, atol=1e-8) diff --git a/jwst/wfss_contam/tests/test_observations.py b/jwst/wfss_contam/tests/test_observations.py index f4ee4dc4e0..7fa7c409f9 100644 --- a/jwst/wfss_contam/tests/test_observations.py +++ b/jwst/wfss_contam/tests/test_observations.py @@ -9,14 +9,18 @@ from photutils.datasets import make_100gaussians_image from photutils.segmentation import SourceFinder -from jwst.wfss_contam.observations import background_subtract +from jwst.wfss_contam.observations import background_subtract, _select_ids, Observation from jwst.wfss_contam.disperse import dispersed_pixel from jwst.wfss_contam.tests import data -from jwst.datamodels import SegmentationMapModel, ImageModel # type: ignore[attr-defined] +from stdatamodels.jwst.datamodels import SegmentationMapModel, ImageModel, MultiSlitModel data_path = os.path.split(os.path.abspath(data.__file__))[0] DIR_IMAGE = "direct_image.fits" +#@pytest.fixture(scope='module') +#def create_source_catalog(): +# '''Mock source catalog''' +# pass @pytest.fixture(scope='module') def direct_image(): @@ -49,8 +53,8 @@ def segmentation_map(direct_image): # turn this into a jwst datamodel model = SegmentationMapModel(data=segm.data) - asdf_file = asdf.open(os.path.join(data_path, "segmentation_wcs.asdf")) - wcsobj = asdf_file.tree['wcs'] + with asdf.open(os.path.join(data_path, "segmentation_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] model.meta.wcs = wcsobj return model @@ -58,18 +62,161 @@ def segmentation_map(direct_image): @pytest.fixture(scope='module') def grism_wcs(): - asdf_file = asdf.open(os.path.join(data_path, "grism_wcs.asdf")) - wcsobj = asdf_file.tree['wcs'] + with asdf.open(os.path.join(data_path, "grism_wcs.asdf")) as asdf_file: + wcsobj = asdf_file.tree['wcs'] return wcsobj +@pytest.fixture(scope='module') +def observation(direct_image_with_gradient, segmentation_map, grism_wcs): + ''' + set up observation object with mock data. + direct_image_with_gradient still needs to be run to produce the file, + even though it is not called directly + ''' + filter_name = "F200W" + seg = segmentation_map.data + all_ids = np.array(list(set(np.ravel(seg)))) + source_ids = all_ids[50:52] + obs = Observation([DIR_IMAGE], segmentation_map, grism_wcs, filter_name, source_id=source_ids, + sed_file=None, extrapolate_sed=False, + boundaries=[], offsets=[0, 0], renormalize=True, max_cpu=1) + return obs + + +@pytest.mark.parametrize("source_id, expected", [(None, [1,2,3]), (2, [2]), (np.array([1,3]), [1,3])]) +def test_select_ids(source_id, expected): + all_ids = [1, 2, 3] + assert _select_ids(source_id, all_ids) == expected + assert isinstance(_select_ids(source_id, all_ids), list) + + +def test_select_ids_expected_raises(): + with pytest.raises(ValueError): + _select_ids("all", [1, 2, 3]) + + def test_background_subtract(direct_image_with_gradient): + data = direct_image_with_gradient.data subtracted_data = background_subtract(data) mean, median, stddev = sigma_clipped_stats(subtracted_data, sigma=3.0) assert np.isclose(mean, 0.0, atol=0.2*stddev) +def test_create_pixel_list(observation, segmentation_map): + ''' + create_pixel_list is called on initialization + compare the pixel lists to values determined directly + from segmentation map + + Note: still need test coverage for flux dictionary + ''' + seg = segmentation_map.data + all_ids = np.array(list(set(np.ravel(seg)))) + source_ids = all_ids[50:52] + for i, source_id in enumerate(source_ids): + pixels_y, pixels_x = np.where(seg == source_id) + assert np.all(np.isin(observation.xs[i], pixels_x)) + assert np.all(np.isin(observation.ys[i], pixels_y)) + assert len(observation.fluxes[2.0][i]) == pixels_x.size + + +def test_disperse_chunk(observation): + ''' + Note: it's not obvious how to get a trivial flux example from first principles + even setting all input fluxes in dict to 1, because transforms change + pixel areas in nontrivial ways. seems a bad idea to write a test that + asserts the answer as it is currently, in case step gets updated slightly + in the future + ''' + obs = observation + i = 1 + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # set all fluxes to unity to try to make a trivial example + obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + #trivial bookkeeping + assert sid == obs.source_ids[i] + assert order == order_out + + # check size of object is same as input dims + assert chunk.shape == obs.dims + + #check that the chunk is zero outside the bounds + assert np.all(chunk[:chunk_bounds[2]-1,:] == 0) + assert np.all(chunk[chunk_bounds[3]+1:,] == 0) + assert np.all(chunk[:,:chunk_bounds[0]-1] == 0) + assert np.all(chunk[:,chunk_bounds[1]+1:] == 0) + + +def test_disperse_chunk_null(observation): + ''' + ensure bounds return None when all dispersion is off image + ''' + obs = observation + i = 0 #i==0 source happens to be far left on image + order = 3 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + assert chunk_bounds is None + assert np.all(chunk == 0) + + +def test_disperse_all(observation): + + obs = observation + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # shorten pixel list to make this test take less time + obs.xs = obs.xs[:3] + obs.ys = obs.ys[:3] + obs.fluxes[2.0] = obs.fluxes[2.0][:3] + obs.disperse_all(order, wmin, wmax, sens_waves, sens_resp, cache=False) + + # test simulated image. should be mostly but not all zeros + assert obs.simulated_image.shape == obs.dims + assert not np.allclose(obs.simulated_image, 0.0) + assert np.median(obs.simulated_image) == 0.0 + + # test simulated slits and their associated metadata + # only the second of the two obs ids is in the simulated image + assert obs.simul_slits_order == [order,]*1 + assert obs.simul_slits_sid == obs.source_ids[-1:] + assert type(obs.simul_slits) == MultiSlitModel + + def test_disperse_oversample_same_result(grism_wcs, segmentation_map): ''' Coverage for bug where wavelength oversampling led to double-counted fluxes @@ -102,11 +249,45 @@ def test_disperse_oversample_same_result(grism_wcs, segmentation_map): sens_waves, sens_resp, seg_wcs, grism_wcs, source_id, naxis, oversample_factor=1, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) - + xs, ys, areas, lams_out, counts_3, source_id = dispersed_pixel( x0, y0, width, height, lams, flxs, order, wmin, wmax, sens_waves, sens_resp, seg_wcs, grism_wcs, source_id, naxis, oversample_factor=3, extrapolate_sed=False, xoffset=xoffset, yoffset=yoffset) - assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1/sens_waves.size) + assert np.isclose(np.sum(counts_1), np.sum(counts_3), rtol=1e-2) + + +def test_construct_slitmodel_for_chunk(observation): + ''' + test that the chunk is constructed correctly + ''' + obs = observation + i = 1 + order = 1 + sens_waves = np.linspace(1.708, 2.28, 100) + wmin, wmax = np.min(sens_waves), np.max(sens_waves) + sens_resp = np.ones(100) + + # manually change x,y offset because took transform from a real direct image, with different + # pixel 0,0 than the mock data. This puts i=1, order 1 onto the real grism image + obs.xoffset = 2200 + obs.yoffset = 1000 + + # set all fluxes to unity to try to make a trivial example + obs.fluxes[2.0][i] = np.ones(obs.fluxes[2.0][i].shape) + + disperse_chunk_args = [i, order, wmin, wmax, sens_waves, sens_resp] + (chunk, chunk_bounds, sid, order_out) = obs.disperse_chunk(*disperse_chunk_args) + + slit = obs.construct_slitmodel_for_chunk(chunk, chunk_bounds, sid, order_out) + + # check that the metadata is correct + assert slit.xstart == chunk_bounds[0] + assert slit.xsize == chunk_bounds[1] - chunk_bounds[0] + 1 + assert slit.ystart == chunk_bounds[2] + assert slit.ysize == chunk_bounds[3] - chunk_bounds[2] + 1 + assert slit.source_id == sid + assert slit.meta.wcsinfo.spectral_order == order_out + assert np.allclose(slit.data, chunk[chunk_bounds[2]:chunk_bounds[3]+1, chunk_bounds[0]:chunk_bounds[1]+1]) diff --git a/jwst/wfss_contam/tests/test_wfss_contam.py b/jwst/wfss_contam/tests/test_wfss_contam.py new file mode 100644 index 0000000000..cb3b8cf25b --- /dev/null +++ b/jwst/wfss_contam/tests/test_wfss_contam.py @@ -0,0 +1,125 @@ +import pytest +from jwst.wfss_contam.wfss_contam import ( + match_backplane_prefer_first, + match_backplane_encompass_both, + SlitOverlapError, + UnmatchedSlitIDError, + determine_multiprocessing_ncores, + _cut_frame_to_match_slit, + _find_matching_simul_slit +) +from stdatamodels.jwst.datamodels import SlitModel +import numpy as np + + +@pytest.mark.parametrize("max_cores, num_cores, expected", + [("none", 4, 1), + ("quarter", 4, 1), + ("half", 4, 2), + ("all", 4, 4), + ("none", 1, 1), + (None, 1, 1,), + (3, 5, 3), + (100, 5, 5)]) +def test_determine_multiprocessing_ncores(max_cores, num_cores, expected): + assert determine_multiprocessing_ncores(max_cores, num_cores) == expected + + +@pytest.fixture(scope="module") +def contam(): + return np.ones((10, 10))*0.1 + +@pytest.fixture(scope="module") +def slit0(): + slit = SlitModel(data=np.ones((5, 3))) + slit.xstart = 2 + slit.ystart = 3 + slit.xsize = 3 + slit.ysize = 5 + slit.meta.wcsinfo.spectral_order = 1 + slit.source_id = 1 + return slit + + +@pytest.fixture(scope="module") +def slit1(): + slit = SlitModel(data=np.ones((4, 4))*0.5) + slit.xstart = 3 + slit.ystart = 2 + slit.xsize = 4 + slit.ysize = 4 + return slit + + +@pytest.fixture(scope="module") +def slit2(): + slit = SlitModel(data=np.ones((3, 5))*0.1) + slit.xstart = 300 + slit.ystart = 200 + slit.xsize = 5 + slit.ysize = 3 + return slit + + +def test_find_matching_simul_slit(slit0): + sids = [0, 1, 1] + orders = [1, 1, 2] + idx = _find_matching_simul_slit(slit0, sids, orders) + assert idx == 1 + + +def test_find_matching_simul_slit_no_match(slit0): + sids = [0, 1, 1] + orders = [1, 2, 2] + with pytest.raises(UnmatchedSlitIDError): + _find_matching_simul_slit(slit0, sids, orders) + + +def test_cut_frame_to_match_slit(slit0, contam): + cut_contam = _cut_frame_to_match_slit(contam, slit0) + assert cut_contam.shape == (5, 3) + assert np.all(cut_contam == 0.1) + + +def test_common_slit_encompass(slit0, slit1): + slit0_final, slit1_final = match_backplane_encompass_both(slit0.copy(), slit1.copy()) + + # check indexing in metadata + assert slit0_final.xstart == slit1_final.xstart + assert slit0_final.ystart == slit1_final.ystart + assert slit0_final.xsize == slit1_final.xsize + assert slit0_final.ysize == slit1_final.ysize + assert slit0_final.data.shape == slit1_final.data.shape + + # check data overlap + assert np.count_nonzero(slit0_final.data) == 15 + assert np.count_nonzero(slit1_final.data) == 16 + assert np.count_nonzero(slit0_final.data * slit1_final.data) == 6 + + # check data values + assert np.all(slit0_final.data[1:6, 0:3] == 1) + assert np.all(slit1_final.data[0:4, 1:5] == 0.5) + + +def test_common_slit_prefer(slit0, slit1): + + slit0_final, slit1_final = match_backplane_prefer_first(slit0.copy(), slit1.copy()) + assert slit0_final.xstart == slit0.xstart + assert slit0_final.ystart == slit0.ystart + assert slit0_final.xsize == slit0.xsize + assert slit0_final.ysize == slit0.ysize + assert slit0_final.data.shape == slit0.data.shape + assert np.all(slit0_final.data == slit0.data) + + assert slit1_final.xstart == slit0.xstart + assert slit1_final.ystart == slit0.ystart + assert slit1_final.xsize == slit0.xsize + assert slit1_final.ysize == slit0.ysize + assert slit1_final.data.shape == slit0.data.shape + assert np.count_nonzero(slit1_final.data) == 6 + + +def test_common_slit_prefer_expected_raise(slit0, slit2): + + with pytest.raises(SlitOverlapError): + match_backplane_prefer_first(slit0.copy(), slit2.copy()) \ No newline at end of file diff --git a/jwst/wfss_contam/wfss_contam.py b/jwst/wfss_contam/wfss_contam.py index 6fae4b53e9..d5639c598a 100644 --- a/jwst/wfss_contam/wfss_contam.py +++ b/jwst/wfss_contam/wfss_contam.py @@ -5,6 +5,8 @@ import numpy as np from stdatamodels.jwst import datamodels +from astropy.table import Table +import copy from .observations import Observation from .sens1d import get_photom_data @@ -13,7 +15,228 @@ log.setLevel(logging.DEBUG) -def contam_corr(input_model, waverange, photom, max_cores): +def determine_multiprocessing_ncores(max_cores, num_cores): + """ + Determine the number of cores to use for multiprocessing. + + Parameters + ---------- + max_cores : str or int + Number of cores to use for multiprocessing. If set to 'none' + (the default), then no multiprocessing will be done. The other + allowable string values are 'quarter', 'half', and 'all', which indicate + the fraction of cores to use for multi-proc. The total number of + cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. + num_cores : int + Number of cores available on the machine + + Returns + ------- + ncpus : int + Number of cores to use for multiprocessing + """ + match max_cores: + case "none": + return 1 + case None: + return 1 + case "quarter": + return num_cores // 4 or 1 + case "half": + return num_cores // 2 or 1 + case "all": + return num_cores + case int(): + if max_cores <= num_cores and max_cores > 0: + return max_cores + log.warning( + f"Requested {max_cores} cores exceeds the number of cores available " + "on this machine ({num_cores}). Using all available cores." + ) + return num_cores + case _: + raise ValueError(f"Invalid value for max_cores: {max_cores}") + + +class UnmatchedSlitIDError(Exception): + """Exception raised when a slit ID is not found in the list of simulated slits.""" + + pass + + +def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders): + """ + Find the index of the matching simulated slit in the list of simulated slits. + + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Source slit model + simul_slit_sids : list + List of source IDs for simulated slits + simul_slit_orders : list + List of spectral orders for simulated slits + + Returns + ------- + good_idx : int + Index of the matching simulated slit in the list of simulated slits + """ + sid = slit.source_id + order = slit.meta.wcsinfo.spectral_order + good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order) + if not any(good): + raise UnmatchedSlitIDError( + f"Source ID {sid} order {order} requested by input slit model " + "but not found in simulated slits. " + "Setting contamination correction to zero for that slit." + ) + return np.where(good)[0][0] + + +def _cut_frame_to_match_slit(contam, slit): + """ + Cut out the contamination image to match the extent of the source slit. + + Parameters + ---------- + contam : 2D array + Contamination image for the full grism exposure + slit : `~jwst.datamodels.SlitModel` + Source slit model + + Returns + ------- + cutout : 2D array + Contamination image cutout that matches the extent of the source slit + """ + x1 = slit.xstart + y1 = slit.ystart + xf = x1 + slit.xsize + yf = y1 + slit.ysize + + # zero-pad the contamination image if the slit extends beyond the contamination image + # fixes an off-by-one bug when sources extend to the edge of the contamination image + if xf > contam.shape[1]: + contam = np.pad(contam, ((0, 0), (0, xf - contam.shape[1])), mode="constant") + if yf > contam.shape[0]: + contam = np.pad(contam, ((0, yf - contam.shape[0]), (0, 0)), mode="constant") + + return contam[y1 : y1 + slit.ysize, x1 : x1 + slit.xsize] + + +class SlitOverlapError(Exception): + """Exception raised when there is no overlap between data and model for a slit.""" + + pass + + +def match_backplane_prefer_first(slit0, slit1): + """ + Reshape slit1 to the backplane of slit0. + + Parameters + ---------- + slit0 : `~jwst.datamodels.SlitModel` + Slit model for the first slit, which is used as reference. + slit1 : `~jwst.datamodels.SlitModel` + Slit model for the second slit, which is reshaped to match slit0. + + Returns + ------- + slit0, slit1 : `~jwst.datamodels.SlitModel` + Reshaped slit models slit0, slit1. + """ + data0 = slit0.data + data1 = slit1.data + + x1 = slit1.xstart - slit0.xstart + y1 = slit1.ystart - slit0.ystart + backplane1 = np.zeros_like(data0) + + i0 = max([y1, 0]) + i1 = min([y1 + data1.shape[0], data0.shape[0], data1.shape[0]]) + j0 = max([x1, 0]) + j1 = min([x1 + data1.shape[1], data0.shape[1], data1.shape[1]]) + if i0 >= i1 or j0 >= j1: + raise SlitOverlapError( + f"No overlap region between data and model for slit {slit0.source_id}, " + f"order {slit0.meta.wcsinfo.spectral_order}. " + "setting contamination correction to zero for that slit." + ) + + backplane1[i0:i1, j0:j1] = data1[i0:i1, j0:j1] + + slit1.data = backplane1 + slit1.xstart = slit0.xstart + slit1.ystart = slit0.ystart + slit1.xsize = slit0.xsize + slit1.ysize = slit0.ysize + + return slit0, slit1 + + +def match_backplane_encompass_both(slit0, slit1): + """ + Put data from the two slits into a common backplane, encompassing both. + + Slits are zero-padded where their new extent does not overlap with the original data. + + Parameters + ---------- + slit0, slit1 : `~jwst.datamodels.SlitModel` + Slit model for the first and second slit. + + Returns + ------- + slit0, slit1 : `~jwst.datamodels.SlitModel` + Reshaped slit models slit0, slit1. + """ + data0 = slit0.data + data1 = slit1.data + + shape = (max(data0.shape[0], data1.shape[0]), max(data0.shape[1], data1.shape[1])) + xmin = min(slit0.xstart, slit1.xstart) + ymin = min(slit0.ystart, slit1.ystart) + shape = ( + max( + slit0.xsize + slit0.xstart - xmin, + slit1.xsize + slit1.xstart - xmin, + ), + max( + slit0.ysize + slit0.ystart - ymin, + slit1.ysize + slit1.ystart - ymin, + ), + ) + x0 = slit0.xstart - xmin + y0 = slit0.ystart - ymin + x1 = slit1.xstart - xmin + y1 = slit1.ystart - ymin + + backplane0 = np.zeros(shape).T + backplane0[y0 : y0 + data0.shape[0], x0 : x0 + data0.shape[1]] = data0 + backplane1 = np.zeros(shape).T + backplane1[y1 : y1 + data1.shape[0], x1 : x1 + data1.shape[1]] = data1 + + slit0.data = backplane0 + slit1.data = backplane1 + for slit in [slit0, slit1]: + slit.xstart = xmin + slit.ystart = ymin + slit.xsize = shape[0] + slit.ysize = shape[1] + + return slit0, slit1 + + +def contam_corr( + input_model, + waverange, + photom, + max_cores, + brightest_n, +): """ Correct contamination in WFSS spectral cutouts. @@ -25,12 +248,19 @@ def contam_corr(input_model, waverange, photom, max_cores): Wavelength range reference file model photom : `~jwst.datamodels.NrcWfssPhotomModel` or `~jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model - max_cores : str + max_cores : str or int Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other - allowable values are 'quarter', 'half', and 'all', which indicate + allowable string values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). + If an integer is provided, it will be the exact number of cores used. + brightest_n : int + Number of sources to simulate. If None, then all sources in the + input model will be simulated. Requires loading the source catalog + file if not None. Note runtime scales non-linearly with this number + because brightest (and therefore typically largest) sources are + simulated first. Returns ------- @@ -41,28 +271,14 @@ def contam_corr(input_model, waverange, photom, max_cores): contam_model : `~jwst.datamodels.MultiSlitModel` Contamination estimate images for each source slit """ - # Determine number of cpu's to use for multi-processing - if max_cores == "none": - ncpus = 1 - else: - num_cores = multiprocessing.cpu_count() - if max_cores == "quarter": - ncpus = num_cores // 4 or 1 - elif max_cores == "half": - ncpus = num_cores // 2 or 1 - elif max_cores == "all": - ncpus = num_cores - else: - ncpus = 1 - log.debug(f"Found {num_cores} cores; using {ncpus}") + num_cores = multiprocessing.cpu_count() + ncpus = determine_multiprocessing_ncores(max_cores, num_cores) # Initialize output model output_model = input_model.copy() - # Get the segmentation map for this grism exposure + # Get the segmentation map, direct image for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) - - # Get the direct image from which the segmentation map was constructed direct_file = input_model.meta.direct_image image_names = [direct_file] log.debug(f"Direct image names={image_names}") @@ -96,23 +312,16 @@ def contam_corr(input_model, waverange, photom, max_cores): else: filter_name = filter_kwd - # Load lists of wavelength ranges and flux cal info for all orders - wmin = {} - wmax = {} - sens_waves = {} - sens_response = {} - for order in spec_orders: - wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) - wmin[order] = wavelength_range[order][0] - wmax[order] = wavelength_range[order][1] - # Load the sensitivity (inverse flux cal) data for this mode and order - sens_waves[order], sens_response[order] = get_photom_data( - photom, filter_kwd, pupil_kwd, order - ) - log.debug(f"wmin={wmin}, wmax={wmax}") + # select a subset of the brightest sources using source catalog + if brightest_n is not None: + log.info(f"Simulating only the brightest {brightest_n} sources") + source_catalog = Table.read(input_model.meta.source_catalog, format="ascii.ecsv") + # magnitudes in ascending order, since brighter is smaller mag number + source_catalog.sort("isophotal_abmag", reverse=False) + selected_ids = list(source_catalog["label"])[:brightest_n] + else: + selected_ids = None - # Initialize the simulated image object - simul_all = None obs = Observation( image_names, seg_model, @@ -121,85 +330,66 @@ def contam_corr(input_model, waverange, photom, max_cores): boundaries=[0, 2047, 0, 2047], offsets=[xoffset, yoffset], max_cpu=ncpus, + source_id=selected_ids, ) - # Create simulated grism image for each order and sum them up + # Initialize output multislitmodel + good_slits = [slit for slit in output_model.slits if slit.source_id in obs.source_ids] + output_model = datamodels.MultiSlitModel() + output_model.update(input_model) + output_model.slits.extend(good_slits) + + simul_all = None for order in spec_orders: - log.info(f"Creating full simulated grism image for order {order}") - obs.disperse_all(order, wmin[order], wmax[order], sens_waves[order], sens_response[order]) + # Load lists of wavelength ranges and flux cal info + wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) + wmin = wavelength_range[order][0] + wmax = wavelength_range[order][1] + log.debug(f"wmin={wmin}, wmax={wmax} for order {order}") + sens_waves, sens_response = get_photom_data(photom, filter_kwd, pupil_kwd, order) - # Accumulate result for this order into the combined image + # Create simulated grism image for each order and sum them up + log.info(f"Creating full simulated grism image for order {order}") + obs.disperse_all(order, wmin, wmax, sens_waves, sens_response) if simul_all is None: simul_all = obs.simulated_image else: simul_all += obs.simulated_image - # Save the full-frame simulated grism image + # Make the full-frame simulated grism image simul_model = datamodels.ImageModel(data=simul_all) simul_model.update(input_model, only="PRIMARY") + simul_slit_sids = np.array(obs.simul_slits_sid) + simul_slit_orders = np.array(obs.simul_slits_order) + # Loop over all slits/sources to subtract contaminating spectra log.info("Creating contamination image for each individual source") contam_model = datamodels.MultiSlitModel() contam_model.update(input_model) slits = [] for slit in output_model.slits: - # Create simulated spectrum for this source only - sid = slit.source_id - order = slit.meta.wcsinfo.spectral_order - chunk = np.where(obs.source_ids == sid)[0][0] # find chunk for this source - - obs.simulated_image = np.zeros(obs.dims) - obs.disperse_chunk( - chunk, order, wmin[order], wmax[order], sens_waves[order], sens_response[order] - ) - this_source = obs.simulated_image + try: + good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders) + this_simul = obs.simul_slits.slits[good_idx] + slit, this_simul = match_backplane_prefer_first(slit, this_simul) + simul_all_cut = _cut_frame_to_match_slit(simul_all, slit) + contam_cut = simul_all_cut - this_simul.data - # Contamination estimate is full simulated image minus this source - contam = simul_all - this_source + except (UnmatchedSlitIDError, SlitOverlapError) as e: + log.warning(e) + contam_cut = np.zeros_like(slit.data) - # Create a cutout of the contam image that matches the extent - # of the source slit - x1 = slit.xstart - 1 - y1 = slit.ystart - 1 - cutout = contam[y1 : y1 + slit.ysize, x1 : x1 + slit.xsize] - new_slit = datamodels.SlitModel(data=cutout) - copy_slit_info(slit, new_slit) - slits.append(new_slit) + contam_slit = copy.copy(slit) + contam_slit.data = contam_cut + slits.append(contam_slit) - # Subtract the cutout from the source slit - slit.data -= cutout + # Subtract the contamination from the source slit + slit.data -= contam_cut # Save the contamination estimates for all slits contam_model.slits.extend(slits) - # Set the step status to COMPLETE output_model.meta.cal_step.wfss_contam = "COMPLETE" - return output_model, simul_model, contam_model - - -def copy_slit_info(input_slit, output_slit): - """ - Copy meta info from one slit to another. - - Parameters - ---------- - input_slit : SlitModel - Input slit model from which slit-specific info will be copied - - output_slit : SlitModel - Output slit model to which slit-specific info will be copied - """ - output_slit.name = input_slit.name - output_slit.xstart = input_slit.xstart - output_slit.ystart = input_slit.ystart - output_slit.xsize = input_slit.xsize - output_slit.ysize = input_slit.ysize - output_slit.source_id = input_slit.source_id - output_slit.source_type = input_slit.source_type - output_slit.source_xpos = input_slit.source_xpos - output_slit.source_ypos = input_slit.source_ypos - output_slit.meta.wcsinfo.spectral_order = input_slit.meta.wcsinfo.spectral_order - output_slit.meta.wcsinfo.dispersion_direction = input_slit.meta.wcsinfo.dispersion_direction - output_slit.meta.wcs = input_slit.meta.wcs + return output_model, simul_model, contam_model, obs.simul_slits diff --git a/jwst/wfss_contam/wfss_contam_step.py b/jwst/wfss_contam/wfss_contam_step.py index 7f925ea134..e286d3e5d9 100755 --- a/jwst/wfss_contam/wfss_contam_step.py +++ b/jwst/wfss_contam/wfss_contam_step.py @@ -18,6 +18,7 @@ class WfssContamStep(Step): save_contam_images = boolean(default=False) # Save source contam estimates maximum_cores = option('none', 'quarter', 'half', 'all', default='none') skip = boolean(default=True) + brightest_n = integer(default=None) """ # noqa: E501 reference_file_types = ["photom", "wavelengthrange"] @@ -37,8 +38,6 @@ def process(self, input_model): A copy of the input_model with contamination removed """ with datamodels.open(input_model) as dm: - max_cores = self.maximum_cores - # Get the wavelengthrange ref file waverange_ref = self.get_reference_file(dm, "wavelengthrange") self.log.info(f"Using WAVELENGTHRANGE reference file {waverange_ref}") @@ -49,17 +48,18 @@ def process(self, input_model): self.log.info(f"Using PHOTOM reference file {photom_ref}") photom_model = datamodels.open(photom_ref) - result, simul, contam = wfss_contam.contam_corr( - dm, waverange_model, photom_model, max_cores + result, simul, contam, simul_slits = wfss_contam.contam_corr( + dm, waverange_model, photom_model, self.maximum_cores, brightest_n=self.brightest_n ) # Save intermediate results, if requested if self.save_simulated_image: simul_path = self.save_model(simul, suffix="simul", force=True) self.log.info(f'Full-frame simulated grism image saved to "{simul_path}"') + simul_slits_path = self.save_model(simul_slits, suffix="simul_slits", force=True) + self.log.info(f'Simulated slits saved to "{simul_slits_path}"') if self.save_contam_images: contam_path = self.save_model(contam, suffix="contam", force=True) self.log.info(f'Contamination estimates saved to "{contam_path}"') - # Return the corrected data return result