Skip to content

Commit

Permalink
Merge pull request #165 from promised-ai/feature/variance
Browse files Browse the repository at this point in the history
Add Oracle::variability
  • Loading branch information
BaxterEaves authored Jan 31, 2024
2 parents 66e5a67 + bfa8416 commit 43c1c8c
Show file tree
Hide file tree
Showing 15 changed files with 449 additions and 231 deletions.
316 changes: 134 additions & 182 deletions cli/Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lace/lace_consts/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ repository = "https://github.com/promised-ai/lace"
description = "Default constants for Lace"

[dependencies]
rv = { version = "0.16.2", features = ["serde1", "arraydist"] }
rv = { version = "0.16.3", features = ["serde1", "arraydist"] }
2 changes: 1 addition & 1 deletion lace/src/interface/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use oracle::utils;

pub use oracle::{
ConditionalEntropyType, DatalessOracle, MiComponents, MiType, Oracle,
OracleT, RowSimilarityVariant,
OracleT, RowSimilarityVariant, Variability,
};

pub use given::Given;
Expand Down
15 changes: 13 additions & 2 deletions lace/src/interface/oracle/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ pub enum PredictError {
GivenError(#[from] GivenError),
}

/// Describes errors that can occur from bad inputs to `Oracle::variability`
#[derive(Debug, Clone, PartialEq, Error)]
pub enum VariabilityError {
/// The target column index is out of bounds
#[error("Target index error in predict query: {0}")]
IndexError(#[from] IndexError),
/// The Given is invalid
#[error("Invalid predict 'given' argument: {0}")]
GivenError(#[from] GivenError),
}

/// Describes errors that arise from invalid predict uncertainty arguments
#[derive(Debug, Clone, PartialEq, Error)]
pub enum PredictUncertaintyError {
Expand All @@ -192,7 +203,7 @@ pub enum PredictUncertaintyError {

/// Describes errors from incompatible `col_max_logp` caches
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum ColumnMaxiumLogPError {
pub enum ColumnMaximumLogPError {
/// The state indices used to compute the cache do not match those passed to the function.
#[error("The state indices used to compute the cache do not match those passed to the function.")]
InvalidStateIndices,
Expand Down Expand Up @@ -247,7 +258,7 @@ pub enum LogpError {
#[error("Invalid logp 'given' argument: {0}")]
GivenError(#[from] GivenError),
#[error("Invalid `col_max_logps` argument: {0}")]
ColumnMaxiumLogPError(#[from] ColumnMaxiumLogPError),
ColumnMaximumLogPError(#[from] ColumnMaximumLogPError),
}

/// Describes errors from bad inputs to Oracle::simulate
Expand Down
2 changes: 1 addition & 1 deletion lace/src/interface/oracle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod utils;
mod validation;

pub use dataless::DatalessOracle;
pub use traits::OracleT;
pub use traits::{OracleT, Variability};

use std::path::Path;

Expand Down
139 changes: 139 additions & 0 deletions lace/src/interface/oracle/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use lace_stats::rv::traits::Rv;
use lace_stats::SampleError;
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;

macro_rules! col_indices_ok {
Expand All @@ -41,6 +42,25 @@ macro_rules! state_indices_ok {
}}
}

/// Represents different formalizations of variability in distributions
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Variability {
/// The variance of a univariate distribution
Variance(f64),
/// The entropy of a distribution
Entropy(f64),
}

impl From<Variability> for f64 {
fn from(value: Variability) -> Self {
match value {
Variability::Variance(x) => x,
Variability::Entropy(x) => x,
}
}
}

pub trait OracleT: CanOracle {
/// Returns the diagnostics for each state
fn state_diagnostics(&self) -> Vec<StateDiagnostics> {
Expand Down Expand Up @@ -2046,6 +2066,125 @@ pub trait OracleT: CanOracle {
}
}

/// Compute the variability of a conditional distribution
///
/// # Notes
/// - Returns variance for Continuous and Count columns
/// - Returns Entropy for Categorical columns
///
/// # Arguments
/// - col_ix: the index of the column for which to compute the variability
/// - given: optional observations by which to constrain the prediction
/// - state_ixs_opt: Optional vector of state indices from which to compute,
/// if None, use all states.
fn variability<Ix: ColumnIndex, GIx: ColumnIndex>(
&self,
col_ix: Ix,
given: &Given<GIx>,
state_ixs_opt: Option<&[usize]>,
) -> Result<Variability, error::VariabilityError> {
use crate::stats::rv::traits::{Entropy, Variance};
use crate::stats::MixtureType;

let states: Vec<&State> = if let Some(state_ixs) = state_ixs_opt {
state_ixs.iter().map(|&ix| &self.states()[ix]).collect()
} else {
self.states().iter().collect()
};

let given =
given.clone().canonical(self.codebook()).map_err(|err| {
error::VariabilityError::GivenError(
error::GivenError::IndexError(err),
)
})?;

let col_ix = col_ix.col_ix(self.codebook())?;

// Get the mixture weights for each state
let mut mixture_types: Vec<MixtureType> = states
.iter()
.map(|state| {
let view_ix = state.asgn.asgn[col_ix];
let weights =
&utils::given_weights(&[state], &[col_ix], &given)[0];

// combine the state weights with the given weights
let mut mm_weights: Vec<f64> = state.views[view_ix]
.weights
.iter()
.zip(weights[&view_ix].iter())
.map(|(&w1, &w2)| w1 + w2)
.collect();

let z: f64 = logsumexp(&mm_weights);
mm_weights.iter_mut().for_each(|w| *w = (*w - z).exp());

state.views[view_ix].ftrs[&col_ix].to_mixture(mm_weights)
})
.collect();

enum MType {
Gaussian,
Categorical,
Count,
Unsupported,
}

let mtype = match mixture_types[0] {
MixtureType::Gaussian(_) => MType::Gaussian,
MixtureType::Poisson(_) => MType::Count,
MixtureType::Categorical(_) => MType::Categorical,
_ => MType::Unsupported,
};

match mtype {
MType::Gaussian => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Gaussian(mm) = mt {
mm
} else {
panic!("Expected Gaussian Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Variance(mm.variance().unwrap()))
}
MType::Count => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Poisson(mm) = mt {
mm
} else {
panic!("Expected Poisson Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Variance(mm.variance().unwrap()))
}
MType::Categorical => {
let mms: Vec<_> = mixture_types
.drain(..)
.map(|mt| {
if let MixtureType::Categorical(mm) = mt {
mm
} else {
panic!("Expected Categorical Mixture Type")
}
})
.collect();
let mm = Mixture::combine(mms);
Ok(Variability::Entropy(mm.entropy()))
}
_ => panic!("Unsupported MType"),
}
}

/// Compute the error between the observed data in a feature and the feature
/// model.
///
Expand Down
2 changes: 2 additions & 0 deletions lace/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub use crate::{
RowSimilarityVariant, SupportExtension, Value, WriteMode,
};

pub use crate::interface::Variability;

pub use crate::data::DataSource;

pub use lace_cc::{
Expand Down
4 changes: 2 additions & 2 deletions pylace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions pylace/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,23 @@ engine.update(10_000)
engine.predict('Class_of_Orbit', given={'Period_minutes': 1436.0})
# ('GEO', 0.13583714831550336)
```

## Tests

To run tests, use `pytest`

```console
$ pytest -x
```

To run doctets:

```console
$ python tests/test_docs.py
```

To prevent plotly from displaying

```console
$ LACE_DOCTEST_NOPLOT=1 python tests/test_docs.py
```
30 changes: 12 additions & 18 deletions pylace/lace/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def held_out_neglogp(
│ ["Apogee_km"] ┆ 5.106627 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 2.951662 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 2.951254 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 2.952801 ┆ 4
│ ["Apogee_km", "Country_of_Contra… ┆ 2.956224 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 2.96479 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 2.992173 ┆ 7 │
Expand Down Expand Up @@ -415,7 +415,7 @@ def held_out_inconsistency(
│ ["Apogee_km"] ┆ 1.290609 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 0.74598 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 0.745877 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 0.746268 ┆ 4
│ ["Apogee_km", "Country_of_Contra… ┆ 0.747133 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.749297 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.756218 ┆ 7 │
Expand Down Expand Up @@ -525,7 +525,7 @@ def held_out_uncertainty(
│ ["Expected_Lifetime"] ┆ 0.437647 ┆ 1 │
│ ["Apogee_km", "Eccentricity"] ┆ 0.05561 ┆ 2 │
│ ["Apogee_km", "Country_of_Operat… ┆ 0.055283 ┆ 3 │
… ┆ … ┆ …
["Apogee_km", "Country_of_Operat… ┆ 0.056185 ┆ 4
│ ["Apogee_km", "Country_of_Operat… ┆ 0.057624 ┆ 5 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.0595 ┆ 6 │
│ ["Apogee_km", "Country_of_Contra… ┆ 0.077359 ┆ 7 │
Expand Down Expand Up @@ -945,15 +945,15 @@ def explain_prediction(
│ --- ┆ --- │
│ str ┆ f64 │
╞══════════════════════════════╪═════════════╡
│ Country_of_Operator ┆ 3.5216e-16 │
│ Users ┆ -3.1668e-14
│ Purpose ┆ -9.5636e-14
│ Class_of_Orbit ┆ -1.8263e-15 │
│ Country_of_Operator ┆ 2.4617e-16 │
│ Users ┆ -2.1412e-15
│ Purpose ┆ -8.0193e-15
│ Class_of_Orbit ┆ -2.2727e-15 │
│ … ┆ … │
│ Launch_Site ┆ -2.8416e-15
│ Launch_Vehicle ┆ 1.0704e-14
│ Source_Used_for_Orbital_Data ┆ -3.9301e-15 │
│ Inclination_radians ┆ -9.6259e-15 │
│ Launch_Site ┆ -5.8214e-16
│ Launch_Vehicle ┆ -9.6101e-16
│ Source_Used_for_Orbital_Data ┆ -9.1997e-15 │
│ Inclination_radians ┆ -1.5407e-15 │
└──────────────────────────────┴─────────────┘
Get the importances using the 'ablative-dist' method, which measures how
Expand All @@ -975,7 +975,7 @@ def explain_prediction(
│ Country_of_Operator ┆ -0.000109 │
│ Users ┆ 0.081289 │
│ Purpose ┆ 0.18938 │
│ Class_of_Orbit ┆ 0.000133
│ Class_of_Orbit ┆ 0.000119
│ … ┆ … │
│ Launch_Site ┆ 0.003411 │
│ Launch_Vehicle ┆ -0.018817 │
Expand All @@ -994,9 +994,3 @@ def explain_prediction(
raise ValueError(
f"Invalid method `{method}`, valid methods are {PRED_EXPLAIN_METHODS}"
)


if __name__ == "__main__":
import doctest

doctest.testmod()
Loading

0 comments on commit 43c1c8c

Please sign in to comment.