Skip to content

Commit

Permalink
Merge pull request #65 from JuliaMath/dw/region
Browse files Browse the repository at this point in the history
Add `region(::Plan)` for accessing transformed region
  • Loading branch information
stevengj authored Jun 30, 2022
2 parents 4733cd1 + 8c4dcd9 commit 3e7d412
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.1.0"
version = "1.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`.
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ AbstractFFTs.brfft
AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ChainRulesCore
export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("chainrules.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

"""
fftdims(p::Plan)
Return an iterable of the dimensions that are transformed by the FFT plan `p`.
# Implementation
For legacy reasons, the default definition of `fftdims` returns `p.region`.
Hence this method should be implemented only for `Plan` subtypes that do not store the transformed dimensions in a field named `region`.
"""
fftdims(p::Plan) = p.region

fftfloat(x) = _fftfloat(float(x))
_fftfloat(::Type{T}) where {T<:BlasReal} = T
_fftfloat(::Type{Float16}) = Float32
Expand Down Expand Up @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))

Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ end
@test eltype(P) === ComplexF64
@test P * x fftw_fft
@test P \ (P * x) x
@test fftdims(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) fftw_bfft
P = plan_bfft(x, dims)
@test P * y fftw_bfft
@test P \ (P * y) y
@test fftdims(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) fftw_ifft
P = plan_ifft(x, dims)
@test P * y fftw_ifft
@test P \ (P * y) y
@test fftdims(P) == dims

# real FFT
fftw_rfft = fftw_fft[
Expand All @@ -84,18 +87,21 @@ end
@test eltype(P) === Int
@test P * x fftw_rfft
@test P \ (P * x) x
@test fftdims(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry fftw_brfft
@test P \ (P * ry) ry

@test fftdims(P) == dims

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry fftw_irfft
@test P \ (P * ry) ry
@test fftdims(P) == dims
end
end

Expand Down Expand Up @@ -187,7 +193,7 @@ end
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

Expand Down

2 comments on commit 3e7d412

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63431

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.2.0 -m "<description of version>" 3e7d412231a84cab1d3fe4ab2ff171b8ca8a0054
git push origin v1.2.0

Please sign in to comment.