Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding progress bars, on the fly plotting/saving and density matrix Observables #17

Merged
merged 11 commits into from
Feb 21, 2025
Merged
235 changes: 99 additions & 136 deletions Manifest.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Documenter = "1.8"
Documenter = "1.8"
4 changes: 2 additions & 2 deletions docs/src/examples/puredephasing.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ This initial state is a product state between the system and the chain. It is co
# MPO and initial state MPS
#---------------------------

H = puredephasingmpo(ΔE, d, N, cpars)
H = puredephasingmpo(ω0, d, N, cpars)

# Initial electronic system in a superposition of 1 and 2
ψ = zeros(2)
Expand Down Expand Up @@ -142,7 +142,7 @@ A, dat = runsim(dt, tfinal, A, H, prec=1E-4;
method = method,
obs = [ob1],
convobs = [ob1],
params = @LogParams(ΔE, N, d, α, s),
params = @LogParams(ω0, N, d, α, s),
convparams = D,
reduceddensity=true,
verbose = false,
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/timedep.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ ob3 = TwoSiteObservable("SXdisp", sx, disp(d), [1], collect(2:N+1))
A, dat = runsim(dt, tfinal, A, H;
name = "Driving field on ohmic spin boson model",
method = method,
obs = [ob1],
obs = [ob1, ob2, ob3],
convobs = [ob1],
params = @LogParams(N, d, α, Δ, ω0, s),
convparams = D,
Expand All @@ -176,7 +176,7 @@ A, dat = runsim(dt, tfinal, A, H;
plot = true,
);
```
Eventually, the stored observables can be represented. For more information about the chain observables, see [Inspecting the bath by undoing the chain mapping]
Eventually, the stored observables can be represented. For more information about the chain observables, see [Inspecting the bath by undoing the chain mapping](@ref)
```julia

#----------
Expand Down
2 changes: 1 addition & 1 deletion docs/src/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ Pages = ["MPSDynamics.jl", "chain2TDVP.jl", "chainA1TDVP.jl","chainDMRG.jl","cha
## Advanced
```@autodocs
Modules = [MPSDynamics]
Pages = ["flattendict.jl", "logging.jl", "logiter.jl","machines.jl"]
Pages = ["flattendict.jl", "logging.jl", "logiter.jl","machines.jl","utilities.jl"]
```
11 changes: 9 additions & 2 deletions src/MPSDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include("run_DTDVP.jl")
include("chainA1TDVP.jl")
include("switchmpo.jl")
include("finitetemperature.jl")
include("utilities.jl")

"""
runsim(dt, tmax, A, H;
Expand Down Expand Up @@ -87,12 +88,16 @@ function runsim(dt, tmax, A, H;
convparams = typeof(convparams) <: Vector ? only(convparams) : convparams
end

if save || plot
if save || plot || (:onthefly in keys(kwargs) && !isempty(kwargs[:onthefly][:save_obs]) && kwargs[:onthefly][:savedir] == "auto")
if savedir[end] != '/'
savedir = string(savedir,"/")
end
isdir(savedir) || mkdir(savedir)
open_log(dt, tmax, convparams, method, machine, savedir, unid, name, params, obs, convobs, convcheck, kwargs...)
if :onthefly in keys(kwargs)
mkdir(string(savedir, unid, "/tmp/"))
kwargs[:onthefly][:savedir] = string(savedir, unid, "/tmp/")
end
end

paramdict = Dict([[(par[1], par[2]) for par in params]...,
Expand Down Expand Up @@ -155,7 +160,7 @@ export chaincoeffs_ohmic, spinbosonmpo, methylbluempo, methylbluempo_correlated,

export productstatemps, physdims, randmps, bonddims, elementmps

export measure, measurempo, OneSiteObservable, TwoSiteObservable, FockError, errorbar
export measure, measurempo, OneSiteObservable, TwoSiteObservable, RhoReduced, FockError, errorbar

export runsim, run_all

Expand All @@ -175,5 +180,7 @@ export rhoreduced_2sites, rhoreduced_1site, protontransfermpo

export chaincoeffs_finiteT, chaincoeffs_fermionic, fermionicspectraldensity_finiteT, chaincoeffs_finiteT_discrete

export onthefly, mergetmp

end

12 changes: 12 additions & 0 deletions src/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ struct TwoSiteObservable <: Observable
allsites::Bool
end

"""
RhoReduced(name,sites)

Computes the reduced density matrix on the sites `sites` which can be either a single site or a tuple of two sites. Used to define
reduced density matrices that are obs and convobs parameters for the `runsim` function.
"""
struct RhoReduced <: Observable
name::String
sites::Union{Int, Tuple{Int, Int}}
end

struct CdagCup <: Observable
name::String
sites::Tuple{Int,Int}
Expand Down Expand Up @@ -183,6 +194,7 @@ measure(A, O::TwoSiteObservable, ::Nothing) =
measure2siteoperator(A, O.op1, O.op2, O.sites1, O.sites2)
measure(A, O::TwoSiteObservable, ρ::Vector) =
measure2siteoperator(A, O.op1, O.op2, O.sites1, O.sites2, ρ)
measure(A, O::RhoReduced; kwargs...) = O.sites isa Int ? rhoreduced_1site(A, O.sites) : rhoreduced_2sites(A, O.sites)

"""
measure1siteoperator(A::Vector, O, sites::Vector{Int})
Expand Down
17 changes: 13 additions & 4 deletions src/run_1TDVP.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function run_1TDVP(dt, tmax, A, H, Dmax; obs=[], timed=false, reduceddensity=false, timedep=false, kwargs...)
function run_1TDVP(dt, tmax, A, H, Dmax; obs=[], timed=false, reduceddensity=false, timedep=false, progressbar=true, onthefly=Dict(), kwargs...)
A0=deepcopy(A)
H0=deepcopy(H)
data = Dict{String,Any}()
Expand All @@ -19,12 +19,19 @@ function run_1TDVP(dt, tmax, A, H, Dmax; obs=[], timed=false, reduceddensity=fal
end

timed && (ttdvp = Vector{Float64}(undef, numsteps))
onthefly = copy(onthefly)
ontheflyplotbool = !isempty(onthefly) && !isnothing(onthefly[:plot_obs])
ontheflysavebool = !isempty(onthefly) && !isempty(onthefly[:save_obs])
ontheflysavebool && (onthefly[:save_obs] = intersect(onthefly[:save_obs], [ob.name for ob in obs]))

F=nothing
mpsembed!(A0, Dmax)
for tstep=1:numsteps
@printf("%i/%i, t = %.3f ", tstep, numsteps, times[tstep])
println()
iter = progressbar ? ProgressBar(numsteps; ETA=true) : 1:numsteps
for tstep in iter
if !progressbar
@printf("%i/%i, t = %.3f ", tstep, numsteps, times[tstep])
println()
end
if timedep
Ndrive = kwargs[:Ndrive]
Htime = kwargs[:Htime]
Expand Down Expand Up @@ -55,6 +62,8 @@ function run_1TDVP(dt, tmax, A, H, Dmax; obs=[], timed=false, reduceddensity=fal
exprho = rhoreduced_1site(A0,Nrho)
data["Reduced ρ"] = cat(data["Reduced ρ"], exprho; dims=ndims(exprho)+1)
end
ontheflyplotbool && tstep%onthefly[:step] == 0 && ontheflyplot(onthefly, tstep, times, data)
ontheflysavebool && tstep%onthefly[:step] == 0 && ontheflysave(onthefly, tstep, times, data)
end
timed && push!(data, "deltat"=>ttdvp)
push!(data, "times" => times)
Expand Down
21 changes: 16 additions & 5 deletions src/run_2TDVP.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function run_2TDVP(dt, tmax, A, H, truncerr; obs=[], Dlim=50, savebonddims=false, timed=false, reduceddensity=false, timedep=false, kwargs...)
function run_2TDVP(dt, tmax, A, H, truncerr; obs=[], Dlim=50, savebonddims=false, timed=false, reduceddensity=false, timedep=false, progressbar=true, onthefly=Dict(), kwargs...)
A0=deepcopy(A)
H0=deepcopy(H)
data = Dict{String,Any}()
Expand All @@ -22,12 +22,19 @@ function run_2TDVP(dt, tmax, A, H, truncerr; obs=[], Dlim=50, savebonddims=false
savebonddims && push!(data, "bonddims" => reshape([bonds...], length(bonds), 1))

timed && (ttdvp = Vector{Float64}(undef, numsteps))
onthefly = copy(onthefly)
ontheflyplotbool = !isempty(onthefly) && !isnothing(onthefly[:plot_obs])
ontheflysavebool = !isempty(onthefly) && !isempty(onthefly[:save_obs])
ontheflysavebool && (onthefly[:save_obs] = intersect(onthefly[:save_obs], [ob.name for ob in obs]))

F=nothing
for tstep=1:numsteps
maxbond = max(bonds...)
@printf("%i/%i, t = %.3f, Dmax = %i ", tstep, numsteps, times[tstep], maxbond)
println()
iter = progressbar ? ProgressBar(numsteps; ETA=false) : 1:numsteps
for tstep in iter
if !progressbar
maxbond = max(bonds...)
@printf("%i/%i, t = %.3f, Dmax = %i ", tstep, numsteps, times[tstep], maxbond)
println()
end
if timedep
Ndrive = kwargs[:Ndrive]
Htime = kwargs[:Htime]
Expand All @@ -52,6 +59,8 @@ function run_2TDVP(dt, tmax, A, H, truncerr; obs=[], Dlim=50, savebonddims=false
A0, F = tdvp2sweep!(dt, A0, H0, F; truncerr=truncerr, truncdim=Dlim, kwargs...)
end
bonds = bonddims(A0)
progressbar && (maxbond = max(bonds...))
progressbar && (iter.Dmax = maxbond)
exp = measure(A0, obs; t=times[tstep])
for (i, ob) in enumerate(obs)
data[ob.name] = cat(data[ob.name], exp[i], dims=ndims(exp[i])+1)
Expand All @@ -63,6 +72,8 @@ function run_2TDVP(dt, tmax, A, H, truncerr; obs=[], Dlim=50, savebonddims=false
if savebonddims
data["bonddims"] = cat(data["bonddims"], [bonds...], dims=2)
end
ontheflyplotbool && tstep%onthefly[:step] == 0 && ontheflyplot(onthefly, tstep, times, data)
ontheflysavebool && tstep%onthefly[:step] == 0 && ontheflysave(onthefly, tstep, times, data)
end
timed && push!(data, "deltat"=>ttdvp)
push!(data, "times" => times)
Expand Down
20 changes: 16 additions & 4 deletions src/run_DTDVP.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function run_DTDVP(dt, tmax, A, H, prec; obs=[], effects=false, error=false, timed=false, savebonddims=false, Dplusmax=nothing, Dlim=50, reduceddensity=false, timedep=false, kwargs...)
function run_DTDVP(dt, tmax, A, H, prec; obs=[], effects=false, error=false, timed=false, savebonddims=false, Dplusmax=nothing, Dlim=50, reduceddensity=false, timedep=false, progressbar=true, onthefly=Dict(), kwargs...)
A0=deepcopy(A)
H0=deepcopy(H)
data = Dict{String,Any}()
Expand All @@ -25,12 +25,19 @@ function run_DTDVP(dt, tmax, A, H, prec; obs=[], effects=false, error=false, tim
timed && (ttdvp = Vector{Float64}(undef, numsteps))
timed && (tproj = Vector{Float64}(undef, numsteps))
effects && (efft = Vector{Any}(undef, numsteps))
onthefly = copy(onthefly)
ontheflyplotbool = !isempty(onthefly) && !isnothing(onthefly[:plot_obs])
ontheflysavebool = !isempty(onthefly) && !isempty(onthefly[:save_obs])
ontheflysavebool && (onthefly[:save_obs] = intersect(onthefly[:save_obs], [ob.name for ob in obs]))

F=nothing
Afull=nothing
for tstep=1:numsteps
maxbond = max(bonds...)
@printf("%i/%i, t = %.3f, Dmax = %i \n", tstep, numsteps, times[tstep], maxbond)
iter = progressbar ? ProgressBar(numsteps; ETA=false) : 1:numsteps
for tstep in iter
if !progressbar
maxbond = max(bonds...)
@printf("%i/%i, t = %.3f, Dmax = %i \n", tstep, numsteps, times[tstep], maxbond)
end
if timedep
Ndrive = kwargs[:Ndrive]
Htime = kwargs[:Htime]
Expand Down Expand Up @@ -58,6 +65,9 @@ function run_DTDVP(dt, tmax, A, H, prec; obs=[], effects=false, error=false, tim

exp = info["obs"]
bonds = info["dims"]
progressbar && (maxbond = max(bonds...))
progressbar && (iter.Dmax = maxbond)

effects && (efft[tstep] = reduce(hcat, info["effect"]))
error && (errs[tstep] = info["err"])
timed && (ttdvp[tstep] = info["t2"] + info["t3"])
Expand All @@ -71,6 +81,8 @@ function run_DTDVP(dt, tmax, A, H, prec; obs=[], effects=false, error=false, tim
exprho = rhoreduced_1site(A0,Nrho)
data["Reduced ρ"] = cat(data["Reduced ρ"], exprho; dims=ndims(exprho)+1)
end
ontheflyplotbool && (tstep-1)%onthefly[:step] == 0 && ontheflyplot(onthefly, tstep-1, times, data)
ontheflysavebool && (tstep-1)%onthefly[:step] == 0 && ontheflysave(onthefly, tstep-1, times, data)
end
if savebonddims
data["bonddims"] = cat(data["bonddims"], bonds, dims=2)
Expand Down
130 changes: 130 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
mutable struct ProgressBar
numsteps::Int
ETA::Bool
times::Vector{<:Float64}
Dmax::Int
length::UInt
end

"""
ProgressBar(numsteps::Int; ETA=false, last=10)

An iterable returning values from 1 to `numsteps`. Displays a progress bar for the for loop where it has been called.
If ETA is true then displays an estimation of the remaining time calculated based on the time spent computing the last `last` values.
The progress bar can be activated or deactivated by setting the `progressbar` keyword argument in `runsim` to true or false."""
ProgressBar(numsteps::Int; ETA=false, last=10) = ProgressBar(numsteps, ETA, fill(0., (ETA ? last : 1)+1), 0, displaysize(stdout)[2] > 54 ? min(displaysize(stdout)[2]-54, 50) : error("Error : Terminal window too narrow"))

function Base.iterate(bar::ProgressBar, state=1)
Ntimes = length(bar.times)-1
if state > bar.numsteps
println()
return nothing
elseif state == 1
printstyled("\nCompiling..."; color=:red, bold=true)
bar.times = fill(time(), Ntimes+1)
else
tnow = time()
dtelapsed = tnow - bar.times[1]
dtETA = (tnow - bar.times[2+state%Ntimes])*(bar.numsteps - state)/min(state-1, Ntimes)
dtiter = tnow - bar.times[2+(state-1)%Ntimes]
bar.times[2+state%Ntimes] = tnow
elapsedstr = Dates.format(Time(0)+Second(floor(Int, dtelapsed)), dtelapsed>3600 ? "HH:MM:SS" : "MM:SS")
ETAstr = Dates.format(Time(0)+Second(floor(Int, dtETA)), dtETA>3600 ? "HH:MM:SS" : "MM:SS")
iterstr = Dates.format(Time(0)+Millisecond(floor(Int, 1000*dtiter)), dtiter>60 ? "MM:SS.sss" : "SS.sss")
print("\r")
printstyled("$(round(100*state/bar.numsteps, digits=1))% "; color = :green, bold=true)
print("┣"*"#"^(round(Int, state/bar.numsteps*bar.length))*" "^(bar.length-round(Int, state/bar.numsteps*bar.length))*"┫")
printstyled(" $state/$(bar.numsteps) [$(elapsedstr)s"*(bar.ETA ? "; ETA:$(ETAstr)s" : "")*"; $(iterstr)s/it"*(bar.Dmax > 0 ? "; Dmax=$(bar.Dmax)" : "")*"]"; color = :green, bold=true)
end
return (state, state+1)
end

"""
onthefly(;plot_obs=nothing::Union{<:Observable, Nothing}, save_obs=Vector{Observable}(undef, 0)::Union{<:Observable, Vector{<:Observable}}, savedir="auto", step=10::Int, func=identity<:Function, compare=nothing::Union{Tuple{Vector{Float64}, Vector{Float64}}, Nothing}, clear=identity<:Function)

Helper function returning a dictionnary containing the necessary arguments for on the fly plotting or saving in the `runsim` function.

# Arguments

* `plot_obs` : Observable to plot
* `save_obs` : List of Observable(s) to save
* `savedir` : Used to specify the path where temporary files are stored, default is `"auto"` which saves in a "tmp" folder in the run folder (generally located at "~/MPSDynamics/<unid>/tmp/").
* `step` : Number of time steps every which the function plots or saves the data
* `func` : Function to apply to the result of measurement of plot_obs (for example `real` or `abs` to make a complex result possible to plot)
* `compare` : Tuple `(times, data)` of previous results to compare against the plot_obs results,
* `clear` : Function used to clear the output before each attempt to plot, helpful for working in Jupyter notebooks where clear=IJulia.clear_output allows to reduce the size of the cell output

# Examples

For example in the [Spin-Boson model example](@ref "The Spin Boson Model"), adding the following argument to the `runsim` function allows to save
the "sz" observable `ob1` in the directory "~/MPSDynamics/<unid>/tmp/" and plot its real part during the simulation:
```julia
runsim(..., onthefly=onthefly(plot_obs=ob1, save_obs=[ob1], savedir="auto", step=10, func=real))
```
To merge the temporary files in one usable file, one can then use [`MPSDynamics.mergetmp`](@ref).
"""
function onthefly(;plot_obs=nothing::Union{<:Observable, Nothing}, save_obs=Vector{Observable}(undef, 0)::Union{<:Observable, Vector{<:Observable}}, savedir="auto", step=10::Int, func=identity::Function, compare=nothing::Union{Tuple{Vector{Float64}, Vector{Float64}}, Nothing}, clear=nothing)
if isnothing(plot_obs) && isempty(save_obs)
error("Must provide an observable to plot/save")
end
# !isempty(save_obs) && (isdir(savedir) || mkdir(savedir))
plt = isnothing(plot_obs) ? nothing : plot(title="Intermediate Results", xlabel="t", ylabel=plot_obs.name)
!isnothing(compare) && plot!(compare)
println("On the fly mode activated")
return Dict(:plot_obs => isnothing(plot_obs) ? nothing : plot_obs.name, :save_obs => [ob.name for ob in save_obs], :savedir => savedir, :step => step, :func => func, :clear => clear, :compare => compare, :plot => plt)
end

"""
ontheflyplot(onthefly, tstep, times, data)

Plots data according to the arguments of the `onthefly` dictionnary."""
function ontheflyplot(onthefly, tstep, times, data)
tstep < onthefly[:step]+2 && plot!(onthefly[:plot], [], [])
N = ndims(data[onthefly[:plot_obs]])
N == 1 ? slicefunc(arr, N, i) = arr[i] : slicefunc = selectdim
append!(onthefly[:plot].series_list[end], times[tstep-onthefly[:step]+1:tstep+1], map(onthefly[:func], (slicefunc(data[onthefly[:plot_obs]], N, i) for i in (tstep-onthefly[:step]+1):(tstep+1))))
!isnothing(onthefly[:clear]) && onthefly[:clear](true)
sleep(0.05);display(onthefly[:plot])
end

"""
ontheflysave(onthefly, tstep, times, data)

Saves data according to the arguments of the `onthefly` dictionnary."""
function ontheflysave(onthefly, tstep, times, data)
jldopen(onthefly[:savedir]*"tmp$(tstep÷onthefly[:step]).jld", "w") do file
write(file, "times", times[tstep-onthefly[:step]+1:tstep+1])
for name in onthefly[:save_obs]
write(file, name, data[name][tstep-onthefly[:step]+1:tstep+1])
end
end
end

"""
mergetmp(tmpdir; fields=[], overwrite=true)

Merges the temporary files created by the `ontheflysave` function at the directory `tmpdir` and returns a dictionnary containing the resulting data.
By default all fields are present but one can select the fields of interest with a list of names in `fields`."""
function mergetmp(tmpdir; fields=[], overwrite=true)
tmpdir[end] != '/' && (tmpdir *= '/')
!isdir(tmpdir) && error("Choose a valid directory")
files = [walkdir(tmpdir)...][1][3]
files = filter(x->!isnothing(match(r"tmp(\d+).jld", x)), files)
files = sort(files, by=(x -> parse(Int, match(r"(\d+)", x).captures[1])))
isempty(fields) && (fields = keys(JLD.load(tmpdir*files[1])))
merged_data = Dict(ob => [] for ob in fields)
i = 1
for file in files
for ob in fields
append!(merged_data[ob], JLD.load(tmpdir*file, ob))
end
i += 1
end
if overwrite
for file in files
rm(tmpdir*file)
end
JLD.save(tmpdir*files[end], Iterators.flatten(merged_data)...)
end
return merged_data
end