Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to include average star/model in StarImages plots #167

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Output file changes
API Changes
-----------

- Changed the default behavior of the StarImages plot to include the average star and model.
To recover the old version without these images, use ``include_ave = False``. (#167)


Performance improvements
Expand All @@ -18,6 +20,7 @@ Performance improvements
New features
------------

- Added an image of the average star and model in the StarImages output plot. (#167)


Bug fixes
Expand Down
88 changes: 63 additions & 25 deletions piff/star_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,23 @@ class StarStats(Stats):
starfit center and flux to match observed star. [default: False]
:param include_reserve: Whether to inlude reserve stars. [default: True]
:param only_reserve: Whether to skip plotting non-reserve stars. [default: False]
:param include_flaggede: Whether to include plotting flagged stars. [default: False]
:param include_flagged: Whether to include plotting flagged stars. [default: False]
:param include_ave: Whether to inlude the average image. [default: True]
:param file_name: Name of the file to output to. [default: None]
:param logger: A logger object for logging debug info. [default: None]
"""
_type_name = 'StarImages'

def __init__(self, nplot=10, adjust_stars=False,
include_reserve=True, only_reserve=False, include_flagged=False,
file_name=None, logger=None):
include_ave=True, file_name=None, logger=None):
self.nplot = nplot
self.file_name = file_name
self.adjust_stars = adjust_stars
self.include_reserve = include_reserve
self.only_reserve = only_reserve
self.include_flagged = include_flagged
self.include_ave = include_ave

def compute(self, psf, stars, logger=None):
"""
Expand All @@ -82,16 +84,48 @@ def compute(self, psf, stars, logger=None):
else:
self.indices = np.random.choice(possible_indices, self.nplot, replace=False)

logger.info("Making {0} Model Stars".format(len(self.indices)))
self.stars = []
for index in self.indices:
star = stars[index]
if self.adjust_stars:
# Do 2 passes, since we sometimes start pretty far from the right values.
star = psf.reflux(star, logger=logger)
star = psf.reflux(star, logger=logger)
self.stars.append(star)
self.models = psf.drawStarList(self.stars)
# If we need to compute the average image, then we need to reflux and drawStar for all
# possible_indices. Otherwise, only do those steps for the stars we will plot.
if self.include_ave:
calculate_indices = possible_indices
else:
calculate_indices = self.indices

logger.info("Making {0} model stars".format(len(calculate_indices)))
calculated_stars = []
calculated_models = []
for i, star in enumerate(stars):
if i in calculate_indices:
if self.adjust_stars:
# Do 2 passes, since we sometimes start pretty far from the right values.
star = psf.reflux(star, logger=logger)
star = psf.reflux(star, logger=logger)
calculated_stars.append(star)
calculated_models.append(psf.drawStar(star))
else:
calculated_stars.append(None)
calculated_models.append(None)

# if including the average image, put that first.
logger.info("Making average star and model")
if self.include_ave:
ave_star_image = np.mean([s.image.array for s in calculated_stars if s is not None],
axis=0)
ave_model_image = np.mean([s.image.array for s in calculated_models if s is not None],
axis=0)
ave_star_image = galsim.Image(ave_star_image)
ave_model_image = galsim.Image(ave_model_image)
ave_star = Star(stars[0].data.withNew(image=ave_star_image), None)
ave_model = Star(stars[0].data.withNew(image=ave_model_image), None)
self.stars = [ave_star]
self.models = [ave_model]
self.stars.extend([calculated_stars[i] for i in self.indices])
self.models.extend([calculated_models[i] for i in self.indices])
self.indices = [-1] + self.indices
else:
self.stars = [calculated_stars[i] for i in self.indices]
self.models = [calculated_models[i] for i in self.indices]


def plot(self, logger=None, **kwargs):
r"""Make the plots.
Expand All @@ -115,25 +149,29 @@ def plot(self, logger=None, **kwargs):

logger.info("Creating %d Star plots", self.nplot)

for i in range(len(self.indices)):
for i in range(nplot):
star = self.stars[i]
model = self.models[i]

# get index, u, v coordinates to put in title
u = star.data.properties['u']
v = star.data.properties['v']
index = self.indices[i]

ii = i // 2
jj = (i % 2) * 3

title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
if self.include_ave and i == 0:
axs[ii][jj+0].set_title('Average Star')
axs[ii][jj+1].set_title('Average PSF')
else:
# get index, u, v coordinates to put in title
index = self.indices[i]
u = star.data.properties['u']
v = star.data.properties['v']

title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
axs[ii][jj+2].set_title('Star - PSF')

star_image = star.image
Expand Down
30 changes: 22 additions & 8 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def test_starstats_config():
'file_name': star_file,
'nplot': 5,
'adjust_stars': True,
'include_ave': False,
}
]
}
Expand All @@ -550,7 +551,7 @@ def test_starstats_config():

# check default nplot
psf = piff.read(psf_file)
starStats = piff.StarStats()
starStats = piff.StarStats(include_ave=False)
orig_stars, wcs, pointing = piff.Input.process(config['input'], logger=logger)
orig_stars = piff.Select.process(config['select'], orig_stars, logger=logger)
with np.testing.assert_raises(RuntimeError):
Expand All @@ -563,12 +564,20 @@ def test_starstats_config():
orig_stars[starStats.indices[2]].image.array)

# check nplot = 6
starStats = piff.StarStats(nplot=6)
starStats = piff.StarStats(nplot=6, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 6

starStats = piff.StarStats(nplot=6, include_ave=True)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 7

starStats = piff.StarStats(nplot=6) # include_ave=True is the default
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 7

# check nplot >> len(stars)
starStats = piff.StarStats(nplot=1000000)
starStats = piff.StarStats(nplot=1000000, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == len(orig_stars)
# if use all stars, no randomness
Expand All @@ -577,7 +586,7 @@ def test_starstats_config():
starStats.plot() # Make sure this runs without error and in finite time.

# check nplot = 0
starStats = piff.StarStats(nplot=0)
starStats = piff.StarStats(nplot=0, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == len(orig_stars)
# if use all stars, no randomness
Expand All @@ -588,20 +597,25 @@ def test_starstats_config():
# With include_reserve=False, only 8 stars
print('All stars: n=',len(starStats.stars)) # 10 stars total
assert len(starStats.stars) == 10
starStats = piff.StarStats(nplot=0, include_reserve=False)
starStats = piff.StarStats(nplot=0, include_reserve=False, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 8
starStats.plot() # Make sure this runs without error.

# With only_reserve=True, only 2 stars
starStats = piff.StarStats(nplot=0, only_reserve=True)
starStats = piff.StarStats(nplot=0, only_reserve=True, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 2
starStats.plot() # Make sure this runs without error.

starStats = piff.StarStats(nplot=0, only_reserve=True)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 3
starStats.plot() # Make sure this runs without error.

# rerun with adjust stars and see if it did the right thing
# first with adjust_stars == False
starStats = piff.StarStats(nplot=0, adjust_stars=False)
starStats = piff.StarStats(nplot=0, adjust_stars=False, include_ave=False)
starStats.compute(psf, orig_stars, logger=logger)
fluxs_noadjust = np.array([s.fit.flux for s in starStats.stars])
ds_noadjust = np.array([s.fit.center for s in starStats.stars])
Expand All @@ -611,7 +625,7 @@ def test_starstats_config():
np.testing.assert_array_equal(ds_noadjust, 0)

# now with adjust_stars == True
starStats = piff.StarStats(nplot=0, adjust_stars=True)
starStats = piff.StarStats(nplot=0, adjust_stars=True, include_ave=False)
starStats.compute(psf, orig_stars, logger=logger)
fluxs_adjust = np.array([s.fit.flux for s in starStats.stars])
ds_adjust = np.array([s.fit.center for s in starStats.stars])
Expand Down
Loading