Skip to content

Commit 1877839

Browse files
committed
Add instructions for training.
1 parent 9f8c21f commit 1877839

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

configs/train/README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Training configurations
2+
3+
The commands below train with the corresponding dataset, which will be downloaded from Hugging Face.
4+
5+
By commenting/uncommenting relevant sections in the configuration file, you can train the models with different architectures. By default, the model architecture uses five unrolleed iterations of ADMM for camera inversion, and UNetRes models for the pre-processor post-processor, and PSF correction.
6+
7+
With DiffuserCam, we show how to set different camera inversion methods.
8+
9+
## DiffuserCam
10+
11+
The commands below show how to train different camera inversion methods on the DiffuserCam dataset (downsampled by a factor of 2 along each dimension).For a fair comparison, all models use around 8.1M parameters.
12+
13+
```bash
14+
# unrolled ADMM
15+
python scripts/recon/train_learning_based.py -cn diffusercam
16+
17+
# Trainable inversion (FlatNet but with out adversarial loss)
18+
# -- need to set PSF as trainable
19+
python scripts/recon/train_learning_based.py -cn diffusercam \
20+
reconstruction.method=trainable_inv \
21+
reconstruction.psf_network=False \
22+
trainable_mask.mask_type=TrainablePSF \
23+
trainable_mask.L1_strength=False
24+
25+
# Unrolled ADMM with compensation branch
26+
# - adjust shapes of pre and post processors
27+
python scripts/recon/train_learning_based.py -cn diffusercam \
28+
reconstruction.psf_network=False \
29+
reconstruction.pre_process.nc=[16,32,64,128] \
30+
reconstruction.post_process.nc=[16,32,64,128] \
31+
reconstruction.compensation=[24,64,128,256,400]
32+
33+
# Multi wiener deconvolution network (MWDN)
34+
# with PSF correction built into the network
35+
python scripts/recon/train_learning_based.py -cn diffusercam \
36+
reconstruction.method=multi_wiener \
37+
reconstruction.multi_wiener.nc=[32,64,128,256,436] \
38+
reconstruction.pre_process.network=null \
39+
reconstruction.post_process.network=null \
40+
reconstruction.psf_network=False
41+
```
42+
43+
Similar to [PhoCoLens](https://phocolens.github.io/), we can train a camera inversion model that learns multiple PSFs. The training below uses the DiffuserCam dataset with its full resolution, and the number of model parameters is around 11.6M.
44+
```bash
45+
python scripts/recon/train_learning_based.py -cn diffusercam \
46+
reconstruction.method=svdeconvnet \
47+
reconstruction.pre_process.nc=[32,64,116,128] \
48+
reconstruction.psf_network=False \
49+
trainable_mask.mask_type=TrainablePSF \
50+
trainable_mask.L1_strength=False \
51+
files.downsample=1 files.downsample_lensed=1
52+
```
53+
54+
## TapeCam
55+
56+
```bash
57+
# unrolled ADMM
58+
python scripts/recon/train_learning_based.py -cn tapecam
59+
```
60+
61+
## DigiCam (Single Mask)
62+
63+
```bash
64+
# unrolled ADMM
65+
python scripts/recon/train_learning_based.py -cn digicam
66+
```
67+
68+
## DigiCam (Multiple Masks)
69+
70+
```bash
71+
# unrolled ADMM
72+
python scripts/recon/train_learning_based.py -cn digicam_multimask
73+
```
74+
75+
## DigiCam CelebA
76+
77+
```bash
78+
# unrolled ADMM
79+
python scripts/recon/train_learning_based.py -cn digicam_celeba
80+
```
81+
82+
## MultiPSF under External Illumination
83+
84+
```bash
85+
# unrolled ADMM
86+
python scripts/recon/train_learning_based.py -cn multilens_ambient
87+
```

lensless/recon/multi_wiener.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def apply(self, **kwargs):
279279
# apply to data
280280
return self.forward(self._data, **kwargs)
281281

282-
def reconstruction_error(self, prediction, lensless):
282+
def reconstruction_error(self, prediction, lensless, **kwargs):
283283
convolver = self._convolver
284284
if not convolver.pad:
285285
prediction = convolver._pad(prediction)

scripts/recon/train_learning_based.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ def train_learned(config):
670670
psf_channels = 3
671671

672672
assert config.reconstruction.direct_background_subtraction is False, "Not supported"
673-
assert config.reconstruction.learned_background_subtraction is None, "Not supported"
674-
assert config.reconstruction.integrated_background_subtraction is None, "Not supported"
673+
assert config.reconstruction.learned_background_subtraction is False, "Not supported"
674+
assert config.reconstruction.integrated_background_subtraction is False, "Not supported"
675675
assert psf_network is None, "Not supported"
676676

677677
recon = MultiWiener(

0 commit comments

Comments
 (0)