|
1 |
| -using ArrayLayouts: ArrayLayouts, MatMulMatAdd |
| 1 | +using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd |
2 | 2 | using BlockArrays: BlockLayout
|
3 | 3 | using ..SparseArrayInterface: SparseLayout
|
4 |
| -using LinearAlgebra: mul! |
| 4 | +using LinearAlgebra: dot, mul! |
5 | 5 |
|
6 | 6 | function blocksparse_muladd!(
|
7 |
| - α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix |
| 7 | + α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray |
8 | 8 | )
|
9 | 9 | mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
|
10 | 10 | return a_dest
|
11 | 11 | end
|
12 | 12 |
|
| 13 | +function blocksparse_matmul!(m::MulAdd) |
| 14 | + α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C |
| 15 | + blocksparse_muladd!(α, a1, a2, β, a_dest) |
| 16 | + return a_dest |
| 17 | +end |
| 18 | + |
13 | 19 | function ArrayLayouts.materialize!(
|
14 | 20 | m::MatMulMatAdd{
|
15 | 21 | <:BlockLayout{<:SparseLayout},
|
16 | 22 | <:BlockLayout{<:SparseLayout},
|
17 | 23 | <:BlockLayout{<:SparseLayout},
|
18 | 24 | },
|
19 | 25 | )
|
20 |
| - α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C |
21 |
| - blocksparse_muladd!(α, a1, a2, β, a_dest) |
22 |
| - return a_dest |
| 26 | + blocksparse_matmul!(m) |
| 27 | + return m.C |
| 28 | +end |
| 29 | +function ArrayLayouts.materialize!( |
| 30 | + m::MatMulVecAdd{ |
| 31 | + <:BlockLayout{<:SparseLayout}, |
| 32 | + <:BlockLayout{<:SparseLayout}, |
| 33 | + <:BlockLayout{<:SparseLayout}, |
| 34 | + }, |
| 35 | +) |
| 36 | + blocksparse_matmul!(m) |
| 37 | + return m.C |
| 38 | +end |
| 39 | + |
| 40 | +function blocksparse_dot(a1::AbstractArray, a2::AbstractArray) |
| 41 | + # TODO: Add a check that the blocking of `a1` and `a2` are |
| 42 | + # the same, or the same up to a reshape. |
| 43 | + return dot(blocks(a1), blocks(a2)) |
| 44 | +end |
| 45 | + |
| 46 | +function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}}) |
| 47 | + return blocksparse_dot(d.A, d.B) |
23 | 48 | end
|
0 commit comments