From 53c67ca3dacdbe31449c648069b2663c38730d04 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 28 Feb 2022 13:45:18 +0100 Subject: [PATCH 1/7] Add `region(::Plan)` for accessing transformed region --- Project.toml | 2 +- src/definitions.jl | 14 ++++++++++++++ test/runtests.jl | 8 +++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 173bd5c..9dcaad1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.1.0" +version = "1.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/definitions.jl b/src/definitions.jl index 80a6656..0900ace 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d] ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int +""" + region(p::Plan) + +Return an iterable of the dimensions that are transformed by the FFT plan `p`. + +# Implementation + +The default definition of `region` returns `p.region`. +Hence this method should be implemented only for types of `Plan`s that do not store the transformed region in a field of name `region`. +""" +region(p::Plan) = p.region + fftfloat(x) = _fftfloat(float(x)) _fftfloat(::Type{T}) where {T<:BlasReal} = T _fftfloat(::Type{Float16}) = Float32 @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) +region(p::ScaledPlan) = region(p.p) + show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p) summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) diff --git a/test/runtests.jl b/test/runtests.jl index de0304d..c0e5752 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,18 +60,21 @@ end @test eltype(P) === ComplexF64 @test P * x ≈ fftw_fft @test P \ (P * x) ≈ x + @test AbstractFFTs.region(P) == dims fftw_bfft = complex.(size(x, dims) .* x) @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft P = plan_bfft(x, dims) @test P * y ≈ fftw_bfft @test P \ (P * y) ≈ y + @test AbstractFFTs.region(P) == dims fftw_ifft = complex.(x) @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft P = plan_ifft(x, dims) @test P * y ≈ fftw_ifft @test P \ (P * y) ≈ y + @test AbstractFFTs.region(P) == dims # real FFT fftw_rfft = fftw_fft[ @@ -84,18 +87,21 @@ end @test eltype(P) === Int @test P * x ≈ fftw_rfft @test P \ (P * x) ≈ x + @test AbstractFFTs.region(P) == dims fftw_brfft = complex.(size(x, dims) .* x) @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft P = plan_brfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_brfft @test P \ (P * ry) ≈ ry + @test AbstractFFTs.region(P) == dims fftw_irfft = complex.(x) @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft P = plan_irfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_irfft @test P \ (P * ry) ≈ ry + @test AbstractFFTs.region(P) == dims end end @@ -170,7 +176,7 @@ end # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region) + f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, region(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end From 5fd7d337071c3f1f9b1fa18954dbf543d511c4a6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 28 Feb 2022 13:57:37 +0100 Subject: [PATCH 2/7] Update test/runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index c0e5752..1541770 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,7 +176,7 @@ end # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, region(p)) + f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, AbstractFFTs.region(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end From 68345ddf305cb6195f2fe7835a08cc6eac1bfb8d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 28 Feb 2022 13:59:25 +0100 Subject: [PATCH 3/7] Update documentation --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index c9b7b98..fd3e173 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,6 +19,7 @@ AbstractFFTs.brfft AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft +AbstractFFTs.region AbstractFFTs.fftshift AbstractFFTs.ifftshift AbstractFFTs.fftfreq From cd29d0d2b977957582752e4addeef957d42d30d7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 28 Feb 2022 15:51:55 +0100 Subject: [PATCH 4/7] Export region --- src/AbstractFFTs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 7b0902f..145c652 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -5,7 +5,7 @@ import ChainRulesCore export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, - fftshift, ifftshift, Frequencies, fftfreq, rfftfreq + region, fftshift, ifftshift, Frequencies, fftfreq, rfftfreq include("definitions.jl") include("chainrules.jl") From a70ff8709fea5f9bacedfe92814daf1347c5c718 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 30 Jun 2022 11:05:02 -0400 Subject: [PATCH 5/7] note region in README for devs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 89f7d48..4fed05a 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should inverse plan. * Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `region(p::MyPlan)` (which defaults to `p.region`). * Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. From b5d3920fb6575c808d1a5d14e43e6bcc1a63164a Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 30 Jun 2022 11:09:02 -0400 Subject: [PATCH 6/7] don't export region --- src/AbstractFFTs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 4315132..734c7d4 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -5,7 +5,7 @@ import ChainRulesCore export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, - region, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq + fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") include("chainrules.jl") From 8c4dcd9e73d956439c4a3296de2fb3efcd603127 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Thu, 30 Jun 2022 11:16:17 -0400 Subject: [PATCH 7/7] region(p) -> fftdims(p) --- README.md | 2 +- docs/src/api.md | 2 +- src/AbstractFFTs.jl | 2 +- src/definitions.jl | 10 +++++----- test/runtests.jl | 16 ++++++++-------- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 4fed05a..5b33c59 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should inverse plan. * Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `region(p::MyPlan)` (which defaults to `p.region`). + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). * Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. diff --git a/docs/src/api.md b/docs/src/api.md index fd3e173..1ed416b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,7 +19,7 @@ AbstractFFTs.brfft AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft -AbstractFFTs.region +AbstractFFTs.fftdims AbstractFFTs.fftshift AbstractFFTs.ifftshift AbstractFFTs.fftfreq diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 734c7d4..56d7123 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -5,7 +5,7 @@ import ChainRulesCore export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, - fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq + fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") include("chainrules.jl") diff --git a/src/definitions.jl b/src/definitions.jl index 642cf53..7901966 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -16,16 +16,16 @@ ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int """ - region(p::Plan) + fftdims(p::Plan) Return an iterable of the dimensions that are transformed by the FFT plan `p`. # Implementation -The default definition of `region` returns `p.region`. -Hence this method should be implemented only for types of `Plan`s that do not store the transformed region in a field of name `region`. +For legacy reasons, the default definition of `fftdims` returns `p.region`. +Hence this method should be implemented only for `Plan` subtypes that do not store the transformed dimensions in a field named `region`. """ -region(p::Plan) = p.region +fftdims(p::Plan) = p.region fftfloat(x) = _fftfloat(float(x)) _fftfloat(::Type{T}) where {T<:BlasReal} = T @@ -255,7 +255,7 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) -region(p::ScaledPlan) = region(p.p) +fftdims(p::ScaledPlan) = fftdims(p.p) show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p) summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) diff --git a/test/runtests.jl b/test/runtests.jl index 9a13516..623d625 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,21 +60,21 @@ end @test eltype(P) === ComplexF64 @test P * x ≈ fftw_fft @test P \ (P * x) ≈ x - @test AbstractFFTs.region(P) == dims + @test fftdims(P) == dims fftw_bfft = complex.(size(x, dims) .* x) @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft P = plan_bfft(x, dims) @test P * y ≈ fftw_bfft @test P \ (P * y) ≈ y - @test AbstractFFTs.region(P) == dims + @test fftdims(P) == dims fftw_ifft = complex.(x) @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft P = plan_ifft(x, dims) @test P * y ≈ fftw_ifft @test P \ (P * y) ≈ y - @test AbstractFFTs.region(P) == dims + @test fftdims(P) == dims # real FFT fftw_rfft = fftw_fft[ @@ -87,21 +87,21 @@ end @test eltype(P) === Int @test P * x ≈ fftw_rfft @test P \ (P * x) ≈ x - @test AbstractFFTs.region(P) == dims + @test fftdims(P) == dims fftw_brfft = complex.(size(x, dims) .* x) @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft P = plan_brfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_brfft @test P \ (P * ry) ≈ ry - @test AbstractFFTs.region(P) == dims - + @test fftdims(P) == dims + fftw_irfft = complex.(x) @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft P = plan_irfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_irfft @test P \ (P * ry) ≈ ry - @test AbstractFFTs.region(P) == dims + @test fftdims(P) == dims end end @@ -193,7 +193,7 @@ end # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, AbstractFFTs.region(p)) + f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end