Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 292b474

Browse files
author
kahil_dell
committedApr 4, 2023
added parallel computation
1 parent 7ef02c8 commit 292b474

10 files changed

+283
-14
lines changed
 

‎README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ The modules of this class:
4848

4949
## The `patch_pd` class:
5050
For fitting the wavefront error is sub-regions of the full FOV of the PD dataset. The user can enter the size of the subregion (it has to be quadratic) plus the number of Zernike polynomials to be fit and the names of the output files (one for the wavefront error and one for the 2D MTF).
51+
52+
The class offers the option to run parallel computation by setting the parser -p to True (see below).
5153
# How to use the code?
5254
The `minimization` class returns the best-fit zernike polynomials, a visualisation of the results (wavefront error+MTF), and the restored scene (in this case the focused image of the PD dataset). To get these specific results, type in the shell terminal:
5355
```
@@ -59,7 +61,7 @@ The specific description of the parsers can be found inside [the main code](http
5961
To use the `patch_pd` class:
6062

6163
```
62-
python3 patch_pd.py -i 'path/input.fits' -z 10 -d 265 -ow 'path/output_wf.fits' -om 'path/output_wf.fits'
64+
python3 patch_pd.py -i 'path/input.fits' -z 10 -d 265 -ow 'path/output_wf.fits' -om 'path/output_wf.fits -p True
6365
6466
```
6567
The parsers description can be found in the [main code](https://github.com/fakahil/PyPD/blob/master/patch_pd.py)

‎__pycache__/cost_func.cpython-38.pyc

0 Bytes
Binary file not shown.

‎__pycache__/logging.cpython-38.pyc

2.58 KB
Binary file not shown.

‎__pycache__/processing.cpython-38.pyc

5.12 KB
Binary file not shown.

‎__pycache__/tools.cpython-38.pyc

1.27 KB
Binary file not shown.

‎cost_func.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,6 @@ def Minimize_res(self,coefficients):
5151

5252

5353
return t0,ph
54+
55+
5456

‎patch_pd.py

+109-13
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
from wavefront import *
2323
from deconvolution import *
2424
from cost_func import *
25+
import os
26+
import processing
2527

2628

2729
class patch_pd(object):
28-
def __init__(self,pd_data,Del,co_num,output_wf,output_mtf):
30+
def __init__(self,pd_data,Del,co_num,output_wf,output_mtf,parallel=True):
2931

3032

3133
self.data= fits.getdata(pd_data)
@@ -42,15 +44,56 @@ def __init__(self,pd_data,Del,co_num,output_wf,output_mtf):
4244
self.Imk = self.data[1,:,:]/self.mean_imk
4345

4446
self.output_WF = np.zeros((2048,2048))
45-
self.output_mtf = np.zeros((2048,2048))
46-
47-
47+
self.output_MTF = np.zeros((2048,2048))
48+
self.parallel = parallel
49+
#@staticmethod
50+
def run_pd(self,k):
51+
52+
im0 = self.patch[k,:,:,0]
53+
imk =self.patch[k,:,:,1]
54+
im0 = im0/im0.mean()
55+
imk = imk/imk.mean()
56+
imk = imreg(im0,imk)
57+
58+
59+
im0 = apo2d(im0,10)
60+
imk = apo2d(imk,10)
61+
62+
d0,dk = FT(im0,imk)
63+
gam =1# Gamma(d0,dk,M_gamma)
64+
p0 = np.zeros(self.co_num)
65+
fit = cost_func(self.Del,0.5,1e-10,10,self.co_num,0.5,617.3e-6, 140,4125.3,0.5)
66+
Mask = aperture.mask_pupil(fit.telescope.pupil_size(),fit.size)
67+
noise_temp = noise_mask_high(fit.size,fit.cut_off)
68+
noise_filter = fftshift(noise_temp)
69+
def Minimise(coefficients):
70+
A_f = wavefront.pupil_foc(coefficients,fit.size,fit.telescope.pupil_size(),self.co_num)
71+
A_def = wavefront.pupil_defocus(coefficients,fit.size,fit.del_z,fit.telescope.pupil_size(),self.co_num)
72+
psf_foc = wavefront.PSF(Mask,A_f,False)
73+
psf_defoc = wavefront.PSF(Mask,A_def,False)
74+
t0 = wavefront.OTF(psf_foc)
75+
tk = wavefront.OTF(psf_defoc)
76+
q,q2 = PD.Q_matrix(t0,tk,fit.reg,gam)
77+
F_m = PD.F_M(q2,d0, dk,t0,tk,noise_filter,gam)
78+
E_metric = PD.Error_metric(t0,tk,d0,dk,q,noise_filter)
79+
L_m = PD.L_M(E_metric,fit.size)
80+
return L_m
81+
82+
83+
84+
Minimize_partial = partial(Minimise)
85+
mini = minimize(Minimize_partial,p0,method='L-BFGS-B')
86+
result = fit.Minimize_res(mini.x)
87+
patch_wfe = result[1]
88+
patch_mtf = MTF(fftshift(result[0]))
89+
return patch_wfe,patch_mtf
4890

4991
def fit_patch(self):
5092

51-
upper = 1700
52-
Nx = np.arange(300,upper,self.Del)
53-
Ny = np.arange(300,upper,self.Del)
93+
upper = 1700
94+
Nx = np.arange(300,upper,self.Del)
95+
Ny = np.arange(300,upper,self.Del)
96+
if not self.parallel:
5497
for n1 in Nx :
5598
for n2 in Ny:
5699

@@ -60,9 +103,7 @@ def fit_patch(self):
60103
im0 = im0/im0.mean()
61104
imk = imk/imk.mean()
62105
imk = imreg(im0,imk)
63-
64-
65-
106+
66107
im0 = apo2d(im0,10)
67108
imk = apo2d(imk,10)
68109

@@ -91,16 +132,68 @@ def Minimise(coefficients):
91132
L_m = PD.L_M(E_metric,fit.size)
92133
return L_m
93134

94-
135+
95136
Minimise_partial = partial(Minimise)
96137
mini = scipy.optimize.minimize(Minimise_partial,p0,method= 'L-BFGS-B')
97138

98139
result = fit.Minimize_res(mini.x)
99140
self.output_WF[n2:n2+self.Del,n1:n1+self.Del] = result[1]
100-
self.output_mtf[n2:n2+self.Del,n1:n1+self.Del] = MTF(fftshift(result[0]))
141+
self.output_MTF[n2:n2+self.Del,n1:n1+self.Del] = MTF(fftshift(result[0]))
101142
hdu = fits.PrimaryHDU(self.output_WF)
102143
hdu.writeto(self.output_wf,overwrite=True)
144+
hdu = fits.PrimaryHDU(self.output_mtf)
145+
hdu.writeto(self.output_mtf,overwrite=True)
146+
else:
147+
self.patch = tools.prepare_patches(self.data,self.Del,self.Im0,self.Imk)
148+
n_workers = min(6, os.cpu_count())
149+
self.args_list = [i for i in range(len(self.patch))]
150+
self.results_parallel = list(processing.MP.simultaneous(self.run_pd, self.args_list, workers=n_workers))
151+
152+
def plot_results(self,output):
153+
154+
# change here the format of the output
155+
if not self.parallel:
156+
data_mtf = self.output_MTF
157+
data_wfe = self.output_WF
103158

159+
if self.parallel:
160+
## call here the stitching function
161+
data_mtf,data_wfe = tools.stitch_patches(self.results_parallel,self.Del)
162+
163+
164+
fig, ax = plt.subplots(1,2,figsize=(10,10))
165+
166+
im=ax[1].imshow(data_mtf,vmin=0,vmax=1,origin='lower',cmap='gray')
167+
ax[1].set_title('MTF')
168+
divider = make_axes_locatable(ax[1])
169+
cax = divider.append_axes('right',pad=0.05,size=0.03)
170+
cbar1 = plt.colorbar(im,cax=cax)
171+
cbar1.ax.tick_params(labelsize=16)
172+
cbar1.set_label('MTF',fontsize=16)
173+
ax[1].set_xlabel('[Pixels]')
174+
major_ticks = np.arange(0, 2048,350)
175+
major_ticks_y = np.arange(0, 2048,350)
176+
ax[1].set_xticks(major_ticks)
177+
ax[1].set_yticks(major_ticks_y)
178+
ax[1].tick_params(labelsize=4)
179+
180+
181+
182+
im2=ax[0].imshow(data_wfe/(2*np.pi),vmin=-0.5,vmax=1.4,origin='lower',cmap='gray')
183+
ax[0].set_xlabel('[Pixels]')
184+
ax[0].set_title('WF error[$\lambda$]')
185+
divider2 = make_axes_locatable(ax[0])
186+
cax2 = divider2.append_axes('right',pad=0.05,size=0.03) #size is the width of the color bar and pad is the fraction of the new axis
187+
cbar2=plt.colorbar(im2,cax=cax2)
188+
cbar2.ax.tick_params(labelsize=16)
189+
ax[0].tick_params(labelbottom=False,labelsize=16)
190+
plt.subplots_adjust(wspace=None, hspace=0.1)
191+
ax[0].set_xticks(major_ticks)
192+
ax[0].set_yticks(major_ticks_y)
193+
ax[0].set_ylabel('[Pixels]')
194+
#plt.subplots_adjust(wspace=None, hspace=-0.1)
195+
#plt.axis('off')
196+
plt.savefig(output,dpi=300)
104197

105198
if (__name__ == '__main__'):
106199
parser = argparse.ArgumentParser(description='PD on sub-fields')
@@ -109,7 +202,10 @@ def Minimise(coefficients):
109202
parser.add_argument('-d','--Del', help='Del',default=265)
110203
parser.add_argument('-ow','--ow', help='output_WFE')
111204
parser.add_argument('-om','--om', help='output MTF')
205+
parser.add_argument('-r','--res', help='results')
206+
parser.add_argument('-p','--parallel',choices=['True','False'],default=True)
112207
parsed = vars(parser.parse_args())
113-
st = patch_pd(pd_data='{0}'.format(parsed['input']),Del=int(parsed['Del']),co_num=int(parsed['Z']),output_wf='{0}'.format(parsed['ow']),output_mtf='{0}'.format(parsed['om']))
208+
st = patch_pd(pd_data='{0}'.format(parsed['input']),Del=int(parsed['Del']),co_num=int(parsed['Z']),output_wf='{0}'.format(parsed['ow']),output_mtf='{0}'.format(parsed['om']),parallel=bool(parsed['parallel']))
114209
st.fit_patch()
210+
st.plot_results(output='{0}'.format(parsed['res']))
115211

‎processing.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import os
2+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor
3+
class Thread:
4+
"""
5+
# Thread
6+
The threads holds the information on the function to execute in a thread or process.
7+
Provides an interface to the `future` object once submitted to an executer.
8+
"""
9+
10+
def __init__(self, func, args):
11+
self.function = func
12+
self.arguments = args
13+
self.future = None
14+
15+
def submit(self, executor: Executor):
16+
"""Start execution via executor"""
17+
if not self.is_submitted():
18+
self.future = executor.submit(self.function, self.arguments)
19+
return self
20+
21+
def is_submitted(self) -> bool:
22+
return self.future is not None
23+
24+
def is_done(self):
25+
return self.is_submitted() and self.future.done()
26+
27+
def exception(self):
28+
if not self.is_done():
29+
return None
30+
return self.future.exception()
31+
32+
def result(self):
33+
if not self.is_submitted():
34+
return None
35+
return self.future.result()
36+
37+
class MP:
38+
"""
39+
## MP Multi-Processing
40+
Class provides housekeeping / setup methods to reduce the programming overhead of
41+
spawning threads or processes.
42+
"""
43+
44+
#: Number of CPUs of the current machine
45+
NUM_CPUs = round(os.cpu_count() * 0.8)
46+
47+
@staticmethod
48+
def threaded(func, args, workers=10, raise_exception=True):
49+
"""
50+
Calls the given function in multiple threads for the set of given arguments
51+
Note that this does not spawn processes, but threads. Use this for non CPU
52+
CPU dependent tasks, i.e. I/O
53+
Method returns once all calls are done.
54+
55+
### Params
56+
- func: [Function] the function to call
57+
- args: [Iterable] the 'list' of arguments for each call
58+
- workers: [Integer] the number of concurrent threads to use
59+
- raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
60+
61+
### Returns
62+
Results from all `Threads` as list
63+
"""
64+
if len(args) == 1:
65+
return list(func(arg) for arg in args)
66+
67+
with ThreadPoolExecutor(workers) as ex:
68+
threads = [Thread(func, arg).submit(ex) for arg in args]
69+
return MP.collect_results(threads, raise_exception)
70+
71+
@staticmethod
72+
def simultaneous(func, args, workers=None, raise_exception=True):
73+
"""
74+
Calls the given function in multiple processes for the set of given arguments
75+
Note that this does spawn processes, not threads. Use this for task that
76+
depend heavily on CPU and can be done in parallel.
77+
Method returns once all calls are done.
78+
79+
### Params
80+
- func: [Function] the function to call
81+
- args: [Iterable] the 'list' of arguments for each call
82+
- workers: [Integer] the number of concurrent threads to use (Default: NUM_CPUs)
83+
- raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
84+
85+
### Returns
86+
Results from all `Threads` as list
87+
"""
88+
if len(args) == 1:
89+
return list(func(arg) for arg in args)
90+
91+
if workers is None:
92+
workers = MP.NUM_CPUs
93+
with ProcessPoolExecutor(workers) as ex:
94+
threads = [Thread(func, arg).submit(ex) for arg in args]
95+
return MP.collect_results(threads, raise_exception)
96+
97+
@staticmethod
98+
def collect_results(threads: list, raise_exception: bool = True) -> list:
99+
"""
100+
Takes a list of threads and waits for them to be executed. Collects results.
101+
102+
### Params
103+
- threads: [List<Thread>] a list of submitted threads
104+
- raise_exception: [Bool] Flag if an exception in a thread shall be raised or just logged
105+
106+
### Returns
107+
Results from all `Threads` as list
108+
"""
109+
result = []
110+
while len(threads) > 0:
111+
for thread in threads:
112+
if not thread.is_submitted():
113+
threads.remove(thread)
114+
if not thread.is_done():
115+
continue
116+
117+
if thread.exception() is not None:
118+
MP.__exception_handling(threads, thread, raise_exception)
119+
else:
120+
result.append(thread.result())
121+
threads.remove(thread)
122+
return result
123+
124+

‎resutls.png

-78 KB
Binary file not shown.

‎tools.py

+45
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,50 @@ def plot_zernike(coeff):
101101
plt.ylabel('Coefficient [$\lambda$]',fontsize=18)
102102
plt.title('Zernike Polynomials Coefficients',fontsize=18)
103103

104+
def prepare_patches(d,Del,Im0,Imk):
105+
n = d.shape[0]
106+
upper = 1700-Del
107+
lower = 300
108+
Nx = np.arange(lower,upper,Del)
109+
Ny = np.arange(lower,upper,Del)
110+
i_max = np.floor((upper-lower)/Del)+1
111+
patches = np.zeros((int(i_max**2),Del,Del,n))
112+
#output_WF =np.zeros((int(i_max**2),Del,Del))
113+
#output_mtf = np.zeros((int(i_max**2),Del,Del))
114+
k=0
115+
for n1 in Nx :
116+
for n2 in Ny:
117+
patches[k,:,:,0]=Im0[n2:n2+Del,n1:n1+Del]
118+
patches[k,:,:,1]=Imk[n2:n2+Del,n1:n1+Del]
119+
k = k+1
120+
return patches
121+
122+
123+
def stitch_patches(results,Del):
124+
data1 = [r[0] for r in results]
125+
data2 = [r[1] for r in results]
126+
upper = 1700-Del
127+
lower = 300
128+
Nx = np.arange(lower,upper,Del)
129+
Ny = np.arange(lower,upper,Del)
130+
i_max = np.floor((upper-lower)/Del)+1
131+
k = 0
132+
if len(data1)==(np.floor((upper-lower)/Del)+1)**2:
133+
134+
st_wf = np.zeros((2048,2048))
135+
st_mtf = np.zeros((2048,2048))
136+
else:
137+
raise TypeError('Check dimensions!')
138+
139+
for n1 in Nx :
140+
for n2 in Ny:
141+
st_wf[n2:n2+Del,n1:n1+Del] = data1[k]
142+
st_mtf[n2:n2+Del,n1:n1+Del] = data2[k]
143+
k=k+1
144+
return st_wf,st_mtf
145+
146+
147+
148+
104149

105150

0 commit comments

Comments
 (0)
Please sign in to comment.