diff --git a/Cargo.lock b/Cargo.lock index 236dfe1326..ed186cac6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -881,6 +881,7 @@ dependencies = [ name = "burn-train" version = "0.16.0" dependencies = [ + "async-channel", "burn-core", "burn-ndarray", "derive-new 0.7.0", @@ -1553,7 +1554,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1584,7 +1585,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1601,7 +1602,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1619,7 +1620,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1633,7 +1634,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1649,7 +1650,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1674,7 +1675,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "bytemuck", "cubecl-core", @@ -1685,7 +1686,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1700,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1737,7 +1738,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "async-channel", "async-lock", @@ -1758,7 +1759,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1772,7 +1773,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a1471a7ffa089ee2878bb8c140d09f66a2b2b664#a1471a7ffa089ee2878bb8c140d09f66a2b2b664" +source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 1bd7ab429c..c092a8177b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 88ef13eb75..04ddb66f2c 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -39,6 +39,7 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross # Utilities derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } +async-channel = { workspace = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } diff --git a/crates/burn-train/src/checkpoint/strategy/metric.rs b/crates/burn-train/src/checkpoint/strategy/metric.rs index 7a1cd6085e..4efcf14028 100644 --- a/crates/burn-train/src/checkpoint/strategy/metric.rs +++ b/crates/burn-train/src/checkpoint/strategy/metric.rs @@ -76,9 +76,9 @@ mod tests { }, TestBackend, }; - use std::rc::Rc; use super::*; + use std::sync::Arc; #[test] fn always_keep_the_best_epoch() { @@ -93,7 +93,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); // Register the loss metric. metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Rc::new(EventStoreClient::new(store)); + let store = Arc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); // Two points for the first epoch. Mean 0.75 diff --git a/crates/burn-train/src/learner/base.rs b/crates/burn-train/src/learner/base.rs index bd6128681e..0534b0f4d0 100644 --- a/crates/burn-train/src/learner/base.rs +++ b/crates/burn-train/src/learner/base.rs @@ -8,7 +8,6 @@ use burn_core::module::Module; use burn_core::optim::Optimizer; use burn_core::tensor::backend::Backend; use burn_core::tensor::Device; -use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -27,7 +26,7 @@ pub struct Learner { pub(crate) interrupter: TrainingInterrupter, pub(crate) early_stopping: Option>, pub(crate) event_processor: LC::EventProcessor, - pub(crate) event_store: Rc, + pub(crate) event_store: Arc, pub(crate) summary: Option, } diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index 2298a41ee7..f6f90c4c5a 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; use std::path::{Path, PathBuf}; -use std::rc::Rc; +use std::sync::Arc; use super::Learner; use crate::checkpoint::{ @@ -11,7 +11,7 @@ use crate::components::LearnerComponentsMarker; use crate::learner::base::TrainingInterrupter; use crate::learner::EarlyStoppingStrategy; use crate::logger::{FileMetricLogger, MetricLogger}; -use crate::metric::processor::{FullEventProcessor, Metrics}; +use crate::metric::processor::{AsyncProcessor, FullEventProcessor, Metrics}; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, LossMetric, Metric}; use crate::renderer::{default_renderer, MetricsRenderer}; @@ -302,7 +302,7 @@ where AsyncCheckpointer, AsyncCheckpointer, AsyncCheckpointer, B>, - FullEventProcessor, + AsyncProcessor>, Box, >, > @@ -327,8 +327,12 @@ where .register_logger_valid(FileMetricLogger::new(self.directory.join("valid"))); } - let event_store = Rc::new(EventStoreClient::new(self.event_store)); - let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); + let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_processor = AsyncProcessor::new(FullEventProcessor::new( + self.metrics, + renderer, + event_store.clone(), + )); let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) diff --git a/crates/burn-train/src/learner/early_stopping.rs b/crates/burn-train/src/learner/early_stopping.rs index db3dc478a9..c66ea9eedd 100644 --- a/crates/burn-train/src/learner/early_stopping.rs +++ b/crates/burn-train/src/learner/early_stopping.rs @@ -113,7 +113,7 @@ impl MetricEarlyStoppingStrategy { #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use crate::{ logger::InMemoryMetricLogger, @@ -197,7 +197,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Rc::new(EventStoreClient::new(store)); + let store = Arc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); let mut epoch = 1; diff --git a/crates/burn-train/src/metric/iteration.rs b/crates/burn-train/src/metric/iteration.rs new file mode 100644 index 0000000000..f053b8ee47 --- /dev/null +++ b/crates/burn-train/src/metric/iteration.rs @@ -0,0 +1,51 @@ +use super::state::FormatOptions; +use super::state::NumericMetricState; +use super::MetricEntry; +use super::MetricMetadata; +use crate::metric::{Metric, Numeric}; + +/// The loss metric. +#[derive(Default)] +pub struct IterationSpeedMetric { + state: NumericMetricState, + instant: Option, +} + +impl IterationSpeedMetric { + /// Create the metric. + pub fn new() -> Self { + Self::default() + } +} + +impl Metric for IterationSpeedMetric { + const NAME: &'static str = "Iteration Speed"; + + type Input = (); + + fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry { + let raw = match self.instant { + Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(), + None => { + self.instant = Some(std::time::Instant::now()); + 0.0 + } + }; + + self.state.update( + raw, + 1, + FormatOptions::new(Self::NAME).unit("iter/sec").precision(2), + ) + } + + fn clear(&mut self) { + self.instant = None; + } +} + +impl Numeric for IterationSpeedMetric { + fn value(&self) -> f64 { + self.state.value() + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index ebbd35132b..2187734807 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -16,6 +16,8 @@ mod loss; #[cfg(feature = "metrics")] mod memory_use; +#[cfg(feature = "metrics")] +mod iteration; #[cfg(feature = "metrics")] mod top_k_acc; @@ -29,6 +31,8 @@ pub use cpu_use::*; #[cfg(feature = "metrics")] pub use cuda::*; pub use hamming::*; +#[cfg(feature = "metrics")] +pub use iteration::*; pub use learning_rate::*; pub use loss::*; #[cfg(feature = "metrics")] diff --git a/crates/burn-train/src/metric/processor/async_wrapper.rs b/crates/burn-train/src/metric/processor/async_wrapper.rs new file mode 100644 index 0000000000..82c358fc73 --- /dev/null +++ b/crates/burn-train/src/metric/processor/async_wrapper.rs @@ -0,0 +1,54 @@ +use super::{Event, EventProcessor}; +use async_channel::{Receiver, Sender}; + +pub struct AsyncProcessor { + sender: Sender>, +} + +struct Worker { + processor: P, + rec: Receiver>, +} + +impl Worker

{ + pub fn start(processor: P, rec: Receiver>) { + let mut worker = Self { processor, rec }; + + std::thread::spawn(move || { + while let Ok(msg) = worker.rec.recv_blocking() { + match msg { + Message::Train(event) => worker.processor.process_train(event), + Message::Valid(event) => worker.processor.process_valid(event), + } + } + }); + } +} + +impl AsyncProcessor

{ + pub fn new(processor: P) -> Self { + let (sender, rec) = async_channel::bounded(1); + + Worker::start(processor, rec); + + Self { sender } + } +} + +enum Message { + Train(Event), + Valid(Event), +} + +impl EventProcessor for AsyncProcessor

{ + type ItemTrain = P::ItemTrain; + type ItemValid = P::ItemValid; + + fn process_train(&mut self, event: Event) { + self.sender.send_blocking(Message::Train(event)).unwrap(); + } + + fn process_valid(&mut self, event: Event) { + self.sender.send_blocking(Message::Valid(event)).unwrap(); + } +} diff --git a/crates/burn-train/src/metric/processor/base.rs b/crates/burn-train/src/metric/processor/base.rs index 9093d26457..587548d299 100644 --- a/crates/burn-train/src/metric/processor/base.rs +++ b/crates/burn-train/src/metric/processor/base.rs @@ -10,11 +10,11 @@ pub enum Event { } /// Process events happening during training and validation. -pub trait EventProcessor { +pub trait EventProcessor: Send { /// The training item. - type ItemTrain; + type ItemTrain: Send; /// The validation item. - type ItemValid; + type ItemValid: Send; /// Collect a training event. fn process_train(&mut self, event: Event); diff --git a/crates/burn-train/src/metric/processor/full.rs b/crates/burn-train/src/metric/processor/full.rs index 9f76588c88..a9e77b9e33 100644 --- a/crates/burn-train/src/metric/processor/full.rs +++ b/crates/burn-train/src/metric/processor/full.rs @@ -1,7 +1,7 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; use crate::renderer::{MetricState, MetricsRenderer}; -use std::rc::Rc; +use std::sync::Arc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). @@ -9,14 +9,14 @@ use std::rc::Rc; pub struct FullEventProcessor { metrics: Metrics, renderer: Box, - store: Rc, + store: Arc, } impl FullEventProcessor { pub(crate) fn new( metrics: Metrics, renderer: Box, - store: Rc, + store: Arc, ) -> Self { Self { metrics, @@ -26,7 +26,7 @@ impl FullEventProcessor { } } -impl EventProcessor for FullEventProcessor { +impl EventProcessor for FullEventProcessor { type ItemTrain = T; type ItemValid = V; diff --git a/crates/burn-train/src/metric/processor/minimal.rs b/crates/burn-train/src/metric/processor/minimal.rs index e95d2e8b46..3e5d34ef44 100644 --- a/crates/burn-train/src/metric/processor/minimal.rs +++ b/crates/burn-train/src/metric/processor/minimal.rs @@ -1,16 +1,16 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; -use std::rc::Rc; +use std::sync::Arc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[derive(new)] pub(crate) struct MinimalEventProcessor { metrics: Metrics, - store: Rc, + store: Arc, } -impl EventProcessor for MinimalEventProcessor { +impl EventProcessor for MinimalEventProcessor { type ItemTrain = T; type ItemValid = V; diff --git a/crates/burn-train/src/metric/processor/mod.rs b/crates/burn-train/src/metric/processor/mod.rs index a025106e95..1a5c5a8957 100644 --- a/crates/burn-train/src/metric/processor/mod.rs +++ b/crates/burn-train/src/metric/processor/mod.rs @@ -1,3 +1,4 @@ +mod async_wrapper; mod base; mod full; mod metrics; @@ -10,6 +11,8 @@ pub(crate) use metrics::*; #[cfg(test)] pub(crate) use minimal::*; +pub use async_wrapper::AsyncProcessor; + #[cfg(test)] pub(crate) mod test_utils { use crate::metric::{ diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index fa07f09158..7c139ea06a 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -18,7 +18,9 @@ use burn::{ record::{CompactRecorder, Recorder}, tensor::backend::AutodiffBackend, train::{ - metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric}, + metric::{ + AccuracyMetric, CudaMetric, IterationSpeedMetric, LearningRateMetric, LossMetric, + }, LearnerBuilder, }, }; @@ -92,10 +94,11 @@ pub fn train( let learner = LearnerBuilder::new(artifact_dir) .metric_train(CudaMetric::new()) .metric_valid(CudaMetric::new()) - .metric_train_numeric(AccuracyMetric::new()) - .metric_valid_numeric(AccuracyMetric::new()) + .metric_train(IterationSpeedMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .devices(devices)