From 82265387dcae86df0b6527609cef2b0bad43f998 Mon Sep 17 00:00:00 2001 From: kailaix <klxu@pku.edu.cn> Date: Thu, 13 Dec 2018 14:11:34 -0800 Subject: [PATCH 1/2] conv1d --- REQUIRE | 1 + deps/default_imports.txt | 3 +++ src/TensorFlow.jl | 1 + src/ops/nn.jl | 19 +++++++++++++++++++ test/nn.jl | 20 ++++++++++++++++++++ 5 files changed, 44 insertions(+) diff --git a/REQUIRE b/REQUIRE index 080e1af8..79a4e614 100644 --- a/REQUIRE +++ b/REQUIRE @@ -13,3 +13,4 @@ MacroTools 0.3.6 AutoHashEquals 0.1.0 MLDatasets 0.3.0 SpecialFunctions 0.7.0 +Optim 0.17.0 diff --git a/deps/default_imports.txt b/deps/default_imports.txt index 9368e056..839cf659 100644 --- a/deps/default_imports.txt +++ b/deps/default_imports.txt @@ -166,3 +166,6 @@ Rank Conv2DBackpropInput Svd Cross +FFT +ComplexAbs +MatrixSolve diff --git a/src/TensorFlow.jl b/src/TensorFlow.jl index 27b62e23..2ba141f3 100644 --- a/src/TensorFlow.jl +++ b/src/TensorFlow.jl @@ -135,6 +135,7 @@ with_tape using Distributed +using Optim # Load these packages here so they are available to the additional # process spawned in 'load_python_process. Arslan thinks that will diff --git a/src/ops/nn.jl b/src/ops/nn.jl index dcd47a3c..cb808a50 100644 --- a/src/ops/nn.jl +++ b/src/ops/nn.jl @@ -36,6 +36,25 @@ import .rnn_cell: zero_state, output_size, state_size conv2d(input, filter; padding=padding, strides=strides, kwargs...) end +@tf.op function conv1d(input, filter_, strides_::Int64, padding::String; data_format="NHWC", kwargs...) + spatial_start_dim = 0 + if data_format=="NHWC" + strides_ = [1,1,strides_,1] + spatial_start_dim = 2 + elseif data_format == "NCHW" || data_format == "NCW" + data_format = "NCHW" + spatial_start_dim = 3 + strides_ = [1,1,1,strides_] + else + @error "data_format must be NHWC or NCHW or NCW" + end + input = Ops.expand_dims(input, spatial_start_dim) + filter_ = Ops.expand_dims(filter_, 1) + result = Ops.conv2d(input, filter_; strides = strides_, padding = padding, data_format=data_format, kwargs...) + result = Ops.squeeze(result, squeeze_dims=[spatial_start_dim-1]) + return result +end + # Same for max pool @tf.op function max_pool(input, ksize, strides, padding; kwargs...) max_pool(input; ksize=ksize, strides=strides, padding=padding, kwargs...) diff --git a/test/nn.jl b/test/nn.jl index 6ef29e52..2bf1e64c 100644 --- a/test/nn.jl +++ b/test/nn.jl @@ -4,6 +4,26 @@ using StatsFuns using Random import LinearAlgebra +@testset "conv1d" begin + let + sess = Session(Graph()) + F = zeros(Float32, 2, 3, 4) # batch_size = 2, dimension = 3, channle = 4 + for i = 1:2 + for j = 1:3 + for k = 1:4 + F[i,j,k] = Float32(i+j+k-3) + end + end + end + input = constant(F) + filter_ = constant(ones(Float32, 3, 4, 1)) # width = 3, input channel = 4 output channel = 1 + output = nn.conv1d(input, filter_, 2, "VALID") + output_val = run(sess, output) + ref_val = reshape(Float32[30.0;42.0], 2, 1, 1) + @test ref_val ≈ output_val + end +end + @testset "conv2d_transpose" begin let sess = Session(Graph()) From 485d3adef8aa9e9dc9095175dcc9fbc03a312236 Mon Sep 17 00:00:00 2001 From: Jon Malmaud <malmaud@gmail.com> Date: Fri, 17 May 2019 14:06:05 -0400 Subject: [PATCH 2/2] Add Optim dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 12abda20..778c3737 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"