-
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bayesian Optimization #102
Open
Ryo0731
wants to merge
23
commits into
main
Choose a base branch
from
bayesopt
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
4e25bab
feat: bayesopt
kimurayu45z 442ee96
feat: ucb
Ryo0731 30d1512
feat: ucb
Ryo0731 91c6d3f
feat: ei
Ryo0731 ed3d8d6
feat: sampling
Ryo0731 8c874c5
feat: maximization
Ryo0731 29b70f7
Merge branch 'main' of https://github.com/OpenSRDK/probability-rs int…
kimurayu45z 94055bc
Merge branch 'bayesopt' of https://github.com/OpenSRDK/probability-rs…
kimurayu45z 4d3cc29
fix: error
Ryo0731 fa35a8e
Merge branch 'bayesopt' of https://github.com/OpenSRDK/probability-rs…
Ryo0731 bf7922b
fix: sqrt
Ryo0731 4f7a5b3
fix: max for vecf64
Ryo0731 ae550e7
fix: gp
Ryo0731 12de912
feat: cmaes
Ryo0731 367a3b5
feat: calc_aquisition_function
Ryo0731 78b93df
fix: max inVecf64
Ryo0731 60e117a
feat: mod
Ryo0731 eb1a4eb
Merge branch 'main' of https://github.com/OpenSRDK/probability-rs int…
Ryo0731 8dc5828
feat: test
Ryo0731 6e6090c
feat: initial_mean
Ryo0731 eb4bc49
feat: delete
Ryo0731 ec0863b
feat: pub
Ryo0731 920d226
feat: test
Ryo0731 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
src/nonparametric/elliptical_process/bayesian_optimization/expected_improvement.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
use super::AcquisitionFunctions; | ||
// use crate::{Normal, NormalParams}; | ||
|
||
pub struct ExpectedImprovement { | ||
f_vec: Vec<f64>, | ||
} | ||
|
||
impl AcquisitionFunctions for ExpectedImprovement { | ||
fn value(&self, theta: &crate::NormalParams) -> f64 { | ||
let mu = theta.mu(); | ||
let sigma = theta.sigma(); | ||
let tau = self.f_vec.iter().max().unwrap(); | ||
let t = (mu - tau) / sigma; | ||
// let n = Normal; | ||
// let phi_large = n.p_kernel(n, t, &NormalParams::new(0.0, 1.0).unwrap()); | ||
let phi_large = pdf(t); | ||
let phi_small = cdf(t); | ||
|
||
(mu - tau) * phi_large + sigma * phi_small | ||
} | ||
} | ||
|
||
// Abramowitz and Stegun (1964) formula 26.2.17 | ||
// precision: abs(err) < 7.5e-8 | ||
|
||
fn pdf(x: f64) -> f64 { | ||
((-x * x) / 2.0).exp() / (2.0 * std::f64::consts::PI) | ||
} | ||
fn cdf(x: f64) -> f64 { | ||
// constants | ||
const p: f64 = 0.2316419; | ||
const b1: f64 = 0.31938153; | ||
const b2: f64 = -0.356563782; | ||
const b3: f64 = 1.781477937; | ||
const b4: f64 = -1.821255978; | ||
const b5: f64 = 1.330274429; | ||
|
||
let t = 1.0 / (1.0 + p * x.abs()); | ||
let z = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt(); | ||
let y = 1.0 - z * ((((b5 * t + b4) * t + b3) * t + b2) * t + b1) * t; | ||
|
||
return if x > 0.0 { y } else { 1.0 - y }; | ||
} |
120 changes: 120 additions & 0 deletions
120
src/nonparametric/elliptical_process/bayesian_optimization/mod.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
pub mod expected_improvement; | ||
pub mod upper_confidence_bound; | ||
|
||
pub use expected_improvement::*; | ||
use ndarray::{Array, ArrayView1}; | ||
use opensrdk_kernel_method::{Periodic, RBF}; | ||
use rand::rngs::StdRng; | ||
pub use upper_confidence_bound::*; | ||
|
||
use crate::{nonparametric::GaussianProcessRegressor, NormalParams}; | ||
use optimize::NelderMeadBuilder; | ||
|
||
use super::BaseEllipticalProcessParams; | ||
|
||
pub trait AcquisitionFunctions { | ||
fn value(&self, theta: &NormalParams) -> f64; | ||
} | ||
|
||
struct Data { | ||
x_data: Vec<f64>, | ||
y_data: Vec<f64>, | ||
} | ||
|
||
#[test] | ||
fn test_main() { | ||
let mut n: usize = 0; | ||
let mut data = Data { | ||
x_data: vec![], | ||
y_data: vec![], | ||
}; | ||
|
||
loop { | ||
let mut rng = StdRng::from_seed([1; 32]); | ||
let mut x: f64 = rng.gen(); | ||
|
||
sampling(&data, &x); | ||
|
||
n += 1; | ||
|
||
if n == 20 { | ||
break; | ||
} | ||
} | ||
|
||
loop { | ||
let xs = maximize_ucb(&data, n); | ||
// let xs = maximize_ei(&data); | ||
|
||
sampling(&data, &xs); | ||
|
||
n += 1; | ||
} | ||
} | ||
|
||
fn objective(x: &f64) -> f64 { | ||
x + x ^ 2.0 | ||
} | ||
|
||
fn sampling(mut data: &Data, x: &f64) { | ||
let y = objective(x); | ||
data.x_data.push(x); | ||
data.y_data.push(y); | ||
} | ||
|
||
fn gp_regression(x: &Vec<f64>, y: &Vec<f64>, xs: f64) -> NormalParams { | ||
let kernel = RBF + Periodic; | ||
let theta = vec![1.0; kernel.params_len()]; | ||
let sigma = 1.0; | ||
|
||
let base_params = BaseEllipticalProcessParams::new(kernel, x, theta, sigma).unwrap(); | ||
let params_y = base_params.exact(&y).unwrap(); | ||
let mu = params_y.gp_predict(xs).unwrap().mu(); | ||
let sigma = params_y.gp_predict(xs).unwrap().sigma(); | ||
|
||
[mu, sigma] | ||
} | ||
|
||
fn maximize_ucb(data: &Data, n: usize) -> f64 { | ||
let func_to_minimize = |xs: ArrayView1<f64>| { | ||
let theta: NormalParams = gp_regression(&data.x_data, &data.y_data, xs); | ||
let ucb = UpperConfidenceBound { trial: n }; | ||
-ucb.value(&theta) | ||
}; | ||
|
||
let minimizer = NelderMeadBuilder::default() | ||
.xtol(1e-6f64) | ||
.ftol(1e-6f64) | ||
.maxiter(50000) | ||
.build() | ||
.unwrap(); | ||
|
||
// Set the starting guess | ||
let args = Array::from_vec(vec![3.0, -8.3]); | ||
let xs = minimizer.minimize(&func_to_minimize, args.view()); | ||
|
||
xs | ||
} | ||
|
||
fn maximize_ei(data: &Data) -> f64 { | ||
let func_to_minimize = |xs: ArrayView1<f64>| { | ||
let theta: NormalParams = gp_regression(&data.x_data, &data.y_data, xs); | ||
let ei = ExpectedImprovement { | ||
f_vec: &data.y_data, | ||
}; | ||
-ei.value(&theta) | ||
}; | ||
|
||
let minimizer = NelderMeadBuilder::default() | ||
.xtol(1e-6f64) | ||
.ftol(1e-6f64) | ||
.maxiter(50000) | ||
.build() | ||
.unwrap(); | ||
|
||
// Set the starting guess | ||
let args = Array::from_vec(vec![3.0, -8.3]); | ||
let xs = minimizer.minimize(&func_to_minimize, args.view()); | ||
|
||
xs | ||
} |
18 changes: 18 additions & 0 deletions
18
src/nonparametric/elliptical_process/bayesian_optimization/upper_confidence_bound.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use num_integer::sqrt; | ||
|
||
use super::AcquisitionFunctions; | ||
|
||
pub struct UpperConfidenceBound { | ||
trial: f64, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 明らかにusizeで保持して計算の時に都度キャストすべき There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最初適当にf64で設計してたのが残ってますね、そのように修正しておきます。 |
||
} | ||
|
||
impl AcquisitionFunctions for UpperConfidenceBound { | ||
fn value(&self, theta: &crate::NormalParams) -> f64 { | ||
let mu = theta.mu(); | ||
let sigma = theta.sigma(); | ||
let n = self.trial; | ||
let k = sqrt(n.ln() / n); | ||
|
||
mu + k * sigma | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
なんでそうなるー
いらない
ndarrayもいらない
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
外部ライブラリ使うことがそもそも不要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
そうなるとここの最適化ってどうやればいいでしょう?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
まずはグリッドサーチって言ってなかったっけ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最適化よりむしろlinear algebraとndarrayを同居させるほうがよくない
それにどうしてもグリッドサーチいやだっていうならなんで直近使っててndarrayいらないcmaesを使わないんだ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
グリッドサーチの件は多分初耳ですけど覚えときます。
cmaesも考えたんですけど、あれってブラックボックス関数の最適化用だと思っていたので、今回はブラックボックス関数じゃないからcmaesより速いのがあるかなって勝手に考えちゃいました