Skip to content

Commit b4a253a

Browse files
authored
Merge pull request #298 from lsst/tickets/DM-48928
DM-48928: Fix serialization issue with LombScarglePeriodogramMulti plugin
2 parents 044012f + 13247a8 commit b4a253a

File tree

3 files changed

+79
-25
lines changed

3 files changed

+79
-25
lines changed

python/lsst/meas/base/compensatedGaussian/_compensatedTophat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class SingleFrameCompensatedTophatFluxConfig(SingleFramePluginConfig):
4444
doc="The aperture radii (in pixels) to measure the top-hats.",
4545
dtype=int,
4646
minLength=1,
47-
default=[12,],
47+
default=[12, ],
4848
)
4949
inner_scale = RangeField(
5050
doc="Inner background annulus scale (relative to aperture).",

python/lsst/meas/base/diaCalculationPlugins.py

+44-17
Original file line numberDiff line numberDiff line change
@@ -316,26 +316,36 @@ def calculate(self,
316316
**kwargs : `dict`
317317
Unused kwargs that are always passed to a plugin.
318318
"""
319-
n_bands = len(diaSources["band"].unique())
319+
320+
bands_arr = diaSources['band'].unique().values
321+
unique_bands = np.unique(np.concatenate(bands_arr))
320322
# Check and initialize output columns in diaObjects.
321323
if (periodCol := "multiPeriod") not in diaObjects.columns:
322324
diaObjects[periodCol] = np.nan
323325
if (powerCol := "multiPower") not in diaObjects.columns:
324326
diaObjects[powerCol] = np.nan
325327
if (fapCol := "multiFap") not in diaObjects.columns:
326328
diaObjects[fapCol] = np.nan
327-
if (ampCol := "multiAmp") not in diaObjects.columns:
328-
diaObjects[ampCol] = pd.Series([np.nan]*n_bands, dtype="object")
329-
if (phaseCol := "multiPhase") not in diaObjects.columns:
330-
diaObjects[phaseCol] = pd.Series([np.nan]*n_bands, dtype="object")
331-
332-
def _calculate_period_multi(df, min_detections=9, oversampling_factor=5, nyquist_factor=100):
329+
ampCol = "multiAmp"
330+
phaseCol = "multiPhase"
331+
for i in range(len(unique_bands)):
332+
ampCol_band = f"{unique_bands[i]}_{ampCol}"
333+
if ampCol_band not in diaObjects.columns:
334+
diaObjects[ampCol_band] = np.nan
335+
phaseCol_band = f"{unique_bands[i]}_{phaseCol}"
336+
if phaseCol_band not in diaObjects.columns:
337+
diaObjects[phaseCol_band] = np.nan
338+
339+
def _calculate_period_multi(df, all_unique_bands,
340+
min_detections=9, oversampling_factor=5, nyquist_factor=100):
333341
"""Calculate the multi-band Lomb-Scargle periodogram.
334342
335343
Parameters
336344
----------
337345
df : `pandas.DataFrame`
338346
The input DataFrame.
347+
all_unique_bands : `list` of `str`
348+
List of all bands present in the diaSource table that is being worked on.
339349
min_detections : `int`, optional
340350
The minimum number of detections, including all bands.
341351
oversampling_factor : `int`, optional
@@ -352,11 +362,14 @@ def _calculate_period_multi(df, min_detections=9, oversampling_factor=5, nyquist
352362
np.isnan(df["midpointMjdTai"]))]
353363

354364
if (len(tmpDf)) < min_detections:
355-
return pd.Series({periodCol: np.nan,
356-
powerCol: np.nan,
357-
fapCol: np.nan,
358-
ampCol: pd.Series([np.nan]*n_bands, dtype="object"),
359-
phaseCol: pd.Series([np.nan]*n_bands, dtype="object")})
365+
pd_tab_nodet = pd.Series({periodCol: np.nan,
366+
powerCol: np.nan,
367+
fapCol: np.nan})
368+
for band in all_unique_bands:
369+
pd_tab_nodet[f"{band}_{ampCol}"] = np.nan
370+
pd_tab_nodet[f"{band}_{phaseCol}"] = np.nan
371+
372+
return pd_tab_nodet
360373

361374
time = tmpDf["midpointMjdTai"].to_numpy()
362375
flux = tmpDf["psfFlux"].to_numpy()
@@ -378,15 +391,29 @@ def _calculate_period_multi(df, min_detections=9, oversampling_factor=5, nyquist
378391

379392
pd_tab = pd.Series({periodCol: period[np.argmax(power)],
380393
powerCol: np.max(power),
381-
fapCol: fap_estimate,
382-
ampCol: params_table_new[0],
383-
phaseCol: params_table_new[1]
394+
fapCol: fap_estimate
384395
})
385396

397+
# Initialize the per-band amplitude/phase columns as NaNs
398+
for band in all_unique_bands:
399+
pd_tab[f"{band}_{ampCol}"] = np.nan
400+
pd_tab[f"{band}_{phaseCol}"] = np.nan
401+
402+
# Populate the values of only the bands that have data for this diaSource
403+
unique_bands = np.unique(bands)
404+
for i in range(len(unique_bands)):
405+
pd_tab[f"{unique_bands[i]}_{ampCol}"] = params_table_new[0][i]
406+
pd_tab[f"{unique_bands[i]}_{phaseCol}"] = params_table_new[1][i]
407+
386408
return pd_tab
387409

388-
diaObjects.loc[:, [periodCol, powerCol, fapCol, ampCol, phaseCol]
389-
] = diaSources.apply(_calculate_period_multi)
410+
columns_list = [periodCol, powerCol, fapCol]
411+
for i in range(len(unique_bands)):
412+
columns_list.append(f"{unique_bands[i]}_{ampCol}")
413+
columns_list.append(f"{unique_bands[i]}_{phaseCol}")
414+
415+
diaObjects.loc[:, columns_list
416+
] = diaSources.apply(_calculate_period_multi, unique_bands)
390417

391418

392419
class MeanDiaPositionConfig(DiaObjectCalculationPluginConfig):

tests/test_diaCalculationPlugins.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run_multi_plugin(diaObjectCat, diaSourceCat, band, plugin):
9797
Input object catalog to store data into and read from.
9898
diaSourcesCat : `pandas.DataFrame`
9999
DiaSource catalog to read data from and groupby on.
100-
fitlerName : `str`
100+
filterName : `str`
101101
String name of the filter to process.
102102
plugin : `lsst.ap.association.DiaCalculationPlugin`
103103
Plugin to run.
@@ -121,6 +121,33 @@ def run_multi_plugin(diaObjectCat, diaSourceCat, band, plugin):
121121
band=band)
122122

123123

124+
def run_multiband_plugin(diaObjectCat, diaSourceCat, plugin):
125+
"""Wrapper for running multi plugins.
126+
127+
Reproduces some of the behavior of `lsst.ap.association.DiaCalcuation.run`
128+
129+
Parameters
130+
----------
131+
diaObjectCat : `pandas.DataFrame`
132+
Input object catalog to store data into and read from.
133+
diaSourcesCat : `pandas.DataFrame`
134+
DiaSource catalog to read data from and groupby on.
135+
plugin : `lsst.ap.association.DiaCalculationPlugin`
136+
Plugin to run.
137+
"""
138+
diaObjectCat.set_index("diaObjectId", inplace=True, drop=False)
139+
diaSourceCat.set_index(
140+
["diaObjectId", "band", "diaSourceId"],
141+
inplace=True,
142+
drop=False)
143+
144+
diaSourcesGB = diaSourceCat.groupby(level=0)
145+
146+
plugin.calculate(diaObjects=diaObjectCat,
147+
diaSources=diaSourcesGB,
148+
)
149+
150+
124151
def make_diaObject_table(objId, plugin, default_value=None, band=None):
125152
"""Create a minimal diaObject table with columns required for the plugin
126153
@@ -958,18 +985,18 @@ def testCalculate(self):
958985
"ap_lombScarglePeriodogramMulti",
959986
None)
960987

961-
run_multi_plugin(diaObjects, diaSources, "u", plugin)
988+
run_multiband_plugin(diaObjects, diaSources, plugin)
962989
self.assertAlmostEqual(diaObjects.at[objId, "multiPeriod"], 10, delta=0.04)
963990
self.assertAlmostEqual(diaObjects.at[objId, "multiPower"], 1, delta=1e-2)
964991
# This implementation of LS returns a normalized power < 1.
965992
self.assertLess(diaObjects.at[objId, "multiPower"], 1)
966993
self.assertAlmostEqual(diaObjects.at[objId, "multiFap"], 0, delta=0.04)
967994
# Note: The below values are empirical, but seem reasonable, and
968-
# test that we get an array with one value per band.
969-
self.assertFloatsAlmostEqual(np.array(diaObjects.at[objId, "multiAmp"]),
970-
np.array([0.029, 0.029]), atol=1e-3)
971-
self.assertFloatsAlmostEqual(np.array(diaObjects.at[objId, "multiPhase"]),
972-
np.array([1., -2.]), rtol=6e-2)
995+
# test that we get values for each band.
996+
self.assertAlmostEqual(diaObjects.at[objId, "u_multiAmp"], 0.029, delta=0.01)
997+
self.assertAlmostEqual(diaObjects.at[objId, "g_multiAmp"], 0.029, delta=0.01)
998+
self.assertAlmostEqual(diaObjects.at[objId, "u_multiPhase"], -2.0, delta=0.2)
999+
self.assertAlmostEqual(diaObjects.at[objId, "g_multiPhase"], 1.0, delta=0.1)
9731000

9741001
def testCalculateTwoSources(self):
9751002
"""Test Mulitband Lomb Scargle Periodogram with 2 sources (minimum

0 commit comments

Comments
 (0)