diff --git a/REQUIRE b/REQUIRE index 080e1af8..79a4e614 100644 --- a/REQUIRE +++ b/REQUIRE @@ -13,3 +13,4 @@ MacroTools 0.3.6 AutoHashEquals 0.1.0 MLDatasets 0.3.0 SpecialFunctions 0.7.0 +Optim 0.17.0 diff --git a/deps/build.jl b/deps/build.jl index 1a104aed..caabe43d 100644 --- a/deps/build.jl +++ b/deps/build.jl @@ -1,8 +1,8 @@ using PyCall using Conda -const cur_version = "1.10.0" -const cur_py_version = "1.10.0" +const cur_version = "1.12.0" +const cur_py_version = "1.12.0" ############################ diff --git a/deps/default_imports.txt b/deps/default_imports.txt index 9368e056..839cf659 100644 --- a/deps/default_imports.txt +++ b/deps/default_imports.txt @@ -166,3 +166,6 @@ Rank Conv2DBackpropInput Svd Cross +FFT +ComplexAbs +MatrixSolve diff --git a/src/TensorFlow.jl b/src/TensorFlow.jl index 69b76c45..228c93ac 100644 --- a/src/TensorFlow.jl +++ b/src/TensorFlow.jl @@ -128,6 +128,7 @@ tf_versioninfo using Distributed +using Optim const pyproc = Ref(0) diff --git a/src/train.jl b/src/train.jl index 0514ce99..d05ae6c2 100644 --- a/src/train.jl +++ b/src/train.jl @@ -7,6 +7,7 @@ apply_gradients, GradientDescentOptimizer, MomentumOptimizer, AdamOptimizer, +optim_minimize, Saver, save, restore, @@ -25,8 +26,10 @@ using Compat using JLD2 using FileIO using ProtoBuf +using Optim import Printf + import ..TensorFlow: Graph, Operation, get_def_graph, extend_graph, gradients, variable_scope, ConstantInitializer, node_name, get_variable, get_shape, get_collection, Session, placeholder, Tensor, Variable, cast, group, @not_implemented, AbstractQueue, tensorflow, add_to_collection, get_proto, get_def, @op import TensorFlow @@ -183,6 +186,132 @@ function apply_gradients(optimizer::AdamOptimizer, grads_and_vars; global_step=n return group(ops...) end +mutable struct OptimOptimizer + indices::Array{Array{Int64}} + segments::Array{Array{Int64}} + sess::Session + vars::Array{Tuple{Any,Any},1} + dtype::Type + feed_dict::Dict +end + +function OptimOptimizer(dtype::Type, loss::Tensor, sess::Session, feed_dict::Dict=Dict()) + var_list = get_def_graph().collections[:TrainableVariables] + vars = zip(gradients(loss, var_list), var_list) |> collect + filter!(x->x[1]!==nothing, vars) + + indices = Array{Int64}[] + segments = Array{Int64}[] + idx = 1 + for i = 1:length(vars) + W = run(sess, vars[i][2], feed_dict) + push!(indices, [ i for i in size(W)]) + push!(segments, [idx; idx+length(W)-1]) + idx += length(W) + end + OptimOptimizer(indices, segments, sess, vars, dtype, feed_dict) +end + +function update_values(opt::OptimOptimizer, x) + for i = 1:length(opt.indices) + x0 = reshape(x[opt.segments[i][1]:opt.segments[i][2]], opt.indices[i]...) + run(opt.sess, tf.assign(opt.vars[i][2], x0)) + end +end + +function compute_grads(opt::OptimOptimizer) + grads = zeros(opt.dtype, opt.segments[end][2]) + for i = 1:length(opt.indices) + grads[opt.segments[i][1]:opt.segments[i][2]] = run(opt.sess, opt.vars[i][1], opt.feed_dict) + end + return grads +end + +function compute_init(opt::OptimOptimizer) + x0 = zeros(opt.dtype, opt.segments[end][2]) + for i = 1:length(opt.indices) + x0[opt.segments[i][1]:opt.segments[i][2]] = run(opt.sess, opt.vars[i][2], opt.feed_dict) + end + return x0 +end + +""" +optim_minimize(sess::Session, loss::AbstractTensor; +dtype::Type = Float64, feed_dict::Dict = Dict(), method::String = "LBFGS", options=nothing) + +`optim_minimize` calls first order optimization solvers from Optim.jl package (https://github.com/JuliaNLSolvers/Optim.jl). +`sess`: current session +`loss`: the loss function to minimize +`dtype`: the computation value type (default Float64) +`feed_dict`: a dictionary for placeholders +`method`: four methods are supported: `LBFGS`(default), `BFGS`, `AGD`(AcceleratedGradientDescent), `CG` +`options`: An Optim.Options instance. See `Optim.jl` documents for details + +Example +======= +``` +function mycallback(handle) + res = run(sess, Loss, Dict(X=>x, Y_obs=>y)) + println("iter \$(handle.iteration): \$(res)") + return false # so it do not stop +end + +options = Optim.Options(show_trace = false, iterations=1000, callback = mycallback, allow_f_increases=true) +optim_minimize(sess, Loss, feed_dict = Dict(X=>x, Y_obs=>y), options=options, method="AGD") +``` + +Note +======= + +Note that this optimizer is not built as part of the graph. Rather, it contructs a function and a gradient function using +`run(sess, ...)` for every iteration. There is drawback for this approach: (1) stochastic gradient descent is not easy to +implement; (2) there is some overhead. However, it would be nice to call the solvers from Optim.jl directly and leverage the +robustness and ffine granite parameter control options. +""" +function optim_minimize(sess::Session, loss::Tensor; + dtype::Type = Float64, feed_dict::Dict = Dict(), method::String = "LBFGS", options::Union{Nothing, Optim.Options}=nothing) + opt = OptimOptimizer(dtype, loss, sess, feed_dict) + function f(x) + update_values(opt, x) + res = run(sess, loss, feed_dict) + return res + end + + function g!(G, x) + update_values(opt, x) + G[:] = compute_grads(opt) + end + + x0 = compute_init(opt) + + optimizer = nothing + if method=="LBFGS" + optimizer = LBFGS() + elseif method=="BFGS" + optimizer = BFGS() + elseif method=="AGD" + optimizer = AcceleratedGradientDescent() + elseif method=="CG" + optimizer = ConjugateGradient() + else + @error """ +Available Optimier: +* LBFGS +* BFGS +* AGD (AcceleratedGradientDescent) +* GC (ConjugateGradient) +""" + + end + if options===nothing + return optimize(f, g!, x0, optimizer) + else + return optimize(f, g!, x0, optimizer, options) + end +end + + + mutable struct Saver var_list max_to_keep @@ -417,4 +546,4 @@ function SummaryWriter(args...; kwargs...) TensorFlow.summary.FileWriter(args...; kwargs...) end -end +end \ No newline at end of file diff --git a/test/train.jl b/test/train.jl index 7a52bf50..ea5a6a26 100644 --- a/test/train.jl +++ b/test/train.jl @@ -1,5 +1,6 @@ using TensorFlow using Test +using Optim @testset "save and resore" begin try @@ -78,3 +79,55 @@ end end end end + + +@testset "optimizers" begin + using Distributions + # Generate some synthetic data + x = randn(100, 50) + w = randn(50, 10) + y_prob = exp.(x*w) + y_prob ./= sum(y_prob,dims=2) + + function draw(probs) + y = zeros(size(probs)) + for i in 1:size(probs, 1) + idx = rand(Categorical(probs[i, :])) + y[i, idx] = 1 + end + return y + end + + y = draw(y_prob) + + # Build the model + sess = Session(Graph()) + + X = placeholder(Float64, shape=[-1, 50]) + Y_obs = placeholder(Float64, shape=[-1, 10]) + + variable_scope("logisitic_model"; initializer=Normal(0, .001)) do + global W = get_variable("W", [50, 10], Float64) + global B = get_variable("B", [10], Float64) + end + + Y=nn.softmax(X*W + B) + Loss = -reduce_sum(log(Y).*Y_obs) + + function mycallback(handle) + res = run(sess, Loss, Dict(X=>x, Y_obs=>y)) + println("iter $(handle.iteration): $(res)") + if isnan(res) || isinf(res) + return true + else + return false # so it do not stop + end + end + + for m in ["AGD", "CG", "BFGS", "LBFGS"] + run(sess, global_variables_initializer()) + options = Optim.Options(show_trace = false, iterations=10, callback = mycallback, allow_f_increases=true) + train.optim_minimize(sess, Loss, feed_dict = Dict(X=>x, Y_obs=>y), options=options, method=m) + end + +end \ No newline at end of file