Skip to content

Commit

Permalink
Merge pull request #2930 from stan-dev/issue-2924-vectorised-log_sum_exp
Browse files Browse the repository at this point in the history
Vectorise binary log_sum_exp
  • Loading branch information
andrjohns authored Aug 12, 2023
2 parents 38289cd + b32ca34 commit d4eab27
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
20 changes: 19 additions & 1 deletion stan/math/prim/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <cmath>
#include <vector>

Expand Down Expand Up @@ -47,7 +48,8 @@ namespace math {
* @param a the first variable
* @param b the second variable
*/
template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr>
template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr,
require_all_stan_scalar_t<T1, T2>* = nullptr>
inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
if (a == NEGATIVE_INFTY) {
return b;
Expand Down Expand Up @@ -91,6 +93,22 @@ inline auto log_sum_exp(const T& x) {
});
}

/**
* Enables the vectorized application of the log_sum_exp function,
* when the first and/or second arguments are containers.
*
* @tparam T1 type of first input
* @tparam T2 type of second input
* @param a First input
* @param b Second input
* @return log_sum_exp function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
inline auto log_sum_exp(const T1& a, const T2& b) {
return apply_scalar_binary(
a, b, [](const auto& c, const auto& d) { return log_sum_exp(c, d); });
}

} // namespace math
} // namespace stan

Expand Down
13 changes: 13 additions & 0 deletions test/unit/math/mix/fun/log_sum_exp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,16 @@ TEST(MathMixMatFun, logSumExp) {
std::vector<double>(x2c.data(), x2c.data() + x2c.size())};
stan::test::expect_ad(tols, f, ststx);
}

TEST(mathMixScalFun, logSumExp_vec) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::log_sum_exp;
return log_sum_exp(x1, x2);
};

Eigen::VectorXd in1(2);
in1 << 3, 1;
Eigen::VectorXd in2(2);
in2 << 0.5, 3.4;
stan::test::expect_ad_vectorized_binary(f, in1, in2);
}
14 changes: 14 additions & 0 deletions test/unit/math/prim/fun/log_sum_exp_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <stan/math/prim.hpp>
#include <test/unit/math/prim/fun/binary_scalar_tester.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
Expand Down Expand Up @@ -129,3 +130,16 @@ TEST(MathFunctions, log_sum_exp_mat) {
ii << -std::numeric_limits<double>::infinity();
test_log_sum_exp(ii);
}

TEST(MathFunctions, log_sum_exp_vec) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::log_sum_exp;
return log_sum_exp(x1, x2);
};

Eigen::VectorXd in1(3);
in1 << 4.1, 3.24, 6.8;
Eigen::VectorXd in2(3);
in2 << 2.8, 1.7, 3.1;
stan::test::binary_scalar_tester(f, in1, in2);
}

0 comments on commit d4eab27

Please sign in to comment.