Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-2546: Refactor wfss_contam to support multiprocessing and decrease runtime #9220

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions jwst/assign_wcs/nircam.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ def tsgrism(input_model, reference_files):
# Get the disperser parameters which are defined as a model for each
# spectral order
with NIRCAMGrismModel(reference_files['specwcs']) as f:
displ = f.displ
dispx = f.dispx
dispy = f.dispy
invdispx = f.invdispx
invdispl = f.invdispl
orders = f.orders
displ = f.displ.instance
dispx = f.dispx.instance
dispy = f.dispy.instance
invdispx = f.invdispx.instance
invdispl = f.invdispl.instance
orders = f.orders.instance

# now create the appropriate model for the grismr
det2det = NIRCAMForwardRowGrismDispersion(orders,
Expand Down Expand Up @@ -377,13 +377,13 @@ def wfss(input_model, reference_files):
# Get the disperser parameters which are defined as a model for each
# spectral order
with NIRCAMGrismModel(reference_files['specwcs']) as f:
displ = f.displ
dispx = f.dispx
dispy = f.dispy
invdispx = f.invdispx
invdispy = f.invdispy
invdispl = f.invdispl
orders = f.orders
displ = f.displ.instance
dispx = f.dispx.instance
dispy = f.dispy.instance
invdispx = f.invdispx.instance
invdispy = f.invdispy.instance
invdispl = f.invdispl.instance
orders = f.orders.instance

# now create the appropriate model for the grism[R/C]
if "GRISMR" in input_model.meta.instrument.pupil:
Expand Down
10 changes: 5 additions & 5 deletions jwst/assign_wcs/niriss.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,11 @@ def wfss(input_model, reference_files):
# Get the disperser parameters which are defined as a model for each
# spectral order
with NIRISSGrismModel(reference_files['specwcs']) as f:
dispx = f.dispx
dispy = f.dispy
displ = f.displ
invdispl = f.invdispl
orders = f.orders
dispx = f.dispx.instance
dispy = f.dispy.instance
displ = f.displ.instance
invdispl = f.invdispl.instance
orders = f.orders.instance
fwcpos_ref = f.fwcpos_ref

# This is the actual rotation from the input model
Expand Down
120 changes: 92 additions & 28 deletions jwst/wfss_contam/disperse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import numpy as np
import warnings

Expand All @@ -7,6 +8,86 @@
from .sens1d import create_1d_sens


def _flat_lam(fluxes, _lams) -> np.ndarray:
return fluxes[0]


def flux_interpolator_injector(
lams,
flxs,
extrapolate_sed,
):
"""
Create a function that returns the flux at a given wavelength.

If only one direct image is in use, this function will always return the same value.

Parameters
----------
lams : np.ndarray[float]
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.
flxs : np.ndarray[float]
Array of fluxes (flam) for the pixels contained in x0, y0. If a single
direct image is in use, this will be a single value.
extrapolate_sed : bool
Whether to allow for the SED of the object to be extrapolated
when it does not fully cover the needed wavelength range. Default if False.

Returns
-------
flux : function
Function that returns the flux at a given wavelength.
If only one direct image is in use, this function will always return the same value.
"""
if len(lams) > 1:
# If we have direct image flux values from more than one filter (lams),
# we have the option to extrapolate the fluxes outside the
# wavelength range of the direct images
if extrapolate_sed is False:
return interp1d(lams, flxs, fill_value=0.0, bounds_error=False)
else:
return interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False)
else:
# If we only have flux from one wavelength, just use that
# single flux value at all wavelengths
return partial(_flat_lam, flxs)


def determine_wl_spacing(
dw,
lams,
oversample_factor,
):
"""
Determine the wavelength spacing to use for the dispersed pixels.

Use a natural wavelength scale or the wavelength scale of the input SED/spectrum,
whichever is smaller, divided by oversampling requested

Parameters
----------
dw : float
The natural wavelength scale of the grism image
lams : np.ndarray[float]
Array of wavelengths corresponding to the fluxes (flxs) for each pixel.
One wavelength per direct image, so can be a single value.
oversample_factor : int
The amount of oversampling

Returns
-------
dlam : float
The wavelength spacing to use for the dispersed pixels
"""
#
if len(lams) > 1:
input_dlam = np.median(lams[1:] - lams[:-1])
if input_dlam < dw:
return input_dlam / oversample_factor
return dw / oversample_factor


def dispersed_pixel(
x0,
y0,
Expand Down Expand Up @@ -95,25 +176,17 @@ def dispersed_pixel(
1D array of the wavelengths of each dispersed pixel
counts : array
1D array of counts for each dispersed pixel
source_id : int
The source ID of the source being processed. Returned in the output unmodified;
used only for bookkeeping. TODO this is not implemented properly right now and
should probably just be removed.
"""
# Setup the transforms we need from the input WCS objects
sky_to_imgxy = grism_wcs.get_transform("world", "detector")
imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector")

# Setup function for retrieving flux values at each dispersed wavelength
if len(lams) > 1:
# If we have direct image flux values from more than one filter (lambda),
# we have the option to extrapolate the fluxes outside the
# wavelength range of the direct images
if extrapolate_sed is False:
flux = interp1d(lams, flxs, fill_value=0.0, bounds_error=False)
else:
flux = interp1d(lams, flxs, fill_value="extrapolate", bounds_error=False)
else:
# If we only have flux from one lambda, just use that
# single flux value at all wavelengths
def flux(_x):
return flxs[0]
# Set up function for retrieving flux values at each dispersed wavelength
flux_interpolator = flux_interpolator_injector(lams, flxs, extrapolate_sed)

# Get x/y positions in the grism image corresponding to wmin and wmax:
# Start with RA/Dec of the input pixel position in segmentation map,
Expand All @@ -127,20 +200,9 @@ def flux(_x):
dxw = xwmax - xwmin
dyw = ywmax - ywmin

# Compute the delta-wave per pixel
dw = np.abs((wmax - wmin) / (dyw - dxw))

# Use a natural wavelength scale or the wavelength scale of the input SED/spectrum,
# whichever is smaller, divided by oversampling requested
if len(lams) > 1:
input_dlam = np.median(lams[1:] - lams[:-1])
if input_dlam < dw:
dlam = input_dlam / oversample_factor
else:
# this value gets used when we only have 1 direct image wavelength
dlam = dw / oversample_factor

# Create list of wavelengths on which to compute dispersed pixels
dw = np.abs((wmax - wmin) / (dyw - dxw))
dlam = determine_wl_spacing(dw, lams, oversample_factor)
lambdas = np.arange(wmin, wmax + dlam, dlam)
n_lam = len(lambdas)

Expand Down Expand Up @@ -174,9 +236,11 @@ def flux(_x):
# values are naturally in units of physical fluxes, so we divide out
# the sensitivity (flux calibration) values to convert to units of
# countrate (DN/s).
# flux_interpolator(lams) is either single-valued (for a single direct image)
# or an array of the same length as lams (for multiple direct images in different filters)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero")
counts = flux(lams) * areas / (sens * oversample_factor)
counts = flux_interpolator(lams) * areas / (sens * oversample_factor)
counts[no_cal] = 0.0 # set to zero where no flux cal info available

return xs, ys, areas, lams, counts, source_id
Loading
Loading