Skip to content

Commit 7c3a2a1

Browse files
committed
Refactor read methods to delegate to new Reader interface.
1 parent f79d1ee commit 7c3a2a1

18 files changed

+438
-234
lines changed

piff/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,5 @@
120120
from . import des
121121
from . import wavefront
122122
from . import meta_data
123+
from . import readers
123124
from . import writers

piff/basis_interp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,12 @@ def _finish_write(self, writer):
447447
data['q'] = self.q
448448
writer.write_table('solution', data)
449449

450-
def _finish_read(self, fits, extname):
451-
"""Read the solution from a FITS binary table.
450+
def _finish_read(self, reader):
451+
"""Read the solution.
452452
453-
:param fits: An open fitsio.FITS object.
454-
:param extname: The name of the extension with the interpolator information.
453+
:param reader: A reader object that encapsulates the serialization format.
455454
"""
456-
data = fits[extname + '_solution'].read()
455+
data = reader.read_table('solution')
456+
assert data is not None
457457
self.q = data['q'][0]
458458

piff/convolvepsf.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import galsim
2121

2222
from .psf import PSF
23-
from .util import read_kwargs
2423
from .star import Star, StarFit
2524
from .outliers import Outliers
2625

@@ -292,24 +291,20 @@ def _finish_write(self, writer, logger):
292291
self.outliers._write(writer, 'outliers')
293292
logger.debug("Wrote the PSF outliers to extension %s", writer.get_full_name('outliers'))
294293

295-
def _finish_read(self, fits, extname, logger):
294+
def _finish_read(self, reader, logger):
296295
"""Finish the reading process with any class-specific steps.
297296
298-
:param fits: An open fitsio.FITS object
299-
:param extname: The base name of the extension to write to.
297+
:param reader: A reader object that encapsulates the serialization format.
300298
:param logger: A logger object for logging debug info.
301299
"""
302-
chisq_dict = read_kwargs(fits, extname + '_chisq')
300+
chisq_dict = reader.read_struct('chisq')
303301
for key in chisq_dict:
304302
setattr(self, key, chisq_dict[key])
305303

306304
ncomponents = self.components
307305
self.components = []
308306
for k in range(ncomponents):
309-
self.components.append(PSF._read(fits, extname + '_' + str(k), logger=logger))
310-
if extname + '_outliers' in fits:
311-
self.outliers = Outliers.read(fits, extname + '_outliers')
312-
else:
313-
self.outliers = None
307+
self.components.append(PSF._read(reader, str(k), logger=logger))
308+
self.outliers = Outliers._read(reader, 'outliers')
314309
# Set up all the num's properly now that everything is constructed.
315310
self.set_num(None)

piff/des/decam_wavefront.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ def _finish_write(self, writer):
176176
# write to fits
177177
writer.write_table('solution', data)
178178

179-
def _finish_read(self, fits, extname):
180-
"""Read the solution from a FITS binary table.
179+
def _finish_read(self, reader):
180+
"""Read the solution.
181181
182-
:param fits: An open fitsio.FITS object.
183-
:param extname: The name of the extension with the interp information.
182+
:param reader: A reader object that encapsulates the serialization format.
184183
"""
185-
data = fits[extname + '_solution'].read()
184+
data = reader.read_table('solution')
185+
assert data is not None
186186

187187
# self.locations and self.targets assigned in _fit
188188
self._fit(data['LOCATIONS'][0], data['TARGETS'][0])

piff/gp_interp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ def _finish_write(self, writer):
321321

322322
writer.write_table('kernel', data)
323323

324-
def _finish_read(self, fits, extname):
324+
def _finish_read(self, reader):
325325
"""Finish the reading process with any class-specific steps.
326326
327-
:param fits: An open fitsio.FITS object.
328-
:param extname: The base name of the extension.
327+
:param reader: A reader object that encapsulates the serialization format.
329328
"""
330-
data = fits[extname+'_kernel'].read()
329+
data = reader.read_table('kernel')
330+
assert data is not None
331331
# Run fit to set up GP, but don't actually do any hyperparameter optimization. Just
332332
# set the GP up using the current hyperparameters.
333333
# Need to give back average fits files if needed.

