Skip to content

Commit

Permalink
adding back opimization for nsclean
Browse files Browse the repository at this point in the history
  • Loading branch information
nden committed Sep 20, 2024
1 parent aeb44ef commit 2c21b8d
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions jwst/clean_flicker_noise/clean_flicker_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

import gwcs
from gwcs.utils import _toindex
import numpy as np
from astropy.stats import sigma_clipped_stats, SigmaClip
from astropy.utils.exceptions import AstropyUserWarning
Expand Down Expand Up @@ -166,11 +167,14 @@ def mask_ifu_slices(input_model, mask):
# Initialize global DQ map to all zero (OK to use)
dqmap = np.zeros_like(input_model.dq)

# Get the wcs objects for all IFU slices
list_of_wcs = nirspec.nrs_ifu_wcs(input_model)
# Note: 30 in the line below is hardcoded in nirspec.nrs.ifu_wcs, which
# the line below replaces.
wcsobj, tr1, tr2, tr3 = nirspec._get_transforms(input_model, np.arange(30))

Check warning on line 172 in jwst/clean_flicker_noise/clean_flicker_noise.py

View check run for this annotation

Codecov / codecov/patch

jwst/clean_flicker_noise/clean_flicker_noise.py#L172

Added line #L172 was not covered by tests

# Loop over the IFU slices, finding the valid region for each
for (k, ifu_wcs) in enumerate(list_of_wcs):
for k in range(len(tr2)):
ifu_wcs = nirspec._nrs_wcs_set_input_lite(input_model, wcsobj, k,

Check warning on line 176 in jwst/clean_flicker_noise/clean_flicker_noise.py

View check run for this annotation

Codecov / codecov/patch

jwst/clean_flicker_noise/clean_flicker_noise.py#L175-L176

Added lines #L175 - L176 were not covered by tests
[tr1, tr2[k], tr3[k]])

# Construct array indexes for pixels in this slice
x, y = gwcs.wcstools.grid_from_bounding_box(
Expand Down Expand Up @@ -220,8 +224,6 @@ def mask_slits(input_model, mask):
2D output mask with additional flags for slit pixels
"""

from jwst.extract_2d.nirspec import offset_wcs

log.info("Finding slit/slitlet pixels")

# Get the slit-to-msa frame transform from the WCS object
Expand All @@ -230,9 +232,17 @@ def mask_slits(input_model, mask):
# Loop over the slits, marking all the pixels within each bounding
# box as False (do not use) in the mask.
# Note that for 3D masks (TSO mode), all planes will be set to the same value.
for slit in slit2msa.slits:
slit_wcs = nirspec.nrs_wcs_set_input(input_model, slit.name)
xlo, xhi, ylo, yhi = offset_wcs(slit_wcs)

slits = [s.name for s in slit2msa.slits]
wcsobj, tr1, tr2, tr3, open_slits = nirspec._get_transforms(input_model, slits, return_slits=True)

for k in range(len(tr2)):
slit_wcs = nirspec._nrs_wcs_set_input_lite(input_model, wcsobj, slits[k],
[tr1, tr2[k], tr3[k]],
open_slits=open_slits)

xlo, xhi = _toindex(slit_wcs.bounding_box[0])
ylo, yhi = _toindex(slit_wcs.bounding_box[1])
mask[..., ylo:yhi, xlo:xhi] = False

return mask
Expand Down Expand Up @@ -691,14 +701,14 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# basically copied from lib.py. Use a robust estimator for
# standard deviation, then exclude discrepant pixels and their
# four nearest neighbors from the fit.

if exclude_outliers:
med = np.median(image[mask])
std = 1.4825 * np.median(np.abs((image - med)[mask]))
outlier = mask & (np.abs(image - med) > sigrej * std)

mask = mask & (~outlier)

# also get four nearest neighbors of flagged pixels
mask[1:] = mask[1:] & (~outlier[:-1])
mask[:-1] = mask[:-1] & (~outlier[1:])
Expand All @@ -713,7 +723,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

# i1 will be the first row with a nonzero element in the mask
# imax will be the last row with a nonzero element in the mask

nonzero_mask_element = np.sum(mask, axis=1) > 0

if np.sum(nonzero_mask_element) == 0:
Expand All @@ -722,7 +732,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

i1 = np.amin(np.arange(mask.shape[0])[nonzero_mask_element])
imax = np.amax(np.arange(mask.shape[0])[nonzero_mask_element])

i1_vals = []
di_list = []
models = []
Expand All @@ -736,7 +746,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
if (sum_mask[k] - sum_mask[i1] > npix_iter
and sum_mask[-1] - sum_mask[i1] > 1.5 * npix_iter):
break

di = k - i1

i1_vals += [i1]
Expand All @@ -747,7 +757,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# outliers section-by-section; we have to do that earlier
# over the full array to get reliable values for the mean
# and standard deviation.

if np.mean(mask[i1:i1 + di]) > minfrac:
cleaner = NSCleanSubarray(image[i1:i1 + di], mask[i1:i1 + di],
fc=fc, exclude_outliers=False)
Expand All @@ -767,9 +777,9 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,

# Step forward by half an interval so that we have
# overlapping fitting regions.

i1 += max(int(np.round(di/2)), 1)

model = np.zeros(image.shape)
tot_wgt = np.zeros(image.shape)

Expand All @@ -779,7 +789,7 @@ def fft_clean_subarray(image, mask, detector, npix_iter=512,
# Use nonzero weights everywhere so that if only one
# correction is available it gets unit weight when we
# normalize.

for i in range(len(models)):
wgt = 1.001 - np.abs(np.linspace(-1, 1, di_list[i]))[:, np.newaxis]
model[i1_vals[i]:i1_vals[i] + di_list[i]] += wgt*models[i]
Expand Down

0 comments on commit 2c21b8d

Please sign in to comment.