diff --git a/Project.toml b/Project.toml index 417cd10..173bd5c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,20 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.0.1" +version = "1.1.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] +ChainRulesCore = "1" julia = "^1.0" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Test", "Unitful"] +test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"] diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index d31cde3..7b0902f 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -1,10 +1,13 @@ module AbstractFFTs +import ChainRulesCore + export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, fftshift, ifftshift, Frequencies, fftfreq, rfftfreq include("definitions.jl") +include("chainrules.jl") end # module diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..97d4d22 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,152 @@ +# ffts +function ChainRulesCore.frule((_, Δx, _), ::typeof(fft), x::AbstractArray, dims) + y = fft(x, dims) + Δy = fft(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) + y = fft(x, dims) + project_x = ChainRulesCore.ProjectTo(x) + function fft_pullback(ȳ) + x̄ = project_x(bfft(ChainRulesCore.unthunk(ȳ), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, fft_pullback +end + +function ChainRulesCore.frule((_, Δx, _), ::typeof(rfft), x::AbstractArray{<:Real}, dims) + y = rfft(x, dims) + Δy = rfft(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) + y = rfft(x, dims) + + # compute scaling factors + halfdim = first(dims) + d = size(x, halfdim) + n = size(y, halfdim) + scale = reshape( + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ) + + project_x = ChainRulesCore.ProjectTo(x) + function rfft_pullback(ȳ) + x̄ = project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, rfft_pullback +end + +function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dims) + y = ifft(x, dims) + Δy = ifft(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(ifft), x::AbstractArray, dims) + y = ifft(x, dims) + invN = normalization(y, dims) + project_x = ChainRulesCore.ProjectTo(x) + function ifft_pullback(ȳ) + x̄ = project_x(invN .* fft(ChainRulesCore.unthunk(ȳ), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, ifft_pullback +end + +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(irfft), x::AbstractArray, d::Int, dims) + y = irfft(x, d, dims) + Δy = irfft(Δx, d, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) + y = irfft(x, d, dims) + + # compute scaling factors + halfdim = first(dims) + n = size(x, halfdim) + invN = normalization(y, dims) + twoinvN = 2 * invN + scale = reshape( + [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ) + + project_x = ChainRulesCore.ProjectTo(x) + function irfft_pullback(ȳ) + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + end + return y, irfft_pullback +end + +function ChainRulesCore.frule((_, Δx, _), ::typeof(bfft), x::AbstractArray, dims) + y = bfft(x, dims) + Δy = bfft(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(bfft), x::AbstractArray, dims) + y = bfft(x, dims) + project_x = ChainRulesCore.ProjectTo(x) + function bfft_pullback(ȳ) + x̄ = project_x(fft(ChainRulesCore.unthunk(ȳ), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, bfft_pullback +end + +function ChainRulesCore.frule((_, Δx, _, _), ::typeof(brfft), x::AbstractArray, d::Int, dims) + y = brfft(x, d, dims) + Δy = brfft(Δx, d, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) + y = brfft(x, d, dims) + + # compute scaling factors + halfdim = first(dims) + n = size(x, halfdim) + scale = reshape( + [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), + ) + + project_x = ChainRulesCore.ProjectTo(x) + function brfft_pullback(ȳ) + x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() + end + return y, brfft_pullback +end + +# shift functions +function ChainRulesCore.frule((_, Δx, _), ::typeof(fftshift), x::AbstractArray, dims) + y = fftshift(x, dims) + Δy = fftshift(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(fftshift), x::AbstractArray, dims) + y = fftshift(x, dims) + project_x = ChainRulesCore.ProjectTo(x) + function fftshift_pullback(ȳ) + x̄ = project_x(ifftshift(ChainRulesCore.unthunk(ȳ), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, fftshift_pullback +end + +function ChainRulesCore.frule((_, Δx, _), ::typeof(ifftshift), x::AbstractArray, dims) + y = ifftshift(x, dims) + Δy = ifftshift(Δx, dims) + return y, Δy +end +function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) + y = ifftshift(x, dims) + project_x = ChainRulesCore.ProjectTo(x) + function ifftshift_pullback(ȳ) + x̄ = project_x(fftshift(ChainRulesCore.unthunk(ȳ), dims)) + return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() + end + return y, ifftshift_pullback +end diff --git a/src/definitions.jl b/src/definitions.jl index e8bda0a..80a6656 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -256,7 +256,7 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) *(p::Plan, I::UniformScaling) = ScaledPlan(p, I.λ) # Normalization for ifft, given unscaled bfft, is 1/prod(dimensions) -normalization(::Type{T}, sz, region) where T = one(T) / Int(prod([sz...][[region...]]))::Int +normalization(::Type{T}, sz, region) where T = one(T) / Int(prod(sz[r] for r in region))::Int normalization(X, region) = normalization(real(eltype(X)), size(X), region) plan_ifft(x::AbstractArray, region; kws...) = @@ -360,7 +360,7 @@ If `dim` is not given then the signal is shifted along each dimension. fftshift function fftshift(x, dim = 1:ndims(x)) - s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, ndims(x)) + s = ntuple(d -> d in dim ? div(size(x,d),2) : 0, Val(ndims(x))) circshift(x, s) end @@ -380,7 +380,7 @@ If `dim` is not given then the signal is shifted along each dimension. ifftshift function ifftshift(x, dim = 1:ndims(x)) - s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, ndims(x)) + s = ntuple(d -> d in dim ? -div(size(x,d),2) : 0, Val(ndims(x))) circshift(x, s) end diff --git a/test/runtests.jl b/test/runtests.jl index afa6d0d..de0304d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,18 @@ using AbstractFFTs using AbstractFFTs: Plan +using ChainRulesTestUtils + using LinearAlgebra +using Random using Test import Unitful +Random.seed!(1234) + +include("testplans.jl") + @testset "rfft sizes" begin A = rand(11, 10) @test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10) @@ -17,80 +24,99 @@ import Unitful @test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2) end -mutable struct TestPlan{T} <: Plan{T} - region - pinv::Plan{T} - TestPlan{T}(region) where {T} = new{T}(region) -end - -mutable struct InverseTestPlan{T} <: Plan{T} - region - pinv::Plan{T} - InverseTestPlan{T}(region) where {T} = new{T}(region) -end - -AbstractFFTs.plan_fft(x::Vector{T}, region; kwargs...) where {T} = TestPlan{T}(region) -AbstractFFTs.plan_bfft(x::Vector{T}, region; kwargs...) where {T} = InverseTestPlan{T}(region) -AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} = InverseTestPlan{T} - -# Just a helper function since forward and backward are nearly identical -function dft!(y::Vector, x::Vector, sign::Int) - n = length(x) - length(y) == n || throw(DimensionMismatch()) - fill!(y, zero(complex(float(eltype(x))))) - c = sign * 2π / n - @inbounds for j = 0:n-1, k = 0:n-1 - y[k+1] += x[j+1] * cis(c*j*k) - end - return y -end - -mul!(y::Vector, p::TestPlan, x::Vector) = dft!(y, x, -1) -mul!(y::Vector, p::InverseTestPlan, x::Vector) = dft!(y, x, 1) - -Base.:*(p::TestPlan, x::Vector) = mul!(copy(x), p, x) -Base.:*(p::InverseTestPlan, x::Vector) = mul!(copy(x), p, x) - @testset "Custom Plan" begin - x = AbstractFFTs.fft(collect(1:8)) - # Result computed using FFTW - fftw_fft = [36.0 + 0.0im, - -4.0 + 9.65685424949238im, - -4.0 + 4.0im, - -4.0 + 1.6568542494923806im, - -4.0 + 0.0im, - -4.0 - 1.6568542494923806im, - -4.0 - 4.0im, - -4.0 - 9.65685424949238im] - @test x ≈ fftw_fft - - fftw_bfft = [Complex{Float64}(8i, 0) for i in 1:8] - @test AbstractFFTs.bfft(x) ≈ fftw_bfft - - fftw_ifft = [Complex{Float64}(i, 0) for i in 1:8] - @test AbstractFFTs.ifft(x) ≈ fftw_ifft - - @test eltype(plan_fft(collect(1:8))) == Int + # DFT along last dimension, results computed using FFTW + for (x, fftw_fft) in ( + (collect(1:7), + [28.0 + 0.0im, + -3.5 + 7.267824888003178im, + -3.5 + 2.7911568610884143im, + -3.5 + 0.7988521603655248im, + -3.5 - 0.7988521603655248im, + -3.5 - 2.7911568610884143im, + -3.5 - 7.267824888003178im]), + (collect(1:8), + [36.0 + 0.0im, + -4.0 + 9.65685424949238im, + -4.0 + 4.0im, + -4.0 + 1.6568542494923806im, + -4.0 + 0.0im, + -4.0 - 1.6568542494923806im, + -4.0 - 4.0im, + -4.0 - 9.65685424949238im]), + (collect(reshape(1:8, 2, 4)), + [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; + 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), + (collect(reshape(1:9, 3, 3)), + [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), + ) + # FFT + dims = ndims(x) + y = AbstractFFTs.fft(x, dims) + @test y ≈ fftw_fft + P = plan_fft(x, dims) + @test eltype(P) === ComplexF64 + @test P * x ≈ fftw_fft + @test P \ (P * x) ≈ x + + fftw_bfft = complex.(size(x, dims) .* x) + @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft + P = plan_bfft(x, dims) + @test P * y ≈ fftw_bfft + @test P \ (P * y) ≈ y + + fftw_ifft = complex.(x) + @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft + P = plan_ifft(x, dims) + @test P * y ≈ fftw_ifft + @test P \ (P * y) ≈ y + + # real FFT + fftw_rfft = fftw_fft[ + (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., + 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) + ] + ry = AbstractFFTs.rfft(x, dims) + @test ry ≈ fftw_rfft + P = plan_rfft(x, dims) + @test eltype(P) === Int + @test P * x ≈ fftw_rfft + @test P \ (P * x) ≈ x + + fftw_brfft = complex.(size(x, dims) .* x) + @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft + P = plan_brfft(ry, size(x, dims), dims) + @test P * ry ≈ fftw_brfft + @test P \ (P * ry) ≈ ry + + fftw_irfft = complex.(x) + @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft + P = plan_irfft(ry, size(x, dims), dims) + @test P * ry ≈ fftw_irfft + @test P \ (P * ry) ≈ ry + end end @testset "Shift functions" begin - @test AbstractFFTs.fftshift([1 2 3]) == [3 1 2] - @test AbstractFFTs.fftshift([1, 2, 3]) == [3, 1, 2] - @test AbstractFFTs.fftshift([1 2 3; 4 5 6]) == [6 4 5; 3 1 2] - - @test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3] - @test AbstractFFTs.fftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6] - @test AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2)) == [6 4 5; 3 1 2] - @test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2) == [6 4 5; 3 1 2] - - @test AbstractFFTs.ifftshift([1 2 3]) == [2 3 1] - @test AbstractFFTs.ifftshift([1, 2, 3]) == [2, 3, 1] - @test AbstractFFTs.ifftshift([1 2 3; 4 5 6]) == [5 6 4; 2 3 1] - - @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3] - @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6] - @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2)) == [5 6 4; 2 3 1] - @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2) == [5 6 4; 2 3 1] + @test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2] + @test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6])) == [6 4 5; 3 1 2] + + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2))) == [6 4 5; 3 1 2] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2)) == [6 4 5; 3 1 2] + + @test @inferred(AbstractFFTs.ifftshift([1 2 3])) == [2 3 1] + @test @inferred(AbstractFFTs.ifftshift([1, 2, 3])) == [2, 3, 1] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6])) == [5 6 4; 2 3 1] + + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2))) == [5 6 4; 2 3 1] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2)) == [5 6 4; 2 3 1] end @testset "FFT Frequencies" begin @@ -147,3 +173,51 @@ end f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end + +@testset "ChainRules" begin + @testset "shift functions" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + for dims in ((), 1, 2, (1,2), 1:2) + any(d > ndims(x) for d in dims) && continue + + # type inference checks of `rrule` fail on old Julia versions + # for higher-dimensional arrays: + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 + check_inferred = ndims(x) < 3 || VERSION >= v"1.6" + + test_frule(AbstractFFTs.fftshift, x, dims) + test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) + + test_frule(AbstractFFTs.ifftshift, x, dims) + test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred) + end + end + end + + @testset "fft" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + for dims in unique((1, 1:N, N)) + for f in (fft, ifft, bfft) + test_frule(f, x, dims) + test_rrule(f, x, dims) + test_frule(f, complex_x, dims) + test_rrule(f, complex_x, dims) + end + + test_frule(rfft, x, dims) + test_rrule(rfft, x, dims) + + for f in (irfft, brfft) + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) + test_frule(f, x, d, dims) + test_rrule(f, x, d, dims) + test_frule(f, complex_x, d, dims) + test_rrule(f, complex_x, d, dims) + end + end + end + end + end +end diff --git a/test/testplans.jl b/test/testplans.jl new file mode 100644 index 0000000..5949da2 --- /dev/null +++ b/test/testplans.jl @@ -0,0 +1,228 @@ +mutable struct TestPlan{T,N} <: Plan{T} + region + sz::NTuple{N,Int} + pinv::Plan{T} + function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} + return new{T,N}(region, sz) + end +end + +mutable struct InverseTestPlan{T,N} <: Plan{T} + region + sz::NTuple{N,Int} + pinv::Plan{T} + function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} + return new{T,N}(region, sz) + end +end + +Base.size(p::TestPlan) = p.sz +Base.ndims(::TestPlan{T,N}) where {T,N} = N +Base.size(p::InverseTestPlan) = p.sz +Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N + +function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} + return TestPlan{T}(region, size(x)) +end +function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T} + return InverseTestPlan{T}(region, size(x)) +end +function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} + unscaled_pinv = InverseTestPlan{T}(p.region, p.sz) + unscaled_pinv.pinv = p + pinv = AbstractFFTs.ScaledPlan( + unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), + ) + return pinv +end +function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T} + unscaled_pinv = TestPlan{T}(p.region, p.sz) + unscaled_pinv.pinv = p + pinv = AbstractFFTs.ScaledPlan( + unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), + ) + return pinv +end + +# Just a helper function since forward and backward are nearly identical +# The function does not check if the size of `y` and `x` are compatible, this +# is done in the function where `dft!` is called since the check differs for FFTs +# with complex and real-valued signals +function dft!( + y::AbstractArray{<:Complex,N}, + x::AbstractArray{<:Union{Complex,Real},N}, + dims, + sign::Int +) where {N} + # check that dimensions that are transformed are unique + allunique(dims) || error("dimensions have to be unique") + + T = eltype(y) + # we use `size(x, d)` since for real-valued signals + # `size(y, first(dims)) = size(x, first(dims)) ÷ 2 + 1` + cs = map(d -> T(sign * 2π / size(x, d)), dims) + fill!(y, zero(T)) + for yidx in CartesianIndices(y) + # set of indices of `x` on which `y[yidx]` depends + xindices = CartesianIndices( + ntuple(i -> i in dims ? axes(x, i) : yidx[i]:yidx[i], Val(N)) + ) + for xidx in xindices + y[yidx] += x[xidx] * cis(sum(c * (yidx[d] - 1) * (xidx[d] - 1) for (c, d) in zip(cs, dims))) + end + end + return y +end + +function mul!( + y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N} +) where {N} + size(y) == size(p) == size(x) || throw(DimensionMismatch()) + dft!(y, x, p.region, -1) +end +function mul!( + y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N} +) where {N} + size(y) == size(p) == size(x) || throw(DimensionMismatch()) + dft!(y, x, p.region, 1) +end + +Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) + +mutable struct TestRPlan{T,N} <: Plan{T} + region + sz::NTuple{N,Int} + pinv::Plan{T} + TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz) +end + +mutable struct InverseTestRPlan{T,N} <: Plan{T} + d::Int + region + sz::NTuple{N,Int} + pinv::Plan{T} + function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N} + sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions") + return new{T,N}(d, region, sz) + end +end + +function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T} + return TestRPlan{T}(region, size(x)) +end +function AbstractFFTs.plan_brfft(x::AbstractArray{T}, d, region; kwargs...) where {T} + return InverseTestRPlan{T}(d, region, size(x)) +end +function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N} + firstdim = first(p.region)::Int + d = p.sz[firstdim] + sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N)) + unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz) + unscaled_pinv.pinv = p + pinv = AbstractFFTs.ScaledPlan( + unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), + ) + return pinv +end +function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N} + firstdim = first(p.region)::Int + sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N)) + unscaled_pinv = TestRPlan{T}(p.region, sz) + unscaled_pinv.pinv = p + pinv = AbstractFFTs.ScaledPlan( + unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region), + ) + return pinv +end + +Base.size(p::TestRPlan) = p.sz +Base.ndims(::TestRPlan{T,N}) where {T,N} = N +Base.size(p::InverseTestRPlan) = p.sz +Base.ndims(::InverseTestRPlan{T,N}) where {T,N} = N + +function real_invdft!( + y::AbstractArray{<:Real,N}, + x::AbstractArray{<:Union{Complex,Real},N}, + dims, +) where {N} + # check that dimensions that are transformed are unique + allunique(dims) || error("dimensions have to be unique") + + firstdim = first(dims) + size_x_firstdim = size(x, firstdim) + iseven_firstdim = iseven(size(y, firstdim)) + # we do not check that the input corresponds to a real-valued signal + # (i.e., that the first and, if `iseven_firstdim`, the last value in dimension + # `haldim` of `x` are real values) due to numerical inaccuracies + # instead we just use the real part of these entries + + T = eltype(y) + # we use `size(y, d)` since `size(x, first(dims)) = size(y, first(dims)) ÷ 2 + 1` + cs = map(d -> T(2π / size(y, d)), dims) + fill!(y, zero(T)) + for yidx in CartesianIndices(y) + # set of indices of `x` on which `y[yidx]` depends + xindices = CartesianIndices( + ntuple(i -> i in dims ? axes(x, i) : yidx[i]:yidx[i], Val(N)) + ) + for xidx in xindices + coeffimag, coeffreal = sincos( + sum(c * (yidx[d] - 1) * (xidx[d] - 1) for (c, d) in zip(cs, dims)) + ) + + # the first and, if `iseven_firstdim`, the last term of the DFT are scaled + # with 1 instead of 2 and only the real part is used (see note above) + xidx_firstdim = xidx[firstdim] + if xidx_firstdim == 1 || (iseven_firstdim && xidx_firstdim == size_x_firstdim) + y[yidx] += coeffreal * real(x[xidx]) + else + xreal, ximag = reim(x[xidx]) + y[yidx] += 2 * (coeffreal * xreal - coeffimag * ximag) + end + end + end + + return y +end + +to_real!(x::AbstractArray) = map!(real, x, x) + +function Base.:*(p::TestRPlan, x::AbstractArray) + size(p) == size(x) || error("array and plan are not consistent") + + # create output array + firstdim = first(p.region)::Int + d = size(x, firstdim) + firstdim_size = d ÷ 2 + 1 + T = complex(float(eltype(x))) + sz = ntuple(i -> i == firstdim ? firstdim_size : size(x, i), Val(ndims(x))) + y = similar(x, T, sz) + + # compute DFT + dft!(y, x, p.region, -1) + + # we clean the output a bit to make sure that we return real values + # whenever the output is mathematically guaranteed to be a real number + to_real!(selectdim(y, firstdim, 1)) + if iseven(d) + to_real!(selectdim(y, firstdim, firstdim_size)) + end + + return y +end + +function Base.:*(p::InverseTestRPlan, x::AbstractArray) + size(p) == size(x) || error("array and plan are not consistent") + + # create output array + firstdim = first(p.region)::Int + d = p.d + sz = ntuple(i -> i == firstdim ? d : size(x, i), Val(ndims(x))) + y = similar(x, real(float(eltype(x))), sz) + + # compute DFT + real_invdft!(y, x, p.region) + + return y +end