-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtest_isPrimaryFlag.py
executable file
·315 lines (260 loc) · 11.8 KB
/
test_isPrimaryFlag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import os
import unittest
import numpy as np
from lsst.geom import Point2I, Box2I, Extent2I
from lsst.skymap import TractInfo
from lsst.skymap.patchInfo import PatchInfo
import lsst.afw.image as afwImage
import lsst.utils.tests
from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask, CharacterizeImageConfig
from lsst.pipe.tasks.calibrate import CalibrateTask, CalibrateConfig
from lsst.meas.algorithms import SourceDetectionTask, SkyObjectsTask, SetPrimaryFlagsTask
import lsst.meas.extensions.scarlet as mes
from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask
from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask
from lsst.meas.base import SingleFrameMeasurementTask
from lsst.afw.table import SourceCatalog
TESTDIR = os.path.abspath(os.path.dirname(__file__))
class NullTract(TractInfo):
"""A Tract not contained in the MockSkyMap.
BaseSkyMap.findTract(coord) will always return a Tract,
even if the coord isn't located in the Tract.
In order to mimick this functionality we create a
NullTract for regions of the MockSkyMap that
aren't contained in any of the tracts.
"""
def __init__(self):
pass
def getId(self):
return None
class MockTractInfo:
"""A Tract based on a bounding box and WCS.
Testing is made easier when we can specifically define
a Tract in terms of its bounding box in pixel coordinates
along with a WCS for the exposure.
Only the relevant methods from `TractInfo` needed to make
test pass are implemented here. Since this is just for
testing, it isn't sophisticated and requires developers to
ensure that the size of the bounding box is evenly divisible
by the number of patches in the Tract.
"""
def __init__(self, name, bbox, wcs, numPatches):
self.name = name
self.bbox = bbox
self.wcs = wcs
self._numPatches = numPatches
assert bbox.getWidth()%numPatches[0] == 0
assert bbox.getHeight()%numPatches[1] == 0
self.patchWidth = bbox.getWidth()//numPatches[0]
self.patchHeight = bbox.getHeight()//numPatches[1]
def contains(self, coord):
pixel = self.wcs.skyToPixel(coord)
return self.bbox.contains(Point2I(pixel))
def getId(self):
return self.name
def getNumPatches(self):
return self._numPatches
def getPatchInfo(self, index):
x, y = index
width = self.patchWidth
height = self.patchHeight
x = x*self.patchWidth
y = y*self.patchHeight
bbox = Box2I(Point2I(x, y), Extent2I(width, height))
nx, ny = self._numPatches
sequentialIndex = nx*y + x
patchInfo = PatchInfo(
index=index,
innerBBox=bbox,
outerBBox=bbox,
sequentialIndex=sequentialIndex,
tractWcs=self.wcs
)
return patchInfo
def __getitem__(self, index):
return self.getPatchInfo(index)
def __iter__(self):
xNum, yNum = self.getNumPatches()
for y in range(yNum):
for x in range(xNum):
yield self.getPatchInfo((x, y))
class MockSkyMap:
"""A SkyMap based on a list of bounding boxes.
Testing is made easier when we can specifically define
a Tract in terms of its bounding box in pixel coordinates
along with a WCS for the exposure. This class allows us
to define the tract(s) in the SkyMap and create
them.
"""
def __init__(self, bboxes, wcs, numPatches):
self.bboxes = bboxes
self.wcs = wcs
self.numPatches = numPatches
def __iter__(self):
for b, bbox in enumerate(self.bboxes):
yield self.generateTract(b)
def __getitem__(self, index):
return self.generateTract(index)
def generateTract(self, index):
return MockTractInfo(index, self.bboxes[index], self.wcs, self.numPatches)
def findTract(self, coord):
for tractInfo in self:
if tractInfo.contains(coord):
return tractInfo
return NullTract()
class IsPrimaryTestCase(lsst.utils.tests.TestCase):
def setUp(self):
# Load sample input from disk
expPath = os.path.join(TESTDIR, "data", "v695833-e0-c000-a00.sci.fits")
self.exposure = afwImage.ExposureF(expPath)
# Characterize the image (create PSF, etc.)
charImConfig = CharacterizeImageConfig()
charImConfig.measureApCorr.sourceSelector["science"].doSignalToNoise = False
charImTask = CharacterizeImageTask(config=charImConfig)
self.charImResults = charImTask.run(self.exposure)
def tearDown(self):
del self.exposure
self.charImResults
def testIsSinglePrimaryFlag(self):
"""Tests detect_isPrimary column gets added when run, and that sources
labelled as detect_isPrimary are not sky sources and have no children.
"""
calibConfig = CalibrateConfig()
calibConfig.doAstrometry = False
calibConfig.doPhotoCal = False
calibConfig.doComputeSummaryStats = False
calibTask = CalibrateTask(config=calibConfig)
calibResults = calibTask.run(self.charImResults.exposure)
outputCat = calibResults.outputCat
self.assertTrue("detect_isPrimary" in outputCat.schema.getNames())
# make sure all sky sources are flagged as not primary
self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["sky_source"])), 0)
# make sure all parent sources are flagged as not primary
self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["deblend_nChild"] > 0)), 0)
with self.assertRaises(KeyError):
outputCat.getSchema().find("detect_isDelendedModelPrimary")
def testIsScarletPrimaryFlag(self):
"""Test detect_isPrimary column when scarlet is used as the deblender
"""
# We need a multiband coadd for scarlet,
# even though there is only one band
coadds = afwImage.MultibandExposure.fromExposures(["test"], [self.exposure])
# Create a SkyMap with a tract that contains a portion of the image,
# subdivided into 3x3 patches
wcs = self.exposure.getWcs()
tractBBox = Box2I(Point2I(100, 100), Extent2I(900, 900))
skyMap = MockSkyMap([tractBBox], wcs, (3, 3))
tractInfo = skyMap[0]
patchInfo = tractInfo[0, 0]
patchBBox = patchInfo.getInnerBBox()
schema = SourceCatalog.Table.makeMinimalSchema()
# Initialize the detection task
detectionTask = SourceDetectionTask(schema=schema)
# Initialize the fake source injection task
skyConfig = SkyObjectsTask.ConfigClass()
skySourcesTask = SkyObjectsTask(name="skySources", config=skyConfig)
schema.addField("merge_peak_sky", type="Flag")
# Initialize the deconvolution task
deconvolveConfig = DeconvolveExposureTask.ConfigClass()
deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig)
# Initialize the deblender task
scarletConfig = ScarletDeblendTask.ConfigClass()
scarletConfig.maxIter = 20
scarletConfig.columnInheritance["merge_peak_sky"] = "merge_peak_sky"
deblendTask = ScarletDeblendTask(schema=schema, config=scarletConfig)
# We'll customize the configuration of measurement to just run the
# minimal number of plugins to make setPrimaryFlags work.
measureConfig = SingleFrameMeasurementTask.ConfigClass()
measureConfig.plugins.names = ["base_SdssCentroid", "base_SkyCoord"]
measureConfig.slots.psfFlux = None
measureConfig.slots.apFlux = None
measureConfig.slots.shape = None
measureConfig.slots.modelFlux = None
measureConfig.slots.calibFlux = None
measureConfig.slots.gaussianFlux = None
measureTask = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
setPrimaryTask = SetPrimaryFlagsTask(schema=schema, isSingleFrame=False)
table = SourceCatalog.Table.make(schema)
# detect sources
detectionResult = detectionTask.run(table, coadds["test"])
catalog = detectionResult.sources
# add fake sources
skySources = skySourcesTask.run(mask=self.exposure.mask, seed=0)
for foot in skySources[:5]:
src = catalog.addNew()
src.setFootprint(foot)
src.set("merge_peak_sky", True)
# deconvolve the images
deconvolved = deconvolveTask.run(coadds["test"], catalog).deconvolved
mDeconvolved = afwImage.MultibandExposure.fromExposures(["test"], [deconvolved])
# deblend
# This is a hack because the variance is not calibrated properly
# (it is 3 orders of magnitude too high), which causes the deblender
# to improperly deblend most sources due to the sparsity constraint.
coadds.variance.array[:] = 2e-1
mDeconvolved.variance.array[:] = 2e-1
catalog, modelData = deblendTask.run(coadds, mDeconvolved, catalog)
# Attach footprints to the catalog
mes.io.updateCatalogFootprints(
modelData=modelData,
catalog=catalog,
band="test",
imageForRedistribution=coadds["test"],
removeScarletData=True,
updateFluxColumns=True,
)
# measure
measureTask.run(catalog, self.exposure)
outputCat = catalog
# Set the primary flags
setPrimaryTask.run(outputCat, skyMap=skyMap, tractInfo=tractInfo, patchInfo=patchInfo)
# There should be the same number of deblenedPrimary and
# deblendedModelPrimary sources,
# since they both have the same blended sources and only differ
# over which model to use for the isolated sources.
isPseudo = outputCat["merge_peak_sky"]
self.assertEqual(
np.sum(outputCat["detect_isDeblendedSource"] & ~isPseudo),
np.sum(outputCat["detect_isDeblendedModelSource"]))
# Check that the sources contained in a tract are all marked appropriately
x = outputCat["slot_Centroid_x"]
y = outputCat["slot_Centroid_y"]
tractInner = tractBBox.contains(x, y)
np.testing.assert_array_equal(outputCat["detect_isTractInner"], tractInner)
# Check that the sources contained in a patch are all marked appropriately
patchInner = patchBBox.contains(x, y)
np.testing.assert_array_equal(outputCat["detect_isPatchInner"], patchInner)
# make sure all sky sources are flagged as not primary
self.assertEqual(sum((outputCat["detect_isPrimary"]) & (outputCat["merge_peak_sky"])), 0)
# Check that sky objects have not been deblended
np.testing.assert_array_equal(
isPseudo,
isPseudo & (outputCat["deblend_nChild"] == 0)
)
class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
def setup_module(module):
lsst.utils.tests.init()
if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()