Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fista algorithm : accuracy problem #114

Open
aTrotier opened this issue Oct 5, 2022 · 30 comments
Open

Fista algorithm : accuracy problem #114

aTrotier opened this issue Oct 5, 2022 · 30 comments

Comments

@aTrotier
Copy link
Contributor

aTrotier commented Oct 5, 2022

Compressed sensing is now really fast with FISTA : #102 (comment)

But there is still an issue with accuracy (#102 (reply in thread)). We observed it on phantom : https://atrotier.github.io/MRIRecoVsBART_Benchmark/test_bart.html

And I also have an issue on real MP2RAGE data : https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/
I am not able to reach the same image quality as BART.

If I use ADMM, it works well !

@andrewwmao
Copy link

andrewwmao commented Oct 5, 2022

Hi aTrotier,
I looked at the notebooks that you linked, and it seems that the second notebook using the in vivo MP2RAGE data is using an old version of 'RegularizedLeastSquares'. I contributed a bugfix to the FISTA algorithm in comparison to BART that should have fixed the accuracy problem. Can you try updating this to v0.8.7?

For the first notebook, my feeling is that there may be an issue with the regularization strength/implementation of the regularizer. Lambdas won't be comparable between Julia/BART esp. when both packages are scaling lambda in different ways. Perhaps you can try testing both BART/MRIReco FISTA implementations setting lambda close to zero to see if they give the same results.

@tknopp
Copy link
Member

tknopp commented Oct 6, 2022

I don't know if it's just the different scaling but @migrosser also reported that he got better results by just taking another regularization parameter. It would be great if someone could have a deeper look. But probably we should keep this on hold and first look that RegularizedLeastSquares is consistent.

If somebody wants to give this a go: MRIRecoBenchmarks now includes the example script: https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2
It will generate timings and outputs images. Errors are currently not calculated but that would be easy to add.

Regarding the scaling: In my opinion this is primarily a documentation issue. We need to document somewhere how to translate a regularization parameter scaled with BART to MRIReco.

@JakobAsslaender
Copy link
Member

So if I understand it correctly, the return value of BART depends whether it terminates based on max number of iteration, in which case it returns after the gradient step, i.e. half way through Eq. 4.1 in the FISTA paper:

https://github.com/mrirecon/bart/blob/5428c0ae9f6cdb1667b549323802682ce1171bd9/src/iter/italgos.c#L252

but if the residual is smaller than the termination threshold, they return after adding the momentum (ravine) step, i.e. after Eq. 4.3 in the FISTA paper:

https://github.com/mrirecon/bart/blob/5428c0ae9f6cdb1667b549323802682ce1171bd9/src/iter/italgos.c#L241

Not sure why they distinguish between these two cases, but maybe @JeffFessler has thoughts?

We currently return after applying the proximal operator, i.e. after Eq 4.1. I think it would be easy to change the behavior to termination after the gradient step by moving line

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L138

below

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L150

or after the momentum step by moving it below

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L146

but I am not sure that the done function supports a distinction based why we are terminating.

Do you, @JeffFessler or @andrewwmao have comments on what is the "right" termination point?

@tknopp
Copy link
Member

tknopp commented Oct 6, 2022

Intuitively, I would have said after the prox step because only then the solution lies in the desired subspace (as Jeff has said). Making it dependent on maxiter seems to be suboptimal. We should not replicate that.

@JakobAsslaender
Copy link
Member

Well in this case the current implementation is correct :). @aTrotier : Have you played with lambda and maxiter to see if either of those resolve the issue? Note that there is a factor 2 difference in lambda between ADMM and FISTA as @andrewwmao pointed out to me. We should probably resolve this to make the algorithms more comparable.

@JeffFessler
Copy link
Member

I have trouble understanding the BART code, but if @JakobAsslaender's reading of it is correct then I think it is a "bug" in BART. For FISTA, the proper thing to return (whenever one stops) is the output of the prox update. (Whether this is "x" or "y" depends on the paper's notation BTW.)

I've double-checked that we do this properly in MIRT:
https://github.com/JeffFessler/MIRT.jl/blob/26cfbc2a26b26814fa85739dbe01b5f1b8be5e21/src/algorithm/general/pogm_restart.jl#L262

BTW, I'd recommend POGM instead of FISTA for the problem at hand.
See Fig. 3 of my survey paper on optimization for MRI: http://doi.org/10.1109/MSP.2019.2943645

