-
-
Notifications
You must be signed in to change notification settings - Fork 282
Open
Description
I have a simple struct for setting some params and creating an SVR model
use linfa::prelude::*;
use linfa_svm::{Svm, SvmParams};
use ndarray::Array;
struct SVRModel {
params: SvmParams<f64,f64>,
model: Option<Svm<f64,f64>>,
}
impl SVRModel
{
fn new() -> Self {
Self {
params: Svm::<f64, _>::params()
.nu_eps(0.5,0.01)
.gaussian_kernel(95.0),
model: None,
}
}
fn train(&mut self, x_train: &[&[f64]], y_train: &[f64]) {
let x_train = x_train
.iter()
.map(|x| x.to_vec())
.flatten()
.collect::<Vec<_>>();
let targets = y_train.iter().cloned().collect::<Vec<_>>();
let dataset = DatasetBase::new(
Array::from_shape_vec([targets.len(), x_train.len()], x_train).unwrap(),
Array::from_shape_vec([targets.len()], targets).unwrap(),
);
self.model = Some(self.params.fit(&dataset).unwrap());
}
}
The above works fine but changing the type to F: Float like this
use linfa::prelude::*;
use linfa_svm::{Svm, SvmParams};
use ndarray::Array;
struct SVRModel<F: Float> {
params: SvmParams<F, F>,
model: Option<Svm<F, F>>,
}
impl<F> SVRModel<F>
where
F: linfa::Float,
{
fn new() -> Self {
Self {
params: Svm::<F, F>::params()
.nu_eps(F::from_f64(0.5).unwrap(), F::from_f64(0.01).unwrap())
.gaussian_kernel(F::from_f64(95.0).unwrap()),
model: None,
}
}
fn train(&mut self, x_train: &[&[F]], y_train: &[F]) {
let x_train = x_train
.iter()
.map(|x| x.to_vec())
.flatten()
.collect::<Vec<_>>();
let targets = y_train.iter().cloned().collect::<Vec<_>>();
let dataset = DatasetBase::new(
Array::from_shape_vec([targets.len(), x_train.len()], x_train).unwrap(),
Array::from_shape_vec([targets.len()], targets).unwrap(),
);
self.model = Some(self.params.fit(&dataset).unwrap());
}
}
errors with
the method `fit` exists for struct `SvmParams<F, F>`, but its trait bounds were not satisfied
the following trait bounds were not satisfied:
`SvmValidParams<F, F>: linfa::prelude::Fit<_, _, _>`
which is required by `SvmParams<F, F>: linfa::prelude::Fit<_, _, _>`rustc[Click for full compiler diagnostic](rust-analyzer-diagnostics-view:/diagnostic message [15]?15#file:///d%3A/muCapital/systems/src/arti_xg.rs)
hyperparams.rs(37, 1): doesn't satisfy `SvmValidParams<F, F>: linfa::prelude::Fit<_, _, _>`
hyperparams.rs(69, 1): doesn't satisfy `SvmParams<F, F>: linfa::prelude::Fit<_, _, _>`
How do I express the missing trait bounds or get this to work with numeric types?
EDIT: Example rewritten to minimal self contained sample
Metadata
Metadata
Assignees
Labels
No labels