Skip to content

Commit

Permalink
Add ChainRules definitions (#58)
Browse files Browse the repository at this point in the history
* Add ChainRules definitions

* Add tests

* Move test plans to separate file

* Fix type inference of `fftshift` and `ifftshift` on Julia 1.0

* Disable type inference checks for `fftshift` and `ifftshift` in old Julia versions

* Bump version
  • Loading branch information
devmotion authored Jan 10, 2022
1 parent d007201 commit 2bae074
Show file tree
Hide file tree
Showing 6 changed files with 536 additions and 75 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"
version = "1.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
ChainRulesCore = "1"
julia = "^1.0"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Test", "Unitful"]
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
3 changes: 3 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
module AbstractFFTs

import ChainRulesCore

export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
fftshift, ifftshift, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("chainrules.jl")

end # module
152 changes: 152 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# ffts
function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims)
y = fft(x, dims)
Δy = fft(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
y = fft(x, dims)
project_x = ChainRulesCore.ProjectTo(x)
function fft_pullback(ȳ)
= project_x(bfft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, fft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims)
y = rfft(x, dims)
Δy = rfft(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
y = rfft(x, dims)

# compute scaling factors
halfdim = first(dims)
d = size(x, halfdim)
n = size(y, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, rfft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims)
y = ifft(x, dims)
Δy = ifft(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims)
y = ifft(x, dims)
invN = normalization(y, dims)
project_x = ChainRulesCore.ProjectTo(x)
function ifft_pullback(ȳ)
= project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, ifft_pullback
end

function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims)
y = irfft(x, d, dims)
Δy = irfft(Δx, d, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
y = irfft(x, d, dims)

# compute scaling factors
halfdim = first(dims)
n = size(x, halfdim)
invN = normalization(y, dims)
twoinvN = 2 * invN
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
end
return y, irfft_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims)
y = bfft(x, dims)
Δy = bfft(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims)
y = bfft(x, dims)
project_x = ChainRulesCore.ProjectTo(x)
function bfft_pullback(ȳ)
= project_x(fft(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, bfft_pullback
end

function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims)
y = brfft(x, d, dims)
Δy = brfft(Δx, d, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
y = brfft(x, d, dims)

# compute scaling factors
halfdim = first(dims)
n = size(x, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
end
return y, brfft_pullback
end

# shift functions
function ChainRulesCore.frule((_, Δx, _), ::typeof(fftshift), x::AbstractArray, dims)
y = fftshift(x, dims)
Δy = fftshift(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(fftshift), x::AbstractArray, dims)
y = fftshift(x, dims)
project_x = ChainRulesCore.ProjectTo(x)
function fftshift_pullback(ȳ)
= project_x(ifftshift(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, fftshift_pullback
end

function ChainRulesCore.frule((_, Δx, _), ::typeof(ifftshift), x::AbstractArray, dims)
y = ifftshift(x, dims)
Δy = ifftshift(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
y = ifftshift(x, dims)
project_x = ChainRulesCore.ProjectTo(x)
function ifftshift_pullback(ȳ)
= project_x(fftshift(ChainRulesCore.unthunk(ȳ), dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, ifftshift_pullback
end
6 changes: 3 additions & 3 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
*(p::Plan, I::UniformScaling) = ScaledPlan(p, I.λ)

# Normalization for ifft, given unscaled bfft, is 1/prod(dimensions)
normalization(::Type{T}, sz, region) where T = one(T) / Int(prod([sz...][[region...]]))::Int
normalization(::Type{T}, sz, region) where T = one(T) / Int(prod(sz[r] for r in region))::Int
normalization(X, region) = normalization(real(eltype(X)), size(X), region)

plan_ifft(x::AbstractArray, region; kws...) =
Expand Down Expand Up @@ -360,7 +360,7 @@ If `dim` is not given then the signal is shifted along each dimension.
fftshift

function fftshift(x, dim = 1:ndims(x))
s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, ndims(x))
s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, Val(ndims(x)))
circshift(x, s)
end

Expand All @@ -380,7 +380,7 @@ If `dim` is not given then the signal is shifted along each dimension.
ifftshift

function ifftshift(x, dim = 1:ndims(x))
s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, ndims(x))
s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, Val(ndims(x)))
circshift(x, s)
end

Expand Down
Loading

2 comments on commit 2bae074

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/52087

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" 2bae07440ec4e069ddb9d086fc5c0f9a964c2ac2
git push origin v1.1.0

Please sign in to comment.