Skip to content

Commit

Permalink
Remove select broadcast hack
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Aug 16, 2023
1 parent 8f5cdb8 commit cabafd6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions stan/math/prim/prob/bernoulli_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
return ops_partials.build(NEGATIVE_INFTY);
}

// Use select() to broadcast theta values & gradients if necessary
size_t theta_size = math::size(theta_arr);
size_t n_size = math::size(n_arr);
double broadcast_n = theta_size == n_size ? 1 : std::fmax(theta_size, n_size);

if (!is_constant_all<T_prob>::value) {
partials<0>(ops_partials) = select(true, inv(theta_arr), n_arr);
partials<0>(ops_partials) = inv(theta_arr) * broadcast_n;
}

return ops_partials.build(sum(select(true, log(theta_arr), n_arr)));
return ops_partials.build(sum(log(theta_arr)) * broadcast_n);
}

} // namespace math
Expand Down

0 comments on commit cabafd6

Please sign in to comment.