Skip to content

Commit

Permalink
Merge pull request #90 from OpenSRDK/instant-condition-diffarentiable…
Browse files Browse the repository at this point in the history
…-distributiuon

feat: instant distribution にcondition微分を追加するstruct
  • Loading branch information
Senna46 authored Apr 1, 2022
2 parents c65f8fe + 1c8d927 commit 145569f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
137 changes: 137 additions & 0 deletions src/distribution/instant_condition_differentiable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use crate::{
ConditionDifferentiableDistribution, DependentJoint, Distribution, DistributionError,
IndependentJoint, InstantDistribution, RandomVariable,
};
use rand::prelude::*;
use std::{
fmt::Debug,
marker::PhantomData,
ops::{BitAnd, Mul},
};

#[derive(Clone)]
pub struct ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
instant_distribution: InstantDistribution<T, U, FF, FS>,
condition_diff: G,
phantom: PhantomData<U>,
}

impl<T, U, FF, FS, G> ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
pub fn new(instant_distribution: InstantDistribution<T, U, FF, FS>, condition_diff: G) -> Self {
Self {
instant_distribution,
condition_diff,
phantom: PhantomData,
}
}
}

impl<T, U, FF, FS, G> Debug for ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "InstantDistribution")
}
}

impl<T, U, FF, FS, G> Distribution for ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
type Value = T;
type Condition = U;

fn fk(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<f64, crate::DistributionError> {
self.instant_distribution.fk(x, theta)
}

fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn RngCore,
) -> Result<Self::Value, crate::DistributionError> {
self.instant_distribution.sample(theta, rng)
}
}

impl<T, U, Rhs, TRhs, FF, FS, G> Mul<Rhs>
for ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
Rhs: Distribution<Value = TRhs, Condition = U>,
TRhs: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
type Output = IndependentJoint<Self, Rhs, T, TRhs, U>;

fn mul(self, rhs: Rhs) -> Self::Output {
IndependentJoint::new(self, rhs)
}
}

impl<T, U, Rhs, URhs, FF, FS, G> BitAnd<Rhs>
for ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
Rhs: Distribution<Value = U, Condition = URhs>,
URhs: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
type Output = DependentJoint<Self, Rhs, T, U, URhs>;

fn bitand(self, rhs: Rhs) -> Self::Output {
DependentJoint::new(self, rhs)
}
}

impl<T, U, FF, FS, G> ConditionDifferentiableDistribution
for ConditionDifferentiableInstantDistribution<T, U, FF, FS, G>
where
T: RandomVariable,
U: RandomVariable,
FF: Fn(&T, &U) -> Result<f64, DistributionError> + Clone + Send + Sync,
FS: Fn(&U, &mut dyn RngCore) -> Result<T, DistributionError> + Clone + Send + Sync,
G: Fn(&U) -> Result<Vec<f64>, DistributionError> + Clone + Send + Sync,
{
fn ln_diff_condition(
&self,
_x: &Self::Value,
theta: &Self::Condition,
) -> Result<Vec<f64>, DistributionError> {
let g = (self.condition_diff)(theta)?;
Ok(g)
}
}
1 change: 1 addition & 0 deletions src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod independent_array_joint;
pub mod independent_joint;
pub mod independent_value_array_joint;
pub mod instant;
pub mod instant_condition_differentiable;
pub mod random_variable;
pub mod switched;
pub mod transformed;
Expand Down

0 comments on commit 145569f

Please sign in to comment.