piff/interp.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"""
1818

1919
import numpy as np
20-
from .util import read_kwargs
2120

2221
class Interp(object):
2322
"""The base class for interpolating a set of data vectors across the field of view.
@@ -227,36 +226,39 @@ def read(cls, fits, extname):
227226
228227
:returns: an interpolator built with a information in the FITS file.
229228
"""
230-
assert extname in fits
231-
assert 'type' in fits[extname].get_colnames()
232-
assert 'type' in fits[extname].read().dtype.names
233-
# interp_type = fits[extname].read_column('type')
234-
interp_type = fits[extname].read()['type']
235-
assert len(interp_type) == 1
236-
try:
237-
interp_type = str(interp_type[0].decode())
238-
except AttributeError:
239-
# fitsio 1.0 returns strings
240-
interp_type = interp_type[0]
229+
from .readers import FitsReader
230+
return cls._read(FitsReader(fits, None), extname)
231+
232+
@classmethod
233+
def _read(cls, reader, name):
234+
"""Read an Interp via a Reader object.
235+
236+
:param reader: A reader object that encapsulates the serialization format.
237+
:param name: Name associated with this interpolator in the serialized output.
238+
239+
:returns: an interpolator built from serialized information.
240+
"""
241+
kwargs = reader.read_struct(name)
242+
assert kwargs is not None
243+
assert 'type' in kwargs
244+
interp_type = kwargs.pop('type')
241245

242246
# Check that interp_type is a valid Interp type.
243247
if interp_type not in Interp.valid_interp_types:
244248
raise ValueError("interp type %s is not a valid Piff Interpolation"%interp_type)
245249
interp_cls = Interp.valid_interp_types[interp_type]
246250

247-
kwargs = read_kwargs(fits, extname)
248-
kwargs.pop('type',None)
249251
interp = interp_cls(**kwargs)
250-
interp._finish_read(fits, extname)
252+
with reader.nested(name) as r:
253+
interp._finish_read(r)
251254
return interp
252255

253-
def _finish_read(self, fits, extname):
256+
def _finish_read(self, reader):
254257
"""Finish the reading process with any class-specific steps.
255258
256259
The base class implementation doesn't do anything, but this will probably always be
257260
overridden by the derived class.
258261
259-
:param fits: An open fitsio.FITS object.
260-
:param extname: The base name of the extension.
262+
:param reader: A reader object that encapsulates the serialization format.
261263
"""
262264
raise NotImplementedError("Derived classes must define the _finish_read method.")

piff/knn_interp.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,12 @@ def _finish_write(self, writer):
174174

175175
writer.write_table('solution', data)
176176

177-
def _finish_read(self, fits, extname):
178-
"""Read the solution from a FITS binary table.
177+
def _finish_read(self, reader):
178+
"""Read the solution.
179179
180-
:param fits: An open fitsio.FITS object.
181-
:param extname: The base name of the extension with the interp information.
180+
:param reader: A reader object that encapsulates the serialization format.
182181
"""
183-
data = fits[extname + '_solution'].read()
182+
data = reader.read_table('solution')
184183

185184
# self.locations and self.targets assigned in _fit
186185
self._fit(data['LOCATIONS'][0], data['TARGETS'][0])

piff/mean_interp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def _finish_write(self, writer):
7070
data = np.array(list(zip(*cols)), dtype=dtypes)
7171
writer.write_table('solution', data)
7272

73-
def _finish_read(self, fits, extname):
74-
"""Read the solution from a FITS binary table.
73+
def _finish_read(self, reader):
74+
"""Read the solution.
7575
76-
:param fits: An open fitsio.FITS object.
77-
:param extname: The base name of the extension
76+
:param reader: A reader object that encapsulates the serialization format.
7877
"""
79-
data = fits[extname + '_solution'].read()
78+
data = reader.read_table('solution')
79+
assert data is not None
8080
self.mean = data['mean']

piff/model.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
.. module:: model
1717
"""
1818

19-
from .util import read_kwargs
2019
from .star import Star
2120

2221

@@ -204,29 +203,37 @@ def read(cls, fits, extname):
204203
205204
:returns: a model built with a information in the FITS file.
206205
"""
207-
assert extname in fits
208-
assert 'type' in fits[extname].get_colnames()
209-
model_type = fits[extname].read()['type']
210-
assert len(model_type) == 1
211-
try:
212-
model_type = str(model_type[0].decode())
213-
except AttributeError:
214-
# fitsio 1.0 returns strings
215-
model_type = model_type[0]
206+
from .readers import FitsReader
207+
return cls._read(FitsReader(fits, None), extname)
208+
209+
@classmethod
210+
def _read(cls, reader, name):
211+
"""Read a Model from a FITS file.
212+
213+
Note: the returned Model will not have its parameters set. This just initializes a fresh
214+
model that can be used to interpret interpolated vectors.
215+
216+
:param reader: A reader object that encapsulates the serialization format.
217+
:param name: Name associated with this model in the serialized output.
218+
219+
:returns: a model built with a information in the FITS file
220+
"""
221+
kwargs = reader.read_struct(name)
222+
assert kwargs is not None
223+
assert 'type' in kwargs
224+
model_type = kwargs.pop('type')
216225

217226
# Check that model_type is a valid Model type.
218227
if model_type not in Model.valid_model_types:
219228
raise ValueError("model type %s is not a valid Piff Model"%model_type)
220229
model_cls = Model.valid_model_types[model_type]
221230

222-
kwargs = read_kwargs(fits, extname)
223-
kwargs.pop('type',None)
224231
if 'force_model_center' in kwargs: # pragma: no cover
225232
# old version of this parameter name.
226233
kwargs['centered'] = kwargs.pop('force_model_center')
227234
model_cls._fix_kwargs(kwargs)
228235
model = model_cls(**kwargs)
229-
model._finish_read(fits, extname)
236+
model._finish_read(reader)
230237
return model
231238

232239
@classmethod
@@ -243,14 +250,13 @@ def _fix_kwargs(cls, kwargs):
243250
"""
244251
pass
245252