There is code for reproducing that figure here:
https://github.com/JeffFessler/mirt-demo/blob/main/isbi-19/01-recon.ipynb
and a Documenter-type example of it here:
https://juliaimagerecon.github.io/Examples/generated/mri/2-cs-wl-l1-2d/

The main issue is that currently that pogm code is buried in MIRT.jl which is too large, though I am working on paring it down into smaller packages. In the long run I should contribute POGM to RegLS it seems, or to some optimization package somewhere. It would benefit from some optimization like in-place ops that I haven't had time to do...

@JakobAsslaender
Copy link
Member

@JeffFessler : I would loooooovvvvveeeee to throw POGM at our data! But so far I have been hesitant because of the different interface (compare to RLS.jl) and the lack of in-place operations etc. Do I understand correctly that the algorithm builds on FISTA? Maybe the easiest way to do those optimizations would be to copy the RSL.jl implementation of FISTA and turn it into POGM? We can, of course, also think about converting the existing FISTA implementation into a super-function similar to yours, but I am not sure about the best compromise of code duplication vs. speed and readability. Let me know if I could be of help, just not sure that I have enough knowledge about POGM to do the job...

@aTrotier
Copy link
Contributor Author

aTrotier commented Oct 7, 2022

@JakobAsslaender @JeffFessler After playing around I think you are right. I am able to get something close to the BART implementation of fista with a little bit much of noise which might then be related to #102 (reply in thread)
In my first tests, I wanted to reduce the noise by increasing the lambda value which creates the threshold effect which is suppose to happen (BART misleads me in this case)

Maybe @uecker can give some advice about BART implementation and why they don't send the image after soft-thresholding.

Something to mentions : With BART most of the time I don't have to change a play a lot with the parameters to make fista works (lambda is generally close to 0.01). I guess the pre-scaling operation helps.

@tknopp
Copy link
Member

tknopp commented Oct 7, 2022

I guess the pre-scaling operation helps.
The option params[:normalizeReg] = true is actually supposed to make things independent of the input data but probably it is not enough. So there is certainly a TODO item left.

@aTrotier
Copy link
Contributor Author

aTrotier commented Oct 7, 2022

