|
| 1 | +// Copyright (C) 2019 Piotr (Peter) Beben <[email protected]> |
| 2 | +// See LICENSE included. |
| 3 | + |
| 4 | +#define EIGEN_NO_MALLOC |
| 5 | + |
| 6 | +#include "OrthogonalPursuit.h" |
| 7 | +#include "ensure_buffer_size.h" |
| 8 | +#include "constants.h" |
| 9 | +//#include <functional> |
| 10 | + |
| 11 | +using Eigen::Matrix; |
| 12 | +using Eigen::Map; |
| 13 | +using Eigen::Dynamic; |
| 14 | +using Eigen::LDLT; |
| 15 | +using Eigen::Aligned16; |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +//------------------------------------------------------------------------- |
| 20 | +void OrthogonalPursuit::ensure(Index nd, Index na, Index lm) |
| 21 | +{ |
| 22 | + if(nd != ndim || na != natm || lm != lmax){ |
| 23 | + ndim = nd; |
| 24 | + natm = na; |
| 25 | + lmax = (na < lm) ? na : lm; |
| 26 | + ensureWorkspace(); |
| 27 | + } |
| 28 | +} |
| 29 | + |
| 30 | +//------------------------------------------------------------------------- |
| 31 | + |
| 32 | +void OrthogonalPursuit::ensureWorkspace() |
| 33 | +{ |
| 34 | + // Allocate more workspace as necessary. |
| 35 | + //std::function< size_t(Index) > align_padded = |
| 36 | + // [=](Index n) ->size_t { return ALIGNEDX*(1+(n/ALIGNEDX)); }; |
| 37 | + size_t paddednd = align_padded(ndim); |
| 38 | + size_t paddedlm = align_padded(lmax); |
| 39 | + size_t paddedndlm = align_padded(ndim*lmax); |
| 40 | + size_t paddedlmsq = align_padded(lmax*lmax); |
| 41 | + |
| 42 | + size_t ndworkNeed = paddednd + 3*paddedlm + 2*paddedndlm + 2*paddedlmsq; |
| 43 | + ensure_buffer_size(ndworkNeed+ndworkNeed/2, dwork); |
| 44 | + |
| 45 | + // Map to pre-allocated workspace. |
| 46 | + size_t p = 0; |
| 47 | + new (&U) MapVectf(&dwork[0],ndim); |
| 48 | + p = p + paddednd; |
| 49 | + new (&V) MapVectf(&dwork[p],lmax); |
| 50 | + p = p + paddedlm; |
| 51 | + new (&W) MapVectf(&dwork[p],lmax); |
| 52 | + p = p + paddedlm; |
| 53 | + new (&XI) MapVectf(&dwork[p],lmax); |
| 54 | + p = p + paddedlm; |
| 55 | + new (&E) MapMtrxf(&dwork[p],ndim,lmax); |
| 56 | + p = p + paddedndlm; |
| 57 | + new (&F) MapMtrxf(&dwork[p],lmax,lmax); |
| 58 | + p = p + paddedlmsq; |
| 59 | + dworkOffset = p; |
| 60 | + |
| 61 | + size_t niworkNeed = lmax; |
| 62 | + ensure_buffer_size(niworkNeed+niworkNeed/2, iwork); |
| 63 | + |
| 64 | + // Map to pre-allocated workspace. |
| 65 | + new (&I) Map<Matrix<Index,Dynamic,1>>(&iwork[0],lmax); |
| 66 | + |
| 67 | + if( lmax*lmax > ldltSize ){ |
| 68 | + ldltSize = lmax*lmax; |
| 69 | + delete ldlt; |
| 70 | + ldlt = new LDLT<MatrixXf>(ldltSize); |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +//------------------------------------------------------------------------- |
| 75 | +/** |
| 76 | + Orthogonal Matching Pursuit: |
| 77 | +
|
| 78 | + Similar to matching pursuit, but provides a better approximation |
| 79 | + at the expense of significantly more computation. Namely, updates |
| 80 | + all the coefficients of the current code vector X' at each iteration |
| 81 | + so that DX' is an orthogonal projection of the signal vector Y onto |
| 82 | + the subspace spanned by the dictionary atoms corresponding to the |
| 83 | + nonzero entries of X'. |
| 84 | +
|
| 85 | + @param[in] Y: size n vector. |
| 86 | + @param[in] D: n x m dictionary matrix of unit column vectors. |
| 87 | + @param[in] latm: Sparsity constraint. |
| 88 | + @param[out] X: size m code vector. |
| 89 | + @param[out] R: size n residual vector. |
| 90 | +
|
| 91 | +*/ |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +void OrthogonalPursuit::operator() ( |
| 96 | + const VectorXf& Y, const MatrixXf& D, Index latm, |
| 97 | + VectorXf& X, VectorXf& R) |
| 98 | +{ |
| 99 | + assert(D.rows() == Y.rows()); |
| 100 | + assert(D.cols() == X.rows()); |
| 101 | + ensure(D.rows(),D.cols(),latm); |
| 102 | + |
| 103 | + X.setZero(); |
| 104 | + R = Y; |
| 105 | + |
| 106 | + for(Index j = 1; j <= latm; ++j){ |
| 107 | + // Find the next 'nearest' atom to current residual R. |
| 108 | + float absprojmax = -float_infinity; |
| 109 | + float projmax = 0; |
| 110 | + Index imax = 0; |
| 111 | + for(Index i = 0; i < natm; ++i){ |
| 112 | + if( X(i) != 0.0f ) continue; |
| 113 | + float proj = R.dot(D.col(i)); |
| 114 | + float absproj = abs(proj); |
| 115 | + if( absproj > absprojmax ){ |
| 116 | + projmax = proj; |
| 117 | + absprojmax = absproj; |
| 118 | + imax = i; |
| 119 | + } |
| 120 | + } |
| 121 | + U = D.col(imax); // Dictionary atom U 'nearest' to R |
| 122 | + E.col(j-1) = U; // ...save it in j^th column of E |
| 123 | + I(j-1) = imax; // ...and save column index of U |
| 124 | + X(imax) = 1.0f; // Set temporarily 1.0 to mark traversed. |
| 125 | + |
| 126 | + // Map to pre-allocated workspace. |
| 127 | + Index p = dworkOffset; |
| 128 | + new (&ETblk) MapMtrxf(&dwork[p],j,ndim); |
| 129 | + p = p + align_padded(j*ndim); |
| 130 | + new (&Fblk) MapMtrxf(&dwork[p],j,j); |
| 131 | + |
| 132 | + // With U added to the current set E of j nearest atoms, |
| 133 | + // optimise the coefficients of XI w.r.t this E. This is |
| 134 | + // done by projecting Y onto the subspace spanned by E. |
| 135 | + if( j > 1 ) { |
| 136 | + // Compute the product E^T(:,1:j) * E(:,1:j), |
| 137 | + // This can be done quicker by reusing the product |
| 138 | + // E^T(:,1:j-1) * E(:,1:j-1) from the previous |
| 139 | + // iteration. |
| 140 | + |
| 141 | + V.segment(0,j-1).noalias() = U.transpose() * E.block(0,0,ndim,j-1); |
| 142 | + F.col(j-1).segment(0,j-1) = V.segment(0,j-1); |
| 143 | + F.row(j-1).segment(0,j-1) = V.segment(0,j-1); |
| 144 | + F(j-1,j-1) = 1.0f; |
| 145 | + |
| 146 | + Fblk = F.block(0,0,j,j); |
| 147 | + ETblk = E.block(0,0,ndim,j).transpose(); |
| 148 | + W.segment(0,j).noalias() = ETblk * Y; |
| 149 | + // Solve (E^T*E)*XI = (E^T)*Y |
| 150 | + ldlt->compute(Fblk); |
| 151 | + XI.segment(0,j) = ldlt->solve(W.segment(0,j)); |
| 152 | + |
| 153 | + //Update residual R |
| 154 | + R = Y; |
| 155 | + R.noalias() -= (E.block(0,0,ndim,j))*(XI.segment(0,j)); |
| 156 | + } |
| 157 | + else{ |
| 158 | + F(0,0) = 1.0f; |
| 159 | + XI(0) = Y.dot(U); |
| 160 | + //Update residual R |
| 161 | + R = Y; |
| 162 | + R.noalias() -= XI(0)*U; |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + // Map back to code vector. |
| 167 | + for(Index i = 0; i < latm; ++i){ |
| 168 | + X(I(i)) = XI(i); |
| 169 | + } |
| 170 | + |
| 171 | + |
| 172 | +} |
| 173 | +//------------------------------------------------------------------------- |
0 commit comments