Skip to content

Commit 93747a7

Browse files
author
kahil_dell
committed
edited plotting tools
1 parent 292b474 commit 93747a7

File tree

6 files changed

+79
-61
lines changed

6 files changed

+79
-61
lines changed

PD.py

-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from functools import partial
66

77

8-
import pyfits
9-
108

119

1210
## function to compute the FT of focused and defocused image

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ This class allows the user to:
3232
4. Restore blurred images and correct for the PSF choosing between the Wiener filter and the Richardson-Lucy filter
3333

3434
The parameters of this class:
35-
1. the pair of focused and defocused images,
35+
1. the pair of focused and defocused images, the input image should have a format of `2xsize_xxsize_y`.
3636
2. the parameters of the `Telescope` class
3737
3. `cutoff` frequency for the noise filtering
3838
4. `reg`: regularization parameter for the Wiener filter
@@ -56,7 +56,7 @@ The `minimization` class returns the best-fit zernike polynomials, a visualisati
5656
python3 minimization.py -i 'path/input.fits' -s 150 -w 617.3e-6 -a 140 -f 4125.3 -p 0.5 -c 0.5 -r 1e-10 -ap 10 -x1 500 -x2 650 -y1 500 -y2 650 -z 10 -del 0.5 -o path/reduced.fits -fl 'Wiener'
5757
```
5858

59-
The specific description of the parsers can be found inside [the main code](https://github.com/fakahil/PyPD/blob/master/minimization.py). The values given above are for an example PD dataset taken by the PHI/HRT telescope. You can change the values according to your telescope.
59+
The specific description of the parsers and input to the class can be found inside [the main code](https://github.com/fakahil/PyPD/blob/master/minimization.py). The values given above are for an example PD dataset taken by the PHI/HRT telescope. You can change the values according to your telescope.
6060

6161
To use the `patch_pd` class:
6262

deconvolution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from scipy.fftpack import fftshift, ifftshift, fft2, ifft2
66
from astropy.io import fits
77
import imreg_dft
8-
import pyfits
8+
99
import tools
1010
import aperture
1111
import PD

minimization.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@
3434

3535
class minimization(object):
3636

37-
def __init__(self,foc_defoc,size,lam, diameter,focal_length,platescale,cut_off,reg,ap,x1,x2,y1,y2,co_num,del_z,output,filterr):
37+
def __init__(self,foc_defoc,lam, diameter,focal_length,platescale,cut_off,reg,ap,x1,x2,y1,y2,co_num,del_z,output,filterr):
3838
self.data = fits.getdata(foc_defoc)
3939

4040

4141
self.foc = self.data[0,:,:]
4242
self.defoc = self.data[1,:,:]
4343
self.cut_off = cut_off
4444
self.reg = reg
45-
self.size = size
45+
self.size =self.y2-self.y1
4646
self.ap = ap
4747
self.co_num = co_num
4848
self.x1 = x1
@@ -112,7 +112,7 @@ def plot_results(self,Z):
112112
ph =wavefront.phase(Z, self.telescope.pupil_size(),self.co_num)
113113

114114

115-
fig = plt.figure(figsize=(10,10))
115+
fig = plt.figure(figsize=(20,20))
116116
ax1 = fig.add_subplot(1,3,1)
117117
im1 = ax1.imshow(ph/(2*np.pi), origin='lower',cmap='gray')
118118
ax1.set_xlabel('[Pixels]',fontsize=18)
@@ -187,7 +187,6 @@ def restored_scene(self,Z,iterations_RL=10):
187187
parser = argparse.ArgumentParser(description='Retrieving wavefront error')
188188
parser.add_argument('-i','--input', help='input')
189189
parser.add_argument('-o','--out', help='out')
190-
parser.add_argument('-s','--size', help='size',default=150)
191190
parser.add_argument('-w','--wavelength', help='wavelength',default=617.3e-6)
192191
parser.add_argument('-a','--aperture', help='aperture', default=140)
193192
parser.add_argument('-f','--focal_length', help='focal_length',default=4125.3)
@@ -207,12 +206,13 @@ def restored_scene(self,Z,iterations_RL=10):
207206

208207

209208

210-
res = minimization(foc_defoc='{0}'.format(parsed['input']), size=int(parsed['size']), lam=float(parsed['wavelength']),diameter=float(parsed['aperture']),focal_length=float(parsed['focal_length']),
209+
res = minimization(foc_defoc='{0}'.format(parsed['input']), lam=float(parsed['wavelength']),diameter=float(parsed['aperture']),focal_length=float(parsed['focal_length']),
211210
platescale=float(parsed['plate_scale']),cut_off=float(parsed['cut_off']),reg=float(parsed['reg']),ap=int(parsed['apod']),x1=int(parsed['x1']),x2=int(parsed['x2']),y1=int(parsed['y1']),y2=int(parsed['y2']),
212211
co_num=int(parsed['Z']),del_z=float(parsed['del']),output='{0}'.format(parsed['out']),filterr=parsed['filter'])
213212

214213
Z = res.fit()
215214
print(Z)
215+
tools.plot_zernike(Z)
216216
res.plot_results(Z)
217217
res.restored_scene(Z,10)
218218

patch_pd.py

+15-43
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import time
2-
import imreg_dft
3-
import pyfits
2+
43
import tools
54
import aperture
65
import PD
@@ -14,8 +13,7 @@
1413
import argparse
1514
import telescope
1615
from telescope import *
17-
18-
from tools import imreg, apo2d
16+
from tools import *
1917
from aperture import *
2018
from PD import *
2119
from noise import *
@@ -144,56 +142,30 @@ def Minimise(coefficients):
144142
hdu = fits.PrimaryHDU(self.output_mtf)
145143
hdu.writeto(self.output_mtf,overwrite=True)
146144
else:
145+
print('Initialising parallel computation')
146+
t0 = time.time()
147147
self.patch = tools.prepare_patches(self.data,self.Del,self.Im0,self.Imk)
148148
n_workers = min(6, os.cpu_count())
149+
print(f'number of workers is {n_workers}')
149150
self.args_list = [i for i in range(len(self.patch))]
150151
self.results_parallel = list(processing.MP.simultaneous(self.run_pd, self.args_list, workers=n_workers))
152+
dt = (time.time() - t0)/60.
153+
print(f'Time spent in fitting the wavefront error is: {dt: .3f}min')
151154

152-
def plot_results(self,output):
155+
def plot_results(self):
153156

154-
# change here the format of the output
155157
if not self.parallel:
156158
data_mtf = self.output_MTF
157159
data_wfe = self.output_WF
158160

159161
if self.parallel:
160-
## call here the stitching function
161162
data_mtf,data_wfe = tools.stitch_patches(self.results_parallel,self.Del)
163+
hdu = fits.PrimaryHDU(data_mtf)
164+
hdu.writeto(self.output_mtf,overwrite=True)
165+
hdu = fits.PrimaryHDU(data_wfe)
166+
hdu.writeto(self.output_wf,overwrite=True)
167+
tools.plot_mtf_wf(data_wfe,data_mtf)
162168

163-
164-
fig, ax = plt.subplots(1,2,figsize=(10,10))
165-
166-
im=ax[1].imshow(data_mtf,vmin=0,vmax=1,origin='lower',cmap='gray')
167-
ax[1].set_title('MTF')
168-
divider = make_axes_locatable(ax[1])
169-
cax = divider.append_axes('right',pad=0.05,size=0.03)
170-
cbar1 = plt.colorbar(im,cax=cax)
171-
cbar1.ax.tick_params(labelsize=16)
172-
cbar1.set_label('MTF',fontsize=16)
173-
ax[1].set_xlabel('[Pixels]')
174-
major_ticks = np.arange(0, 2048,350)
175-
major_ticks_y = np.arange(0, 2048,350)
176-
ax[1].set_xticks(major_ticks)
177-
ax[1].set_yticks(major_ticks_y)
178-
ax[1].tick_params(labelsize=4)
179-
180-
181-
182-
im2=ax[0].imshow(data_wfe/(2*np.pi),vmin=-0.5,vmax=1.4,origin='lower',cmap='gray')
183-
ax[0].set_xlabel('[Pixels]')
184-
ax[0].set_title('WF error[$\lambda$]')
185-
divider2 = make_axes_locatable(ax[0])
186-
cax2 = divider2.append_axes('right',pad=0.05,size=0.03) #size is the width of the color bar and pad is the fraction of the new axis
187-
cbar2=plt.colorbar(im2,cax=cax2)
188-
cbar2.ax.tick_params(labelsize=16)
189-
ax[0].tick_params(labelbottom=False,labelsize=16)
190-
plt.subplots_adjust(wspace=None, hspace=0.1)
191-
ax[0].set_xticks(major_ticks)
192-
ax[0].set_yticks(major_ticks_y)
193-
ax[0].set_ylabel('[Pixels]')
194-
#plt.subplots_adjust(wspace=None, hspace=-0.1)
195-
#plt.axis('off')
196-
plt.savefig(output,dpi=300)
197169

198170
if (__name__ == '__main__'):
199171
parser = argparse.ArgumentParser(description='PD on sub-fields')
@@ -202,10 +174,10 @@ def plot_results(self,output):
202174
parser.add_argument('-d','--Del', help='Del',default=265)
203175
parser.add_argument('-ow','--ow', help='output_WFE')
204176
parser.add_argument('-om','--om', help='output MTF')
205-
parser.add_argument('-r','--res', help='results')
177+
206178
parser.add_argument('-p','--parallel',choices=['True','False'],default=True)
207179
parsed = vars(parser.parse_args())
208180
st = patch_pd(pd_data='{0}'.format(parsed['input']),Del=int(parsed['Del']),co_num=int(parsed['Z']),output_wf='{0}'.format(parsed['ow']),output_mtf='{0}'.format(parsed['om']),parallel=bool(parsed['parallel']))
209181
st.fit_patch()
210-
st.plot_results(output='{0}'.format(parsed['res']))
182+
st.plot_results()
211183

tools.py

+56-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import numpy as np
2+
import pylab
23
import matplotlib.pyplot as plt
34
import scipy
45
from scipy.fftpack import fftshift, ifftshift, fft2, ifft2
56
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
67
from scipy.signal import correlate2d as correlate
78
from scipy.signal import general_gaussian
89
from scipy import ndimage
9-
10-
10+
import imreg_dft as ird
11+
from image_registration import chi2_shift
12+
from image_registration.fft_tools import shift
1113

1214

1315
def GetPSD1D(psd2D):
@@ -72,9 +74,7 @@ def strehl(rms):
7274
return np.exp(-2*(np.pi*rms**2))
7375

7476
def imreg(im0,imk):
75-
import imreg_dft as ird
76-
from image_registration import chi2_shift
77-
from image_registration.fft_tools import shift
77+
7878
xoff, yoff, exoff, eyoff = chi2_shift(im0,imk)
7979
timg = ird.transform_img(imk, tvec=np.array([-yoff,-xoff]))
8080
return timg
@@ -87,7 +87,6 @@ def noise(im):
8787
s = estimate_sigma(im)
8888
return s
8989

90-
9190
def plot_zernike(coeff):
9291
n = coeff.shape[0]
9392
index = np.arange(n)
@@ -100,10 +99,11 @@ def plot_zernike(coeff):
10099
plt.xlabel('Zernike Polynomials',fontsize=18)
101100
plt.ylabel('Coefficient [$\lambda$]',fontsize=18)
102101
plt.title('Zernike Polynomials Coefficients',fontsize=18)
102+
plt.savefig('Zernikes.png',dpi=300)
103103

104104
def prepare_patches(d,Del,Im0,Imk):
105105
n = d.shape[0]
106-
upper = 1700-Del
106+
upper = 1700
107107
lower = 300
108108
Nx = np.arange(lower,upper,Del)
109109
Ny = np.arange(lower,upper,Del)
@@ -141,10 +141,58 @@ def stitch_patches(results,Del):
141141
st_wf[n2:n2+Del,n1:n1+Del] = data1[k]
142142
st_mtf[n2:n2+Del,n1:n1+Del] = data2[k]
143143
k=k+1
144-
return st_wf,st_mtf
144+
return st_mtf,st_wf
145145

146146

147147

148+
def plot_mtf_wf(ph,mtf):
149+
150+
fig=plt.figure(figsize=(20,8))
151+
aspect = 5
152+
pad_fraction = 0.5
153+
ax = fig.add_subplot(1,2,1)
154+
im=ax.imshow(ph/(2*np.pi), cmap=pylab.gray(),origin='lower',vmin=-1.2,vmax=1.2)
155+
156+
ax.set_xlabel('[Pixels]',fontsize=18)
157+
ax.set_ylabel('[Pixels]',fontsize=18)
158+
divider = make_axes_locatable(ax)
159+
cax = divider.append_axes("right", size=0.15, pad=0.05)
160+
cbar = plt.colorbar(im, cax=cax,orientation='vertical')
161+
cbar.set_label('WF error HRT [$\lambda$]',fontsize=20)
162+
cax.tick_params(labelsize=14)
163+
ax2 = fig.add_subplot(1,2,2)
164+
165+
im2=ax2.imshow(mtf,cmap=pylab.gray(),origin='lower',vmin=0,vmax=1)
166+
ax2.set_xlabel('[Pixels]',fontsize=18)
167+
divider = make_axes_locatable(ax2)
168+
cax2 = divider.append_axes("right", size=0.15, pad=0.05)
169+
cbar2 = plt.colorbar(im2, cax=cax2,orientation='vertical')
170+
cax2.tick_params(labelsize=14)
171+
cbar2.set_label('MTF',fontsize=16)
172+
173+
plt.subplots_adjust(wspace=.2, hspace=None)
174+
plt.savefig('WFE+MTF.png',dpi=300)
175+
148176

177+
def compute_residual_shifts(pd_pair,Del):
149178

179+
d = fits.getdata(pd_pair)
180+
xoff, yoff, exoff, eyoff = chi2_shift(d[0,500:1000,500:1000],d[1,500:1000,500:1000])
181+
Imk = ird.transform_img(d[1,:,:], tvec=np.array([-yoff,-xoff]))
182+
Nx = np.arange(200,1800,Del)
183+
Ny = np.arange(200,1800,Del)
184+
shifts_x = np.zeros((2048,2048))
185+
shifts_y = np.zeros((2048,2048))
186+
S_x = []
187+
S_y = []
150188

189+
for n1 in Nx :
190+
for n2 in Ny:
191+
192+
im0 = Im0[n2:n2+Del,n1:n1+Del]
193+
imk = Imk[n2:n2+Del,n1:n1+Del]
194+
xoff, yoff, exoff, eyoff = chi2_shift(im0,imk)
195+
print(xoff, yoff)
196+
shifts_x[n2:n2+Del,n1:n1+Del] = xoff
197+
shifts_y[n2:n2+Del,n1:n1+Del] = yoff
198+
return shifts_x, shifts_y

0 commit comments

Comments
 (0)