Ok I think BART is doing that the other way, they scale the input image (#92) rather than the lambda value.

edit : results after playing with parameters https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/

@JeffFessler
Copy link
Member

throw POGM

Thanks for the encouragement. I will first write a version of pogm that is streamlined (putting pogm and fista and pgm in one function is too messy) and uses in-place ops etc. It will be for a general composite cost function and I will illustrate how to use it using a regularized LS problem. Then we can decide if/how to make a wrapper in RLS.jl to call it just as easily as you call FISTA.

@aTrotier
Copy link
Contributor Author

Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.

@tknopp
Copy link
Member

tknopp commented Oct 10, 2022

Then we probably want to have options to refine the wavelet transform. Don't know how we should approach this but one would first start making the WaveletOp more general and then introduce some new high-level parameters.

@aTrotier
Copy link
Contributor Author

Actually it cannot be the issue in the benchmark because the 3D datasets dimension are the same along each axis. But it is something that might impact the image quality for non-square matrix like in the example : https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/examples/mridataorg/example.jl

@andrewwmao
Copy link

@aTrotier have you run this notebook recently with the latest version of RLS? And also the FISTA recon with params2[:ρ] = 0.95? I am having trouble getting your binder to work.

@aTrotier
Copy link
Contributor Author

No, I think it is outdated (before the splitting of the package). I will update it

@aTrotier
Copy link
Contributor Author

aTrotier commented Jun 14, 2023

I create a rapid test with your pogm branch of RegularizedLeastSquared.jl and my PR for MRIReco that gives this results :
compare.pdf

Results between BART and MRIReco can be closed. However, if I increase the number of iteration, bart seems to converge whereas MRIReco (fista and even admm) increase the noise level. For fista, the noise amplification is really fast regarding the number of iteration.

By the way @andrewwmao, pogm and optista works but also gives the same results when increasing the number of iteration.

Maybe, we should work on the benchmark test rather than my real MP2RAGE datasets for quantitative metrics : https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2

Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.

At least it does not seems really related to the wavelet implementation.

@tknopp @JeffFessler Do you have some thoughts about that ?

using MRIReco, MRIFiles, MRICoilSensitivities
using BartIO, QuantitativeMRI
using CairoMakie
include("utils_MP2RAGE.jl")
## Setup BartIO and Global variable
set_bart_path("/usr/local/bin/bart")

slice = 25 # slice to show


## load data
b = BrukerFile("data/LR_3T_CS4")
raw = RawAcquisitionData_MP2RAGE_CS(b); # create an object with function in utils_MP2RAGE.jl
acq = AcquisitionData(raw,OffsetBruker = true)


## plot the mask 

begin# check mask
	mask = zeros(acq.encodingSize[1],acq.encodingSize[2],acq.encodingSize[3]);
	for i =1:length(acq.subsampleIndices[1]);
	  mask[acq.subsampleIndices[1][i]]=1;
	end 
	heatmap!(Axis(Figure()[1,1],aspect=1), mask[64,:,:,1],colormap = :grays)
  current_figure()
end

## CoilSensitivities
sens = espirit(acq)

imMP_MRIReco_fista = Vector{Array{Float32,3}}()
imMP_MRIReco_admm = Vector{Array{Float32,3}}()
imMP_pics = Vector{Array{Float32,3}}()
iter_vec = (1,5,10,15,20,30,50)
for iter in iter_vec 
	# Then Wavelet
	params2 = Dict{Symbol, Any}()
	params2[:reco] = "multiCoil"
	params2[:reconSize] = acq.encodingSize
	params2[:senseMaps] = sens;
	
	params2[:solver] = "fista"
	params2[:sparseTrafoName] = "Wavelet"
	params2[:regularization] = "L1"
	params2[] = 0.01# 5.e-2
	params2[:iterations] = iter
	params2[:normalize_ρ] = true
	params2[] = 0.95
	params2[:normalizeReg] = true
	
	I_wav = reconstruction(acq, params2);
  push!(imMP_MRIReco_fista,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))
  
  params2[:solver] = "admm"
  I_wav = reconstruction(acq, params2);
  push!(imMP_MRIReco_admm,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))
  #heatmap(imMP_wav[:,:,slice],colormap=:grays,axis= (;title="MRIReco wav, iter = $iter"))

  ## compare to bart

  k_bart = kDataCart(acq)
  k_bart = permutedims(k_bart,(1,2,3,4,6,5))
  size(k_bart)

  im_pics = bart(1,"pics -e -S -i $iter -R W:7:0:0.01",k_bart,sens);
  im_pics = permutedims(im_pics,(1,2,3,6,4,5));
  im_pics = im_pics[:,:,:,:,:,1];
  push!(imMP_pics,mp2rage_comb(im_pics[:,:,:,:,1]))
end

f = Figure(resolution=(400,600))
ga = f[1,1] = GridLayout()
asp = 128/96
for i in 1:length(imMP_pics)
  
  ax1 = Axis(ga[i,1],aspect=asp)
  hidedecorations!(ax1)
  heatmap!(ax1,imMP_MRIReco_fista[i][:,:,slice],colormap=:grays)
 

  ax2 = Axis(ga[i,2],aspect=asp)
  hidedecorations!(ax2)
  heatmap!(ax2,imMP_MRIReco_admm[i][:,:,slice],colormap=:grays)

  ax3 = Axis(ga[i,3],aspect=asp)
  hidedecorations!(ax3)
  heatmap!(ax3,imMP_pics[i][:,:,slice],colormap=:grays)

  Label(ga[i,0],"iter = $(iter_vec[i])",tellheight = false)

  if i == 1
    ax1.title = "MRIReco \n fista"
    ax2.title = "MRIReco \n admm"
    ax3.title = "bart \n fista"
  end
  rowsize!(ga,i,75)
end
rowgap!(ga,0)
f

save("compare.pdf",f)

@JakobAsslaender
Copy link
Member

JakobAsslaender commented Jun 14, 2023

To what degree did you fine tune lambda? As discussed earlier, I don't think we can assume that the same lambda is optimal for BART and Julia. But it would maybe be nice to match the implementations, i.e. re-create the BART normalization in Julia instead of the current norm_reg implementation. Thoughts, @tknopp?

Regrading the comparison between FISTA and ADMM: What does "iteration" mean in the MRIReco interface (sorry, I never use that interface)? When you run both FISTA and ADMM long enough with the right lambda, they should converge to the same solution, assuming that the norm_reg is doing the same thing in both algorithms.

