diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 1bc17866..37797d85 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -12,9 +12,14 @@ using FillArrays: Fill export PowerMeasure -struct PowerMeasure{M,A} <: AbstractProductMeasure +struct PowerMeasure{M,N,A} <: AbstractProductMeasure{Fill{M,N,A}} parent::M axes::A + + function PowerMeasure(parent::M, axes::A) where {M,A} + N = length(axes) + new{M,N,A}(parent, axes) + end end function Pretty.tile(μ::PowerMeasure) @@ -42,8 +47,7 @@ end @inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N} a = axes(Fill{T,N}(x, sz)) - A = typeof(a) - PowerMeasure{T,A}(x, a) + PowerMeasure(x, a) end marginals(d::PowerMeasure) = Fill(d.parent, d.axes) @@ -76,7 +80,14 @@ end end @inline function logdensity_def( - d::PowerMeasure{M,Tuple{Base.OneTo{StaticInt{N}}}}, + d::PowerMeasure{M,1,Tuple{Base.OneTo{StaticInt{0}}}}, + x, +) where {M} + static(0.0) +end + +@inline function logdensity_def( + d::PowerMeasure{M,1,Tuple{Base.OneTo{StaticInt{N}}}}, x, ) where {M,N} parent = d.parent @@ -86,7 +97,7 @@ end end @inline function logdensity_def( - d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}}, + d::PowerMeasure{M,N,NTuple{N,Base.OneTo{StaticInt{0}}}}, x, ) where {M,N} static(0.0) @@ -108,7 +119,8 @@ end end end -@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) +# `prod` isn't static-friendly +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * (*(map(length, μ.axes)...)) @inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N} static(0) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index cb7a0aaf..bbe96513 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -8,7 +8,7 @@ using FillArrays export AbstractProductMeasure -abstract type AbstractProductMeasure <: AbstractMeasure end +abstract type AbstractProductMeasure{M} <: AbstractMeasure end function Pretty.tile(μ::AbstractProductMeasure) result = Pretty.literal("ProductMeasure(") @@ -76,7 +76,7 @@ end mapreduce(logdensity_def, +, marginals(d), x) end -struct ProductMeasure{M} <: AbstractProductMeasure +struct ProductMeasure{M} <: AbstractProductMeasure{M} marginals::M end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 2b50c1e5..34d7e4d8 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -17,16 +17,16 @@ function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) end function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, + ν::PowerMeasure{<:StdMeasure,1,<:NTuple{1,Base.OneTo}}, + μ::PowerMeasure{<:StdMeasure,1,<:NTuple{1,Base.OneTo}}, x, ) return transport_to(ν.parent, μ.parent).(x) end function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, + ν::PowerMeasure{<:StdMeasure,N,<:NTuple{N,Base.OneTo}}, + μ::PowerMeasure{<:StdMeasure,M,<:NTuple{M,Base.OneTo}}, x, ) where {N,M} return reshape(transport_to(ν.parent, μ.parent).(x), map(length, ν.axes)...) @@ -72,7 +72,7 @@ end function transport_def( ν::PowerMeasure{NU}, - μ::ProductMeasure{<:Tuple}, + μ::AbstractProductMeasure{<:Tuple}, x, ) where {NU<:StdMeasure} _tuple_transport_def(ν, marginals(μ), x) @@ -80,7 +80,7 @@ end function transport_def( ν::PowerMeasure{NU}, - μ::ProductMeasure{<:NamedTuple{names}}, + μ::AbstractProductMeasure{<:NamedTuple{names}}, x, ) where {NU<:StdMeasure,names} _tuple_transport_def(ν, values(marginals(μ)), values(x)) @@ -107,7 +107,7 @@ function _tuple_transport_def( end function transport_def( - ν::ProductMeasure{<:Tuple}, + ν::AbstractProductMeasure{<:Tuple}, μ::PowerMeasure{MU}, x, ) where {MU<:StdMeasure} @@ -115,7 +115,7 @@ function transport_def( end function transport_def( - ν::ProductMeasure{<:NamedTuple{names}}, + ν::AbstractProductMeasure{<:NamedTuple{names}}, μ::PowerMeasure{MU}, x, ) where {MU<:StdMeasure,names}