From 793b720f878fa5b654eb4c880ff5a430f987b19e Mon Sep 17 00:00:00 2001 From: "chung@molgen.mpg.de" Date: Wed, 20 Jan 2016 17:24:16 +0100 Subject: [PATCH] check normalization for sampling the latent sources --- src/rmn.cpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/rmn.cpp b/src/rmn.cpp index 7e914e3..b02f52e 100644 --- a/src/rmn.cpp +++ b/src/rmn.cpp @@ -147,14 +147,21 @@ void sampleU( arma::uword i; arma::uword L = s.size(); arma::vec p = arma::zeros(w + 1); - + double norm; for (i = 0; i < (w - 1) && i < L; i++){ if (t(i) > 0){ p(w) = empty(i) * mu(w, s(i)); p(arma::span(w - 1 - i, w - 1)) = z(arma::span(0, i)) % mu(arma::span(w - 1 - i, w - 1), s(i)); - p /= arma::sum(p); - gsl_ran_multinomial(r, p.size(), t(i), p.begin(), u.colptr(i)); - meanU.col(i) = p * t(i); + norm = arma::sum(p); + if (norm > 0){ + p /= arma::sum(p); + + gsl_ran_multinomial(r, p.size(), t(i), p.begin(), u.colptr(i)); + meanU.col(i) = p * t(i); + } + else{ + std::cout << "norm p = 0::" << p << std::endl; + } } else{ u.col(i).zeros(); @@ -166,9 +173,16 @@ void sampleU( if (t(i) > 0){ p(w) = empty(i) * mu(w, s(i)); p(arma::span(0, w - 1)) = z(arma::span(i - w + 1, i)) % mu(arma::span(0, w - 1), s(i)); - p /= arma::sum(p); - gsl_ran_multinomial(r, p.size(), t(i), p.begin(), u.colptr(i)); - meanU.col(i) = p * t(i); + norm = arma::sum(p); + if (norm > 0){ + p /= arma::sum(p); + + gsl_ran_multinomial(r, p.size(), t(i), p.begin(), u.colptr(i)); + meanU.col(i) = p * t(i); + } + else{ + std::cout << "norm p = 0::" << p << std::endl; + } } else{ u.col(i).zeros();