Last, I would suggest to benchmark it with a different regularization (e.g. TV) to avoid the known difference in the wavelet implementation.

@andrewwmao
Copy link

andrewwmao commented Jun 14, 2023

My guess is something is probably wrong with params2[:normalize_ρ] = true, i.e. in the calculation of the Lipschitz constant. That would explain why ISTA/POGM both seem to fail whereas ADMM seems to give a good result. But since I am also not using the high-level interface it is difficult for me to say where this problem occurs.
This option is also 'false' in the above mentioned benchmark, where FISTA appears to be working fine.

For ADMM it's difficult at a glance to say what's going on there. Certainly there is a factor 2 difference in the appropriate lambda to use w.r.t. FISTA/POGM, and the parameter rho also has a different meaning there. This could probably be fixed with some appropriate tuning.

@aTrotier
Copy link
Contributor Author

aTrotier commented Jun 14, 2023 via email

@aTrotier
Copy link
Contributor Author

For the scaling issue I can also force bart to scale the data with 1

@JeffFessler
Copy link
Member

I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.

@aTrotier
Copy link
Contributor Author

aTrotier commented Jun 15, 2023

For the scaling issue I can also force bart to scale the data with 1

Indeed it was the scaling / fine tuning the lambda value, if I force the scaling of data to 1 in bart (option -w 1) I get similar noise amplification :
compare.pdf

edit : Weirdly, If I inverse scale acq.kdata by the value calculated by BART, the results converge for BART but not for MRIReco. I have to try with the low level interface

@JakobAsslaender
Copy link
Member

I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.

@JeffFessler: I didn't have the Lipschitz constant on my radar in this context. Would you mind double checking that this is calculated correctly?

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/FISTA.jl#L68

And the called function can be found here:

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/Utils.jl#L294

@tknopp
Copy link
Member

tknopp commented Jun 15, 2023

So, I am not really sure how to move forward here. I am not sure if this is possible but I wonder if it makes sense to first improve our test cases in RegularizedLeastSquares to be sure that there are no issues on that level? Furthermore, we might want to improve the documentation and define more clearly the semantics of the algorithms. So that we clearly define what optimization problem is being solved and what normalizations are being done. Does that make sense?

On the other side it seems that it might be worth to translate the MRIReco.jl reconstruction code from @aTrotier to a low-level interface so that it becomes clearer for @JakobAsslaender and @andrewwmao?

These are just ideas. I don't have so much coding capacities right now unfortunately.

@aTrotier
Copy link
Contributor Author

I have a benchmarkwith shepp logan phantom for MRIReco high level / low level and BART reconstruction here : https://github.com/aTrotier/MRIReco_Accuray_fista which gives the following results :

From metrics and qualitative evaluation High and low level gives approximatively the same results : still some residual artefacts (visible on img not on the the RMSE metrics) compare to the bart reconstruction
compare_metrics.pdf
compare_img.pdf

@JeffFessler
Copy link
Member

you mind double checking

If the data is (possibly under-sampled) Cartesian, and if the encoding matrix uses the unitary DFT (with no B0 correction), and if the sensitivity maps are normalized so that the SSoS = 1, then the Lipschitz constant is 1, then there is no need to run the power iteration. I have seen situations where code set it to be 1 but one of those three "ifs" was not satisfied, leading to problems. I didn't realize that here we are always (?) using the power iteration, so that way should always be safe.

Probably we could use a smarter initial guess than randn to reduce iterations a bit, because often the principle eigenvector is quite smooth (e.g., when low frequencies are heavily sampled like in radial) but otherwise the code looks fine.
https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/Utils.jl#L295

@JakobAsslaender
Copy link
Member

Thanks for looking into this! To answer your question: we are running power iterations if the flag normalize_ρ=true, which it is by default. @aTrotier: Can you check that the high-level wrappers don't overwrite the default?

Would ones be a better initialization? Or did you have something else in mind?

@JeffFessler
Copy link
Member

Would ones be a better initialization? Or did you have something else in mind?

I was thinking ones but I've never done any serious testing of it so caveat emptor...

@aTrotier
Copy link
Contributor Author

@aTrotier: Can you check that the high-level wrappers don't overwrite the default?

it does not. Anyway, I forced it to true : https://github.com/aTrotier/MRIReco_Accuray_fista/blob/c6ed10cbe97477086f8e214c8a127f7b6a5d73bb/Accuracy_fista.jl#L81

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants