Skip to content

Commit

Permalink
[fix,test] fix bugs in poisson sampling and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pblischak committed May 13, 2024
1 parent 97b3dfe commit c487fff
Showing 1 changed file with 110 additions and 43 deletions.
153 changes: 110 additions & 43 deletions src/poisson.zig
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
//! Poisson distribution with parameter `lambda`.

// zig fmt: off
//! Poisson distribution
//!
//! [https://en.wikipedia.org/wiki/Poisson_distribution](https://en.wikipedia.org/wiki/Poisson_distribution)

const std = @import("std");
const math = std.math;
const Random = std.rand.Random;
const Allocator = std.mem.Allocator;
const Random = std.Random;

const spec_fn = @import("special_functions.zig");
const utils = @import("utils.zig");

/// Poisson distribution with parameter `lambda`.
pub fn Poisson(comptime I: type, comptime F: type) type {
_ = utils.ensureIntegerType(I);
_ = utils.ensureFloatType(F);

return struct {
const Self = @This();

prng: *Random,
rand: *Random,

pub fn init(prng: *Random) Self {
pub fn init(rand: *Random) Self {
return Self{
.prng = prng,
.rand = rand,
};
}

Expand All @@ -43,13 +49,26 @@ pub fn Poisson(comptime I: type, comptime F: type) type {
}
}

pub fn sampleSlice(
self: Self,
size: usize,
lambda: F,
allocator: Allocator,
) ![]I {
var res = try allocator.alloc(I, size);
for (0..size) |i| {
res[i] = self.sample(lambda);
}
return res;
}

fn low(self: Self, lambda: F) I {
const d: F = @sqrt(lambda);
if (self.prng.float(F) >= d) {
if (@as(F, @floatCast(self.rand.float(f64))) >= d) {
return 0;
}

const r = self.prng.float(F) * d;
const r: F = @as(F, @floatCast(self.rand.float(f64))) * d;
if (r > lambda * (1.0 - lambda)) {
return 0;
}
Expand All @@ -68,7 +87,7 @@ pub fn Poisson(comptime I: type, comptime F: type) type {
var f: F = p_f0;

while (true) {
r = self.prng.float(F);
r = @floatCast(self.rand.float(f64));
x = 0;
f = p_f0;

Expand All @@ -78,17 +97,17 @@ pub fn Poisson(comptime I: type, comptime F: type) type {
return x;
}
x += 1;
f += lambda;
f += @as(F, @floatFromInt(x));
f *= lambda;
r *= @as(F, @floatFromInt(x));

while (x <= bound) {
r -= f;
if (r <= 0.0) {
return x;
}
x += 1;
f += lambda;
f += @as(F, @floatFromInt(x));
f *= lambda;
r *= @as(F, @floatFromInt(x));
}
}
}
Expand All @@ -99,20 +118,20 @@ pub fn Poisson(comptime I: type, comptime F: type) type {
var x: F = undefined;
var k: I = undefined;

var p_a = lambda + 0.5;
var mode = @as(I, @intFromFloat(lambda));
var p_g = @log(lambda);
var p_q = @as(F, @floatFromInt(mode)) * p_g - spec_fn.lnFactorial(I, F, mode);
var p_h = @sqrt(2.943035529371538573 * (lambda + 0.5)) + 0.8989161620588987408;
var p_bound = @as(I, @intFromFloat(p_a + 6.0 * p_h));
const p_a = lambda + 0.5;
const mode = @as(I, @intFromFloat(lambda));
const p_g = @log(lambda);
const p_q = @as(F, @floatFromInt(mode)) * p_g - spec_fn.lnFactorial(I, F, mode);
const p_h = @sqrt(2.943035529371538573 * (lambda + 0.5)) + 0.8989161620588987408;
const p_bound = @as(I, @intFromFloat(p_a + 6.0 * p_h));

while (true) {
u = self.prng.float(F);
u = @floatCast(self.rand.float(f64));
if (u == 0) {
continue;
}

x = p_a + p_h * (self.prng.float(F) - 0.5) / u;
x = p_a + p_h * (@as(F, @floatCast(self.rand.float(f64))) - 0.5) / u;
if (x < 0.0 or x >= @as(F, @floatFromInt(p_bound))) {
continue;
}
Expand All @@ -132,32 +151,80 @@ pub fn Poisson(comptime I: type, comptime F: type) type {
return k;
}

pub fn pmf(self: Self, k: I, lambda: F) I {
return @exp(self.lnPmf(I, F, k, lambda));
pub fn pmf(self: Self, k: I, lambda: F) F {
return @exp(self.lnPmf(k, lambda));
}

pub fn lnPmf(k: I, lambda: F) I {
return @as(F, @floatFromInt(k)) * @log(lambda) - lambda + spec_fn.lnFactorial(k);
pub fn lnPmf(self: Self, k: I, lambda: F) F {
_ = self;
return @as(
F,
@floatFromInt(k),
) * @log(lambda) - lambda + spec_fn.lnFactorial(I, F, k);
}
};
}

test "Poisson API" {
const DefaultPrng = std.rand.Xoshiro256;
test "Sample Poisson" {
const seed: u64 = @intCast(std.time.microTimestamp());
var prng = std.Random.DefaultPrng.init(seed);
var rand = prng.random();

var poisson = Poisson(u32, f64).init(&rand);
const val = poisson.sample(20.0);
std.debug.print("\n{}\n", .{val});
}

test "Sample Poisson Slice" {
const seed: u64 = @intCast(std.time.microTimestamp());
var prng = std.Random.DefaultPrng.init(seed);
var rand = prng.random();

var poisson = Poisson(u32, f64).init(&rand);
const allocator = std.testing.allocator;
const sample = try poisson.sampleSlice(100, 20.0, allocator);
defer allocator.free(sample);
std.debug.print("\n{any}\n", .{sample});
}

test "Poisson Mean" {
const seed: u64 = @intCast(std.time.milliTimestamp());
var prng = DefaultPrng.init(seed);
var rng = prng.random();
var poisson = Poisson(u32, f64).init(&rng);
var sum: f64 = 0.0;
const lambda: f64 = 20.0;
for (0..10_000) |_| {
const samp = poisson.sample(lambda);
sum += @as(f64, @floatFromInt(samp));
var prng = std.Random.DefaultPrng.init(seed);
var rand = prng.random();
var poisson = Poisson(u32, f64).init(&rand);

const lambda_vec = [_]f64{ 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0 };

std.debug.print("\n", .{});
for (lambda_vec) |lambda| {
var sum: f64 = 0.0;
for (0..10_000) |_| {
const samp = poisson.sample(lambda);
sum += @as(f64, @floatFromInt(samp));
}
const avg: f64 = sum / 10_000.0;
std.debug.print("Mean: {}\tAvg: {}\tStdDev: {}\n", .{ lambda, avg, @sqrt(lambda) });
try std.testing.expectApproxEqAbs(lambda, avg, @sqrt(lambda));
}
}

test "Poisson with Different Types" {
const seed: u64 = @intCast(std.time.microTimestamp());
var prng = std.Random.DefaultPrng.init(seed);
var rand = prng.random();

const int_types = [_]type{ u8, u16, u32, u64, u128, i16, i32, i64, i128 };
const float_types = [_]type{ f16, f32, f64, f128 };

std.debug.print("\n", .{});
inline for (int_types) |i| {
inline for (float_types) |f| {
var poisson = Poisson(i, f).init(&rand);
const val = poisson.sample(20.0);
std.debug.print(
"Poisson({any}, {any}): {}\n",
.{ i, f, val },
);
}
}
const avg: f64 = sum / 10_000.0;
const mean: f64 = lambda;
const variance: f64 = lambda;
try std.testing.expectApproxEqAbs(
mean, avg, variance
);
}
}

0 comments on commit c487fff

Please sign in to comment.