Skip to content

Commit

Permalink
Async Processor (#2482)
Browse files Browse the repository at this point in the history
* Async Processor

* WIP

* Update cubecl

* Fix CI

* fix
  • Loading branch information
nathanielsimard authored Nov 14, 2024
1 parent 1cd956e commit bff67dc
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 38 deletions.
25 changes: 13 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions crates/burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross
# Utilities
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
async-channel = { workspace = true }

[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-train/src/checkpoint/strategy/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ mod tests {
},
TestBackend,
};
use std::rc::Rc;

use super::*;
use std::sync::Arc;

#[test]
fn always_keep_the_best_epoch() {
Expand All @@ -93,7 +93,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
// Register the loss metric.
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Rc::new(EventStoreClient::new(store));
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

// Two points for the first epoch. Mean 0.75
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-train/src/learner/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use burn_core::module::Module;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
use burn_core::tensor::Device;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand All @@ -27,7 +26,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) interrupter: TrainingInterrupter,
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Rc<EventStoreClient>,
pub(crate) event_store: Arc<EventStoreClient>,
pub(crate) summary: Option<LearnerSummaryConfig>,
}

Expand Down
14 changes: 9 additions & 5 deletions crates/burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;

use super::Learner;
use crate::checkpoint::{
Expand All @@ -11,7 +11,7 @@ use crate::components::LearnerComponentsMarker;
use crate::learner::base::TrainingInterrupter;
use crate::learner::EarlyStoppingStrategy;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::processor::{AsyncProcessor, FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
Expand Down Expand Up @@ -302,7 +302,7 @@ where
AsyncCheckpointer<M::Record, B>,
AsyncCheckpointer<O::Record, B>,
AsyncCheckpointer<S::Record<B>, B>,
FullEventProcessor<T, V>,
AsyncProcessor<FullEventProcessor<T, V>>,
Box<dyn CheckpointingStrategy>,
>,
>
Expand All @@ -327,8 +327,12 @@ where
.register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
}

let event_store = Rc::new(EventStoreClient::new(self.event_store));
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = AsyncProcessor::new(FullEventProcessor::new(
self.metrics,
renderer,
event_store.clone(),
));

let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-train/src/learner/early_stopping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl MetricEarlyStoppingStrategy {

#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::sync::Arc;

use crate::{
logger::InMemoryMetricLogger,
Expand Down Expand Up @@ -197,7 +197,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());

let store = Rc::new(EventStoreClient::new(store));
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

let mut epoch = 1;
Expand Down
51 changes: 51 additions & 0 deletions crates/burn-train/src/metric/iteration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use super::state::FormatOptions;
use super::state::NumericMetricState;
use super::MetricEntry;
use super::MetricMetadata;
use crate::metric::{Metric, Numeric};

/// The loss metric.
#[derive(Default)]
pub struct IterationSpeedMetric {
state: NumericMetricState,
instant: Option<std::time::Instant>,
}

impl IterationSpeedMetric {
/// Create the metric.
pub fn new() -> Self {
Self::default()
}
}

impl Metric for IterationSpeedMetric {
const NAME: &'static str = "Iteration Speed";

type Input = ();

fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry {
let raw = match self.instant {
Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(),
None => {
self.instant = Some(std::time::Instant::now());
0.0
}
};

self.state.update(
raw,
1,
FormatOptions::new(Self::NAME).unit("iter/sec").precision(2),
)
}

fn clear(&mut self) {
self.instant = None;
}
}

impl Numeric for IterationSpeedMetric {
fn value(&self) -> f64 {
self.state.value()
}
}
4 changes: 4 additions & 0 deletions crates/burn-train/src/metric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod loss;
#[cfg(feature = "metrics")]
mod memory_use;

#[cfg(feature = "metrics")]
mod iteration;
#[cfg(feature = "metrics")]
mod top_k_acc;

Expand All @@ -29,6 +31,8 @@ pub use cpu_use::*;
#[cfg(feature = "metrics")]
pub use cuda::*;
pub use hamming::*;
#[cfg(feature = "metrics")]
pub use iteration::*;
pub use learning_rate::*;
pub use loss::*;
#[cfg(feature = "metrics")]
Expand Down
54 changes: 54 additions & 0 deletions crates/burn-train/src/metric/processor/async_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use super::{Event, EventProcessor};
use async_channel::{Receiver, Sender};

pub struct AsyncProcessor<P: EventProcessor> {
sender: Sender<Message<P>>,
}

struct Worker<P: EventProcessor> {
processor: P,
rec: Receiver<Message<P>>,
}

impl<P: EventProcessor + 'static> Worker<P> {
pub fn start(processor: P, rec: Receiver<Message<P>>) {
let mut worker = Self { processor, rec };

std::thread::spawn(move || {
while let Ok(msg) = worker.rec.recv_blocking() {
match msg {
Message::Train(event) => worker.processor.process_train(event),
Message::Valid(event) => worker.processor.process_valid(event),
}
}
});
}
}

impl<P: EventProcessor + 'static> AsyncProcessor<P> {
pub fn new(processor: P) -> Self {
let (sender, rec) = async_channel::bounded(1);

Worker::start(processor, rec);

Self { sender }
}
}

enum Message<P: EventProcessor> {
Train(Event<P::ItemTrain>),
Valid(Event<P::ItemValid>),
}

impl<P: EventProcessor> EventProcessor for AsyncProcessor<P> {
type ItemTrain = P::ItemTrain;
type ItemValid = P::ItemValid;

fn process_train(&mut self, event: Event<Self::ItemTrain>) {
self.sender.send_blocking(Message::Train(event)).unwrap();
}

fn process_valid(&mut self, event: Event<Self::ItemValid>) {
self.sender.send_blocking(Message::Valid(event)).unwrap();
}
}
6 changes: 3 additions & 3 deletions crates/burn-train/src/metric/processor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ pub enum Event<T> {
}

/// Process events happening during training and validation.
pub trait EventProcessor {
pub trait EventProcessor: Send {
/// The training item.
type ItemTrain;
type ItemTrain: Send;
/// The validation item.
type ItemValid;
type ItemValid: Send;

/// Collect a training event.
fn process_train(&mut self, event: Event<Self::ItemTrain>);
Expand Down
Loading

0 comments on commit bff67dc

Please sign in to comment.