Skip to content

Commit

Permalink
Nevermind
Browse files Browse the repository at this point in the history
  • Loading branch information
Baxter Eaves committed Mar 21, 2024
1 parent 45f1cc4 commit 04f1273
Showing 1 changed file with 17 additions and 44 deletions.
61 changes: 17 additions & 44 deletions lace/lace_stats/src/prior_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,55 +100,28 @@ impl PriorProcessT for Dirichlet {
}

fn draw_assignment<R: Rng>(&self, n: usize, rng: &mut R) -> Assignment {
// if n == 0 {
// return Assignment::empty();
// }
// let mut counts = vec![1];
// let mut ps = vec![1.0, self.alpha];
// let mut zs = vec![0; n];

// for z in zs.iter_mut().take(n).skip(1) {
// let zi = pflip(&ps, 1, rng)[0];
// *z = zi;
// if zi < counts.len() {
// ps[zi] += 1.0;
// counts[zi] += 1;
// } else {
// ps[zi] = 1.0;
// ps.push(self.alpha);
// counts.push(1);
// };
// }
//
// Assignment {
// asgn: zs,
// n_cats: counts.len(),
// counts,
// }
let mut n_cats = 0;
let mut weights: Vec<f64> = vec![];
let mut asgn: Vec<usize> = Vec::with_capacity(n);

for _ in 0..n {
weights.push(self.alpha);
let k = pflip(&weights, 1, rng)[0];
asgn.push(k);
if n == 0 {
return Assignment::empty();
}
let mut counts = vec![1];
let mut ps = vec![1.0, self.alpha];
let mut zs = vec![0; n];

if k == n_cats {
weights[n_cats] = 1.0;
n_cats += 1;
for z in zs.iter_mut().take(n).skip(1) {
let zi = pflip(&ps, 1, rng)[0];
*z = zi;
if zi < counts.len() {
ps[zi] += 1.0;
counts[zi] += 1;
} else {
weights.truncate(n_cats);
weights[k] += 1.0;
}
ps[zi] = 1.0;
ps.push(self.alpha);
counts.push(1);
};
}
// convert weights to counts, correcting for possible floating point
// errors
let counts: Vec<usize> =
weights.iter().map(|w| (w + 0.5) as usize).collect();

Assignment {
asgn,
asgn: zs,
n_cats: counts.len(),
counts,
}
Expand Down

0 comments on commit 04f1273

Please sign in to comment.