Skip to content

Commit

Permalink
Merge pull request #3140 from stan-dev/std_normal-vec
Browse files Browse the repository at this point in the history
Simplify vectorisation of std_normal_lpdf
  • Loading branch information
andrjohns authored Jan 16, 2025
2 parents 42d94c4 + a929e7e commit ac977e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
5 changes: 5 additions & 0 deletions stan/math/prim/fun/dot_self.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
namespace stan {
namespace math {

template <typename T, require_stan_scalar_t<T>* = nullptr>
inline T dot_self(const T& x) {
return x * x;
}

inline double dot_self(const std::vector<double>& x) {
double sum = 0.0;
for (double i : x) {
Expand Down
22 changes: 8 additions & 14 deletions stan/math/prim/prob/std_normal_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/dot_self.hpp>
#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>

namespace stan {
Expand Down Expand Up @@ -43,22 +43,16 @@ return_type_t<T_y> std_normal_lpdf(const T_y& y) {
return 0.0;
}

T_partials_return logp(0.0);
const auto& y_val = as_value_column_vector_or_scalar(y_ref);
T_partials_return logp = -dot_self(y_val) / 2.0;
auto ops_partials = make_partials_propagator(y_ref);

scalar_seq_view<T_y_ref> y_vec(y_ref);
size_t N = stan::math::size(y);

for (size_t n = 0; n < N; n++) {
const T_partials_return y_val = y_vec.val(n);
logp += y_val * y_val;
if (!is_constant_all<T_y>::value) {
partials<0>(ops_partials)[n] -= y_val;
}
if (!is_constant_all<T_y>::value) {
partials<0>(ops_partials) = -y_val;
}
logp *= -0.5;

if (include_summand<propto>::value) {
logp += NEG_LOG_SQRT_TWO_PI * N;
logp += NEG_LOG_SQRT_TWO_PI * math::size(y);
}

return ops_partials.build(logp);
Expand Down
8 changes: 8 additions & 0 deletions test/unit/math/mix/prob/std_normal_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,12 @@ TEST_F(AgradRev, mathMixScalFun_std_normal) {
stan::test::expect_ad(f, -0.3);
stan::test::expect_ad(f, 0.0);
stan::test::expect_ad(f, 1.7);

Eigen::VectorXd x(3);
x << -0.3, 0.0, 1.7;
std::vector<double> x2{0.0, 1.7};

stan::test::expect_ad(f, x);
stan::test::expect_ad(f, x.transpose().eval());
stan::test::expect_ad(f, x2);
}

0 comments on commit ac977e4

Please sign in to comment.