Skip to content

Commit

Permalink
Merge pull request #2931 from stan-dev/feature/2845-tuple-fns
Browse files Browse the repository at this point in the history
Add tuple-returning special functions
  • Loading branch information
WardBrian authored Aug 18, 2023
2 parents d4eab27 + 497dc71 commit 5d1fd38
Show file tree
Hide file tree
Showing 53 changed files with 1,194 additions and 71 deletions.
5 changes: 5 additions & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <stan/math/prim/fun/cov_matrix_free.hpp>
#include <stan/math/prim/fun/cov_matrix_free_lkj.hpp>
#include <stan/math/prim/fun/crossprod.hpp>
#include <stan/math/prim/fun/csr_extract.hpp>
#include <stan/math/prim/fun/csr_extract_u.hpp>
#include <stan/math/prim/fun/csr_extract_v.hpp>
#include <stan/math/prim/fun/csr_extract_w.hpp>
Expand All @@ -84,6 +85,8 @@
#include <stan/math/prim/fun/dot_product.hpp>
#include <stan/math/prim/fun/dot_self.hpp>
#include <stan/math/prim/fun/eigen_comparisons.hpp>
#include <stan/math/prim/fun/eigendecompose.hpp>
#include <stan/math/prim/fun/eigendecompose_sym.hpp>
#include <stan/math/prim/fun/eigenvalues.hpp>
#include <stan/math/prim/fun/eigenvalues_sym.hpp>
#include <stan/math/prim/fun/eigenvectors.hpp>
Expand Down Expand Up @@ -277,6 +280,7 @@
#include <stan/math/prim/fun/qr.hpp>
#include <stan/math/prim/fun/qr_Q.hpp>
#include <stan/math/prim/fun/qr_R.hpp>
#include <stan/math/prim/fun/qr_thin.hpp>
#include <stan/math/prim/fun/qr_thin_Q.hpp>
#include <stan/math/prim/fun/qr_thin_R.hpp>
#include <stan/math/prim/fun/quad_form.hpp>
Expand Down Expand Up @@ -333,6 +337,7 @@
#include <stan/math/prim/fun/sub_row.hpp>
#include <stan/math/prim/fun/subtract.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/svd.hpp>
#include <stan/math/prim/fun/svd_U.hpp>
#include <stan/math/prim/fun/svd_V.hpp>
#include <stan/math/prim/fun/symmetrize_from_lower_tri.hpp>
Expand Down
42 changes: 36 additions & 6 deletions stan/math/prim/fun/complex_schur_decompose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ namespace math {
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
complex_schur_decompose_u(const M& m) {
if (m.size() == 0)
if (unlikely(m.size() == 0)) {
return m;
}
check_square("complex_schur_decompose_u", "m", m);
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
// copy because ComplexSchur requires Eigen::Matrix type
MatType mv = m;
Eigen::ComplexSchur<MatType> cs(mv);
Eigen::ComplexSchur<MatType> cs{MatType(m)};
return cs.matrixU();
}

Expand All @@ -51,16 +51,46 @@ complex_schur_decompose_u(const M& m) {
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
complex_schur_decompose_t(const M& m) {
if (m.size() == 0)
if (unlikely(m.size() == 0)) {
return m;
}
check_square("complex_schur_decompose_t", "m", m);
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
// copy because ComplexSchur requires Eigen::Matrix type
MatType mv = m;
Eigen::ComplexSchur<MatType> cs(mv, false);
Eigen::ComplexSchur<MatType> cs{MatType(m), false};
return cs.matrixT();
}

/**
* Return the complex Schur decomposition of the
* specified square matrix.
*
* The complex Schur decomposition of a square matrix `A` produces a
* complex unitary matrix `U` and a complex upper-triangular Schur
* form matrix `T` such that `A = U * T * inv(U)`. Further, the
* unitary matrix's inverse is equal to its conjugate transpose,
* `inv(U) = U*`, where `U*(i, j) = conj(U(j, i))`
*
* @tparam M type of matrix
* @param m real matrix to decompose
* @return a tuple (U,T) where U is the complex unitary matrix of the complex
* Schur decomposition of `m` and T is the Schur form matrix of
* the complex Schur decomposition of `m`
*/
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
inline std::tuple<Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>,
Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>>
complex_schur_decompose(const M& m) {
if (unlikely(m.size() == 0)) {
return std::make_tuple(m, m);
}
check_square("complex_schur_decompose", "m", m);
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
// copy because ComplexSchur requires Eigen::Matrix type
Eigen::ComplexSchur<MatType> cs{MatType(m)};
return std::make_tuple(std::move(cs.matrixU()), std::move(cs.matrixT()));
}

} // namespace math
} // namespace stan
#endif
66 changes: 66 additions & 0 deletions stan/math/prim/fun/csr_extract.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP
#define STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
namespace math {

/** \addtogroup csr_format
* @{
*/

/**
* Extract the non-zero values, column indexes for non-zero values, and
* the NZE index for each entry from a sparse matrix.
*
* @tparam T type of elements in the matrix
* @param[in] A sparse matrix.
* @return a tuple W,V,U.
*/
template <typename T>
const std::tuple<Eigen::Matrix<T, Eigen::Dynamic, 1>, std::vector<int>,
std::vector<int>>
csr_extract(const Eigen::SparseMatrix<T, Eigen::RowMajor>& A) {
auto a_nonzeros = A.nonZeros();
Eigen::Matrix<T, Eigen::Dynamic, 1> w
= Eigen::Matrix<T, Eigen::Dynamic, 1>::Zero(a_nonzeros);
std::vector<int> v(a_nonzeros);
for (int nze = 0; nze < a_nonzeros; ++nze) {
w[nze] = *(A.valuePtr() + nze);
v[nze] = *(A.innerIndexPtr() + nze) + stan::error_index::value;
}
std::vector<int> u(A.outerSize() + 1); // last entry is garbage.
for (int nze = 0; nze <= A.outerSize(); ++nze) {
u[nze] = *(A.outerIndexPtr() + nze) + stan::error_index::value;
}
return std::make_tuple(std::move(w), std::move(v), std::move(u));
}

/* Extract the non-zero values from a dense matrix by converting
* to sparse and calling the sparse matrix extractor.
*
* @tparam T type of elements in the matrix
* @tparam R number of rows, can be Eigen::Dynamic
* @tparam C number of columns, can be Eigen::Dynamic
*
* @param[in] A dense matrix.
* @return a tuple W,V,U.
*/
template <typename T, require_eigen_dense_base_t<T>* = nullptr>
const std::tuple<Eigen::Matrix<scalar_type_t<T>, Eigen::Dynamic, 1>,
std::vector<int>, std::vector<int>>
csr_extract(const T& A) {
// conversion to sparse seems to touch data twice, so we need to call to_ref
Eigen::SparseMatrix<scalar_type_t<T>, Eigen::RowMajor> B
= to_ref(A).sparseView();
return csr_extract(B);
}

/** @} */ // end of csr_format group

} // namespace math
} // namespace stan

