Skip to content

Commit 650a149

Browse files
ebezzamrootroot
authored
Add learnable multipsf and reorganize config files (#159)
* Small updates. * Change defaults. * Add options to set PSF from file. * Change defaults. * Start off multi psf. * Add support for svdeconvnet in training. * reformat. * reformat. * reformat. * Update notebook. * Update CHANGELOG. * More consistent learned PSFs. * Remove single psf restriction as it is worse. * Remove PSF normalization. * Add option for PSF correction. * Organize configs into subfolder, easier API for set GPUs for training. * Add instructions for training. * Update readme. * Enhance training README with dataset links and clarify architecture options * Update README. * Update README. --------- Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]>
1 parent a8b7702 commit 650a149

File tree

106 files changed

+962
-832
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+962
-832
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Added
3737
- Add parameterize and peturb to evaluate model adaptation.
3838
- PSF correction network.
3939
- Option to add noise to input image or PSF (for robustness experiments).
40+
- Learnable shift-variant forward model similar to PhoCoLens: https://phocolens.github.io/
4041

4142

4243
Changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ As **modularity** is a key feature of this toolkit, we try to support different
4747

4848
The toolkit includes:
4949

50+
* Training scripts/configuration for various learnable, physics-informed reconstruction approaches, as shown `here <https://github.com/LCAV/LenslessPiCam/blob/main/configs/train#training-configurations>`__.
5051
* Camera assembly tutorials (`link <https://lensless.readthedocs.io/en/latest/building.html>`__).
5152
* Measurement scripts (`link <https://lensless.readthedocs.io/en/latest/measurement.html>`__).
5253
* Dataset preparation and loading tools, with `Hugging Face <https://huggingface.co/bezzam>`__ integration (`slides <https://docs.google.com/presentation/d/18h7jTcp20jeoiF8dJIEcc7wHgjpgFgVxZ_bJ04W55lg/edit?usp=sharing>`__ on uploading a dataset to Hugging Face with `this script <https://github.com/LCAV/LenslessPiCam/blob/main/scripts/data/upload_dataset_huggingface.py>`__).
5354
* `Reconstruction algorithms <https://lensless.readthedocs.io/en/latest/reconstruction.html>`__ (e.g. FISTA, ADMM, unrolled algorithms, trainable inversion, , multi-Wiener deconvolution network, pre- and post-processors).
54-
* `Training script <https://github.com/LCAV/LenslessPiCam/blob/main/scripts/recon/train_learning_based.py>`__ for learning-based reconstruction.
5555
* `Pre-trained models <https://github.com/LCAV/LenslessPiCam/blob/main/lensless/recon/model_dict.py>`__ that can be loaded from `Hugging Face <https://huggingface.co/bezzam>`__, for example in `this script <https://github.com/LCAV/LenslessPiCam/blob/main/scripts/recon/diffusercam_mirflickr.py>`__.
5656
* Mask `design <https://lensless.readthedocs.io/en/latest/mask.html>`__ and `fabrication <https://lensless.readthedocs.io/en/latest/fabrication.html>`__ tools.
5757
* `Simulation tools <https://lensless.readthedocs.io/en/latest/simulation.html>`__.
File renamed without changes.

configs/benchmark_diffusercam_mirflickr.yaml configs/benchmark/diffusercam.yaml

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# python scripts/eval/benchmark_recon.py -cn benchmark_diffusercam_mirflickr
1+
# python scripts/eval/benchmark_recon.py -cn diffusercam
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

