diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8d43117..1444ca5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,7 +15,7 @@ jobs: version: - '1.0' - '1' -# - 'nightly' + - 'nightly' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index a639c5d..572c7a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,20 +1,27 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.2.1" +version = "1.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + [compat] ChainRulesCore = "1" julia = "^1.0" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] diff --git a/src/chainrules.jl b/ext/AbstractFFTsChainRulesCoreExt.jl similarity index 96% rename from src/chainrules.jl rename to ext/AbstractFFTsChainRulesCoreExt.jl index 97d4d22..f0c788e 100644 --- a/src/chainrules.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -1,4 +1,8 @@ -# ffts +module AbstractFFTsChainRulesCoreExt + +using AbstractFFTs +import ChainRulesCore + function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) y = fft(x, dims) Δy = fft(Δx, dims) @@ -46,7 +50,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim end function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) y = ifft(x, dims) - invN = normalization(y, dims) + invN = AbstractFFTs.normalization(y, dims) project_x = ChainRulesCore.ProjectTo(x) function ifft_pullback(ȳ) x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) @@ -66,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - invN = normalization(y, dims) + invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN scale = reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], @@ -150,3 +154,5 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) end return y, ifftshift_pullback end + +end # module diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 56d7123..00f6dc2 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -1,13 +1,14 @@ module AbstractFFTs -import ChainRulesCore - export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") -include("chainrules.jl") + +if !isdefined(Base, :get_extension) + include("../ext/AbstractFFTsChainRulesCoreExt.jl") +end end # module