246-
def _finish_read(self, fits, extname):
253+
def _finish_read(self, reader):
247254
"""Finish the reading process with any class-specific steps.
248255
249256
The base class implementation doesn't do anything, which is often appropriate, but
250257
this hook exists in case any Model classes need to read extra information from the
251258
fits file.
252259
253-
:param fits: An open fitsio.FITS object.
254-
:param extname: The base name of the extension.
260+
:param reader: A reader object that encapsulates the serialization format.
255261
"""
256262
pass

piff/outliers.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import galsim
2323
from scipy.stats import chi2
2424

25-
from .util import read_kwargs
2625

2726
class Outliers(object):
2827
"""The base class for handling outliers.
@@ -135,15 +134,25 @@ def read(cls, fits, extname):
135134
136135
:returns: an Outliers handler
137136
"""
138-
assert extname in fits
139-
assert 'type' in fits[extname].get_colnames()
140-
outliers_type = fits[extname].read()['type']
141-
assert len(outliers_type) == 1
142-
try:
143-
outliers_type = str(outliers_type[0].decode())
144-
except AttributeError:
145-
# fitsio 1.0 returns strings
146-
outliers_type = outliers_type[0]
137+
from .readers import FitsReader
138+
result = cls._read(FitsReader(fits, None), extname)
139+
assert result is not None
140+
return result
141+
142+
@classmethod
143+
def _read(cls, reader, name):
144+
"""Read a Outliers from a FITS file.
145+
146+
:param reader: A reader object that encapsulates the serialization format.
147+
:param name: Name associated with the outliers in the serialized output.
148+
149+
:returns: an Outliers handler, or None if there isn't one.
150+
"""
151+
kwargs = reader.read_struct(name)
152+
if kwargs is None:
153+
return None
154+
assert 'type' in kwargs
155+
outliers_type = kwargs.pop('type')
147156

148157
# Old output files had the full name. Fix it if necessary.
149158
if (outliers_type.endswith('Outliers')
@@ -155,21 +164,19 @@ def read(cls, fits, extname):
155164
raise ValueError("outliers type %s is not a valid Piff Outliers"%outliers_type)
156165
outliers_cls = Outliers.valid_outliers_types[outliers_type]
157166

158-
kwargs = read_kwargs(fits, extname)
159-
kwargs.pop('type',None)
160167
outliers = outliers_cls(**kwargs)
161-
outliers._finish_read(fits, extname)
168+
with reader.nested(name) as r:
169+
outliers._finish_read(r)
162170
return outliers
163171

164-
def _finish_read(self, fits, extname):
172+
def _finish_read(self, reader):
165173
"""Finish the reading process with any class-specific steps.
166174
167175
The base class implementation doesn't do anything, which is often appropriate, but
168176
this hook exists in case any Outliers classes need to read extra information from the
169177
fits file.
170178
171-
:param fits: An open fitsio.FITS object.
172-
:param extname: The base name of the extension.
179+
:param reader: A reader object that encapsulates the serialization format.
173180
"""
174181
pass
175182

piff/polynomial_interp.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -370,17 +370,16 @@ def _finish_write(self, writer):
370370
writer.write_table('solution', data, metadata=header)
371371

372372

373-
def _finish_read(self, fits, extname):
374-
"""Read the solution from a fits file.
373+
def _finish_read(self, reader):
374+
"""Read the solution.
375375
376-
:param fits: An open fitsio.FITS object.
377-
:param extname: The base name of the extension
376+
:param reader: A reader object that encapsulates the serialization format.
378377
"""
379378
# Read the solution extension.
380-
data = fits[extname + '_solution'].read()
381-
header = fits[extname + '_solution'].read_header()
382-
383-
self.nparam = header['NPARAM']
379+
metadata = {}
380+
data = reader.read_table('solution', metadata=metadata)
381+
assert data is not None
382+
self.nparam = metadata['NPARAM']
384383

385384
# Run setup functions to get these values right.
386385
self._set_function(self.poly_type)

0 commit comments

Comments
 (0)