Skip to content

Commit 2d017c9

Browse files
committed
STASH bind
1 parent 56157f9 commit 2d017c9

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

Diff for: src/MeasureBase.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import Base.iterate
2828
import ConstructionBase
2929
using ConstructionBase: constructorof
3030
using IntervalSets
31-
using OneTwoMany: secondarg
31+
using OneTwoMany: firstarg, secondarg
3232

3333
using PrettyPrinting
3434
const Pretty = PrettyPrinting

Diff for: src/combinators/bind.jl

+21-18
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ See also [`mbind`](@ref).
3030
function mkernel end
3131
export mkernel
3232

33-
@inline mkernel(f_β::MKernel) = f_β
34-
@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c)
35-
36-
@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c)
37-
@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β
38-
3933

4034
"""
4135
struct MeasureBase.MKernel <: Function
@@ -45,12 +39,20 @@ Represents a generalized monatic transition kernel.
4539
User code should not create instances of `MKernel` directly, but should call
4640
[`mkernel`](@ref) instead.
4741
"""
48-
struct MKernel
49-
f_β::FK
42+
struct MKernel{FT,FC} <: Function
43+
f_β::FT
5044
f_c::FC
5145
end
5246

5347

48+
@inline mkernel(f_β::MKernel) = f_β
49+
@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c)
50+
51+
@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c)
52+
@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β
53+
54+
55+
5456
@doc raw"""
5557
mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg)
5658
mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure)
@@ -102,7 +104,7 @@ The measure `α` that went into the bind can be retrieved via
102104
103105
Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)`
104106
can be unambiguously split into `a` and `b` again, knowing `α`. This is
105-
currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these
107+
currently implemented for `f_c` that is either `tuple` or `=>`/`Pair` (these
106108
work for any combination of variate types), `vcat` (for tuple- or vector-like
107109
variates) and `merge` (`NamedTuple` variates).
108110
[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to
@@ -152,19 +154,20 @@ export mbind
152154

153155
@inline mbind(f_β) = Base.Fix1(mbind, f_β)
154156

155-
@inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c)
156-
157-
#@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary ---
158-
@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, α, f_c)
157+
@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, asmeasure(α), f_c)
159158

160159
@inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c)
161160
F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c)
162161
Bind{F,M,G}(f_β, α, f_c)
163162
end
164163

165-
function _generic_mbind_impl(f_β, α::Dirac, f_c)
166-
mcombine(f_c, α, f_β.x))
167-
end
164+
@inline _generic_mbind_impl(f_β, α::Dirac, f_c) = mcombine(f_c, α, f_β.x))
165+
166+
@inline _generic_mbind_impl(@nospecialize(f_β), α::AbstractMeasure, ::typeof(firstarg)) = α
167+
@inline _generic_mbind_impl(@nospecialize(f_β), α::Dirac, ::typeof(firstarg)) = α
168+
169+
@inline _generic_mbind_impl(f_k::MKernel, α::AbstractMeasure, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c)
170+
@inline _generic_mbind_impl(f_k::MKernel, α::Dirac, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c)
168171

169172

170173
"""
@@ -175,8 +178,8 @@ Represents a monatic bind resp. a mbind in general.
175178
User code should not create instances of `Bind` directly, but should call
176179
[`mbind`](@ref) instead.
177180
"""
178-
struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure
179-
f_β::FK
181+
struct Bind{FT,M<:AbstractMeasure,FC} <: AbstractMeasure
182+
f_β::FT
180183
α::M
181184
f_c::FC
182185
end

Diff for: src/combinators/combined.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ export mcombine
6262

6363
@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β)
6464

65-
@inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α
66-
@inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β
67-
@inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β
65+
@inline _generic_mcombine_impl_stage1(::typeof(firstarg), α::AbstractMeasure, β::AbstractMeasure) = α
66+
@inline _generic_mcombine_impl_stage1(::typeof(secondarg), α::AbstractMeasure, β::AbstractMeasure) = β
6867

6968
@inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure)
7069
productmeasure((α, β))

0 commit comments

Comments
 (0)