66
dataset: HFDataset
@@ -24,15 +24,15 @@ algorithms: [
2424

2525
## -- reconstructions trained on DiffuserCam measured
2626
"hf:diffusercam:mirflickr:U5+Unet8M",
27-
"hf:diffusercam:mirflickr:Unet8M+U5",
28-
"hf:diffusercam:mirflickr:TrainInv+Unet8M",
29-
"hf:diffusercam:mirflickr:MMCN4M+Unet4M",
30-
"hf:diffusercam:mirflickr:MWDN8M",
27+
# "hf:diffusercam:mirflickr:Unet8M+U5",
28+
# "hf:diffusercam:mirflickr:TrainInv+Unet8M",
29+
# "hf:diffusercam:mirflickr:MMCN4M+Unet4M",
30+
# "hf:diffusercam:mirflickr:MWDN8M",
3131
"hf:diffusercam:mirflickr:Unet4M+U5+Unet4M",
32-
"hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M",
33-
"hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M",
34-
"hf:diffusercam:mirflickr:Unet2M+MWDN6M",
35-
"hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",
32+
# "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M",
33+
# "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M",
34+
# "hf:diffusercam:mirflickr:Unet2M+MWDN6M",
35+
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",
3636
"hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN",
3737

3838
# # -- benchmark PSF error

configs/benchmark_digicam_mirflickr_single.yaml configs/benchmark/digicam.yaml

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# python scripts/eval/benchmark_recon.py -cn benchmark_digicam_mirflickr_single
1+
# python scripts/eval/benchmark_recon.py -cn digicam
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

66
dataset: HFDataset
@@ -26,19 +26,19 @@ algorithms: [
2626

2727
# # -- reconstructions trained on measured data
2828
"hf:digicam:mirflickr_single_25k:U5+Unet8M_wave",
29-
"hf:digicam:mirflickr_single_25k:Unet8M+U5_wave",
30-
"hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave",
31-
"hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave",
32-
"hf:digicam:mirflickr_single_25k:MWDN8M_wave",
33-
"hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave",
29+
# "hf:digicam:mirflickr_single_25k:Unet8M+U5_wave",
30+
# "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave",
31+
# "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave",
32+
# "hf:digicam:mirflickr_single_25k:MWDN8M_wave",
33+
# "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave",
3434
"hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave",
35-
"hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave",
36-
"hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave",
37-
"hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave",
35+
# "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave",
36+
# "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave",
37+
# "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave",
3838
"hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN",
39-
"hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips",
40-
"hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10",
41-
"hf:digicam:mirflickr_single_25k:Unet8M_wave_v2",
39+
# "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips",
40+
# "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10",
41+
# "hf:digicam:mirflickr_single_25k:Unet8M_wave_v2",
4242

4343
# ## -- reconstructions trained on other datasets/systems
4444
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",

configs/benchmark_digicam_celeba.yaml configs/benchmark/digicam_celeba.yaml

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
# python scripts/eval/benchmark_recon.py -cn benchmark_digicam_celeba
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

66

77
dataset: HFDataset
88
batchsize: 10
9-
device: "cuda:1"
9+
device: "cuda"
1010

1111
algorithms: [
1212
# "ADMM",
1313

1414
## -- reconstructions trained on measured data
1515
"hf:digicam:celeba_26k:U5+Unet8M_wave",
16-
"hf:digicam:celeba_26k:Unet8M+U5_wave",
17-
"hf:digicam:celeba_26k:TrainInv+Unet8M_wave",
18-
"hf:digicam:celeba_26k:MWDN8M_wave",
19-
"hf:digicam:celeba_26k:MMCN4M+Unet4M_wave",
20-
"hf:digicam:celeba_26k:Unet2M+MWDN6M_wave",
21-
"hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave",
22-
"hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave",
16+
# "hf:digicam:celeba_26k:Unet8M+U5_wave",
17+
# "hf:digicam:celeba_26k:TrainInv+Unet8M_wave",
18+
# "hf:digicam:celeba_26k:MWDN8M_wave",
19+
# "hf:digicam:celeba_26k:MMCN4M+Unet4M_wave",
20+
# "hf:digicam:celeba_26k:Unet2M+MWDN6M_wave",
21+
# "hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave",
22+
# "hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave",
2323
"hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave",
2424
"hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave_psfNN",
25-
"hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave",
25+
# "hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave",
2626

2727
# # -- reconstructions trained on other datasets/systems
2828
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",

configs/benchmark_digicam_mirflickr_multi.yaml configs/benchmark/digicam_multimask.yaml

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
# python scripts/eval/benchmark_recon.py -cn benchmark_digicam_mirflickr_multi
1+
# python scripts/eval/benchmark_recon.py -cn digicam_multimask
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

6-
76
dataset: HFDataset
87
batchsize: 4
98
device: "cuda:0"
@@ -21,15 +20,15 @@ huggingface:
2120
downsample: 1
2221

2322
algorithms: [
24-
"ADMM",
23+
# "ADMM",
2524

2625
## -- reconstructions trained on measured data
2726
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave",
2827
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_psfNN",
29-
"hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave",
30-
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1",
31-
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips",
32-
"hf:digicam:mirflickr_multi_25k:Unet8M_wave_v2",
28+
# "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave",
29+
# "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1",
30+
# "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips",
31+
# "hf:digicam:mirflickr_multi_25k:Unet8M_wave_v2",
3332

3433
# ## -- reconstructions trained on other datasets/systems
3534
# "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN",

configs/benchmark_digicam_mirflickr_pnp.yaml configs/benchmark/digicam_parameter_and_perturb.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# python scripts/eval/benchmark_recon.py -cn benchmark_digicam_mirflickr_pnp
1+
# python scripts/eval/benchmark_recon.py -cn digicam_parameter_and_perturb
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

66

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# python scripts/eval/benchmark_recon.py -cn multilens_ambient
2+
defaults:
3+
- defaults
4+
- _self_
5+
6+
dataset: HFDataset
7+
batchsize: 8
8+
device: "cuda:0"
9+
10+
huggingface:
11+
repo: Lensless/MultiLens-Mirflickr-Ambient
12+
cache_dir: /dev/shm
13+
psf: psf.png
14+
image_res: [600, 600] # used during measurement
15+
rotate: False # if measurement is upside-down
16+
alignment:
17+
top_left: [118, 220] # height, width
18+
height: 123
19+
use_background: True
20+
21+
## -- reconstructions trained with same dataset/system
22+
algorithms: [
23+
# "ADMM",
24+
"hf:multilens:mirflickr_ambient:U5+Unet8M",
25+
# "hf:multilens:mirflickr_ambient:U5+Unet8M_direct_sub",
26+
# "hf:multilens:mirflickr_ambient:U5+Unet8M_learned_sub",
27+
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M",
28+
# "hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_direct_sub",
29+
# "hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_learned_sub",
30+
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_concat",
31+
# "hf:multilens:mirflickr_ambient:TrainInv+Unet8M",
32+
# "hf:multilens:mirflickr_ambient:TrainInv+Unet8M_learned_sub",
33+
# "hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M",
34+
# "hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_learned_sub",
35+
# "hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_concat",
36+
# "hf:multilens:mirflickr_ambient:TrainInv+Unet8M_direct_sub",
37+
# "hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_direct_sub",
38+
]
39+
40+
save_idx: [
41+
1, 2, 4, 5, 9, 64, # bottom right
42+
# 2141, 2155, 2162, 2225, 2502, 2602, # top right (door, flower, cookies, wolf, plush, sky)
43+
# 3262, 3304, 3438, 3451, 3644, 3667 # bottom left (pancakes, flower, grapes, pencils, bird, sign)
44+
]
45+
n_iter_range: [100] # for ADMM

configs/benchmark_tapecam_mirflickr.yaml configs/benchmark/tapecam.yaml

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# python scripts/eval/benchmark_recon.py -cn benchmark_tapecam_mirflickr
1+
# python scripts/eval/benchmark_recon.py -cn tapecam
22
defaults:
3-
- benchmark
3+
- defaults
44
- _self_
55

66
dataset: HFDataset
77
batchsize: 4
8-
device: "cuda:1"
8+
device: "cuda:0"
99

1010
huggingface:
1111
repo: "bezzam/TapeCam-Mirflickr-25K"
@@ -27,17 +27,17 @@ algorithms: [
2727

2828
# -- reconstructions trained on measured data
2929
"hf:tapecam:mirflickr:U5+Unet8M",
30-
"hf:tapecam:mirflickr:Unet8M+U5",
31-
"hf:tapecam:mirflickr:TrainInv+Unet8M",
32-
"hf:tapecam:mirflickr:MMCN4M+Unet4M",
30+
# "hf:tapecam:mirflickr:Unet8M+U5",
31+
# "hf:tapecam:mirflickr:TrainInv+Unet8M",
32+
# "hf:tapecam:mirflickr:MMCN4M+Unet4M",
3333
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M",
34-
"hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M",
35-
"hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M",
36-
"hf:tapecam:mirflickr:Unet4M+U10+Unet4M",
37-
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10",
38-
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1",
39-
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips",
40-
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10",
34+
# "hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M",
35+
# "hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M",
36+
# "hf:tapecam:mirflickr:Unet4M+U10+Unet4M",
37+
# "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10",
38+
# "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1",
39+
# "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips",
40+
# "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10",
4141
"hf:tapecam:mirflickr:Unet4M+U5+Unet4M_psfNN",
4242

4343
# # below models need `single_channel_psf = True`

configs/benchmark_multilens_mirflickr_ambient.yaml

-45
This file was deleted.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

configs/demo.yaml

+12-10
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ rpi:
1616
display:
1717
# default to this screen: https://www.dell.com/en-us/work/shop/dell-ultrasharp-usb-c-hub-monitor-u2421e/apd/210-axmg/monitors-monitor-accessories#techspecs_section
1818
screen_res: [1920, 1200] # width, height
19-
image_res: null
19+
image_res: [900, 1200]
2020
pad: 0
2121
hshift: 0
22-
vshift: -10
22+
vshift: -30
2323
brightness: 100
2424
rot90: 3
2525

@@ -57,7 +57,7 @@ capture:
5757
# remote script returns RGB data
5858
rgb: True
5959
down: 4
60-
awb_gains: [1.6, 1.2]
60+
awb_gains: [1.7, 1.3]
6161

6262

6363
camera:
@@ -77,7 +77,7 @@ camera:
7777

7878

7979
recon:
80-
gamma: 2.2 # for visualization
80+
gamma: null
8181
downsample: 4
8282
dtype: float32
8383
use_torch: True
@@ -93,7 +93,7 @@ recon:
9393

9494
# -- admm
9595
admm:
96-
n_iter: 100
96+
n_iter: 10
9797
disp_iter: null
9898
mu1: 1e-6
9999
mu2: 1e-5
@@ -115,8 +115,10 @@ recon:
115115
# model_path: models/wallerlab_unet_inversion.pb
116116
input_shape: [1, 270, 480, 3]
117117

118-
postproc:
119-
# crop in percent to extract region of interest
120-
# set to null to skip
121-
crop_hor: [0.22, 0.71]
122-
crop_vert: [0., 0.86]
118+
postproc:
119+
# crop_hor: null
120+
# crop_vert: null
121+
# # crop in percent to extract region of interest
122+
# # set to null to skip
123+
crop_hor: [0.28, 0.75]
124+
crop_vert: [0.2, 0.82]

0 commit comments

Comments
 (0)