#endif
7 changes: 4 additions & 3 deletions stan/math/prim/fun/csr_extract_w.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ namespace math {
template <typename T>
const Eigen::Matrix<T, Eigen::Dynamic, 1> csr_extract_w(
const Eigen::SparseMatrix<T, Eigen::RowMajor>& A) {
Eigen::Matrix<T, Eigen::Dynamic, 1> w(A.nonZeros());
w.setZero();
for (int nze = 0; nze < A.nonZeros(); ++nze) {
auto a_nonzeros = A.nonZeros();
Eigen::Matrix<T, Eigen::Dynamic, 1> w
= Eigen::Matrix<T, Eigen::Dynamic, 1>::Zero(a_nonzeros);
for (int nze = 0; nze < a_nonzeros; ++nze) {
w[nze] = *(A.valuePtr() + nze);
}
return w;
Expand Down
74 changes: 74 additions & 0 deletions stan/math/prim/fun/eigendecompose.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_HPP
#define STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/err.hpp>

namespace stan {
namespace math {

/**
* Return the eigendecomposition of a (real-valued) matrix
*
* @tparam EigMat type of real matrix argument
* @param[in] m matrix to find the eigendecomposition of. Must be square and
* have a non-zero size.
* @return A tuple V,D where V is a matrix where the columns are the
* complex-valued eigenvectors of `m` and D is a complex-valued column vector
* with entries the eigenvectors of `m`
*/
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
require_not_vt_complex<EigMat>* = nullptr>
inline std::tuple<Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>,
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>>
eigendecompose(const EigMat& m) {
if (unlikely(m.size() == 0)) {
return std::make_tuple(
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>(0, 0),
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>(0, 1));
}
check_square("eigendecompose", "m", m);

using PlainMat = plain_type_t<EigMat>;
const PlainMat& m_eval = m;

Eigen::EigenSolver<PlainMat> solver(m_eval);
return std::make_tuple(std::move(solver.eigenvectors()),
std::move(solver.eigenvalues()));
}

/**
* Return the eigendecomposition of a (complex-valued) matrix
*
* @tparam EigCplxMat type of complex matrix argument
* @param[in] m matrix to find the eigendecomposition of. Must be square and
* have a non-zero size.
* @return A tuple V,D where V is a matrix where the columns are the
* complex-valued eigenvectors of `m` and D is a complex-valued column vector
* with entries the eigenvectors of `m`
*/
template <typename EigCplxMat,
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
inline std::tuple<
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>,
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>>
eigendecompose(const EigCplxMat& m) {
if (unlikely(m.size() == 0)) {
return std::make_tuple(
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>(0, 0),
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>(0, 1));
}
check_square("eigendecompose", "m", m);

using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
const PlainMat& m_eval = m;

Eigen::ComplexEigenSolver<PlainMat> solver(m_eval);

return std::make_tuple(std::move(solver.eigenvectors()),
std::move(solver.eigenvalues()));
}

} // namespace math
} // namespace stan
#endif
41 changes: 41 additions & 0 deletions stan/math/prim/fun/eigendecompose_sym.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_SYM_HPP
#define STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_SYM_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

namespace stan {
namespace math {

/**
* Return the eigendecomposition of the specified symmetric matrix.
*
* @tparam EigMat type of the matrix
* @param m Specified matrix.
* @return A tuple V,D where V is a matrix where the columns are the
* eigenvectors of m, and D is a column vector of the eigenvalues of m.
* The eigenvalues are in ascending order of magnitude, with the eigenvectors
* provided in the same order.
*/
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
require_not_st_var<EigMat>* = nullptr>
std::tuple<Eigen::Matrix<value_type_t<EigMat>, -1, -1>,
Eigen::Matrix<value_type_t<EigMat>, -1, 1>>
eigendecompose_sym(const EigMat& m) {
if (unlikely(m.size() == 0)) {
return std::make_tuple(Eigen::Matrix<value_type_t<EigMat>, -1, -1>(0, 0),
Eigen::Matrix<value_type_t<EigMat>, -1, 1>(0, 1));
}
using PlainMat = plain_type_t<EigMat>;
const PlainMat& m_eval = m;
check_symmetric("eigendecompose_sym", "m", m_eval);

Eigen::SelfAdjointEigenSolver<PlainMat> solver(m_eval);
return std::make_tuple(std::move(solver.eigenvectors()),
std::move(solver.eigenvalues()));
}

} // namespace math
} // namespace stan
#endif
16 changes: 10 additions & 6 deletions stan/math/prim/fun/eigenvalues.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
require_not_vt_complex<EigMat>* = nullptr>
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1> eigenvalues(
const EigMat& m) {
if (unlikely(m.size() == 0)) {
return Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>(0, 1);
}
check_square("eigenvalues", "m", m);
using PlainMat = plain_type_t<EigMat>;
const PlainMat& m_eval = m;
check_nonzero_size("eigenvalues", "m", m_eval);
check_square("eigenvalues", "m", m_eval);

Eigen::EigenSolver<PlainMat> solver(m_eval, false);
return solver.eigenvalues();
Expand All @@ -37,14 +39,16 @@ inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1> eigenvalues(
* @return a complex-valued column vector with entries the eigenvectors of `m`
*/
template <typename EigCplxMat,
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
require_vt_complex<EigCplxMat>* = nullptr>
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>
eigenvalues(const EigCplxMat& m) {
if (unlikely(m.size() == 0)) {
return Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>(0,
1);
}
check_square("eigenvalues", "m", m);
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
const PlainMat& m_eval = m;
check_nonzero_size("eigenvalues", "m", m_eval);
check_square("eigenvalues", "m", m_eval);

Eigen::ComplexEigenSolver<PlainMat> solver(m_eval, false);

Expand Down
7 changes: 4 additions & 3 deletions stan/math/prim/fun/eigenvalues_sym.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ namespace math {

/**
* Return the eigenvalues of the specified symmetric matrix
* in descending order of magnitude. This function is more
* in ascending order of magnitude. This function is more
* efficient than the general eigenvalues function for symmetric
* matrices.
* <p>See <code>eigen_decompose()</code> for more information.
*
* @tparam EigMat type of the matrix
* @param m Specified matrix.
Expand All @@ -22,9 +21,11 @@ namespace math {
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
require_not_st_var<EigMat>* = nullptr>
Eigen::Matrix<value_type_t<EigMat>, -1, 1> eigenvalues_sym(const EigMat& m) {
if (unlikely(m.size() == 0)) {
return Eigen::Matrix<value_type_t<EigMat>, -1, 1>(0, 1);
}
using PlainMat = plain_type_t<EigMat>;
const PlainMat& m_eval = m;
check_nonzero_size("eigenvalues_sym", "m", m_eval);
check_symmetric("eigenvalues_sym", "m", m_eval);

Eigen::SelfAdjointEigenSolver<PlainMat> solver(m_eval,
Expand Down
16 changes: 10 additions & 6 deletions stan/math/prim/fun/eigenvectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
require_not_vt_complex<EigMat>* = nullptr>
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>
eigenvectors(const EigMat& m) {
if (unlikely(m.size() == 0)) {
return Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>(0, 0);
}
check_square("eigenvectors", "m", m);
using PlainMat = plain_type_t<EigMat>;
const PlainMat& m_eval = m;
check_nonzero_size("eigenvectors", "m", m_eval);
check_square("eigenvectors", "m", m_eval);

Eigen::EigenSolver<PlainMat> solver(m_eval);
return solver.eigenvectors();
Expand All @@ -39,14 +41,16 @@ eigenvectors(const EigMat& m) {
* `m`
*/
template <typename EigCplxMat,
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
require_vt_complex<EigCplxMat>* = nullptr>
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>
eigenvectors(const EigCplxMat& m) {
if (unlikely(m.size() == 0)) {
return Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>(0,
0);
}
check_square("eigenvectors", "m", m);
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
const PlainMat& m_eval = m;
check_nonzero_size("eigenvectors", "m", m_eval);
check_square("eigenvectors", "m", m_eval);

Eigen::ComplexEigenSolver<PlainMat> solver(m_eval);
return solver.eigenvectors();
Expand Down
Loading

0 comments on commit 5d1fd38

Please sign in to comment.