-
Notifications
You must be signed in to change notification settings - Fork 442
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Async Processor * WIP * Update cubecl * Fix CI * fix
- Loading branch information
1 parent
1cd956e
commit bff67dc
Showing
15 changed files
with
158 additions
and
38 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use super::state::FormatOptions; | ||
use super::state::NumericMetricState; | ||
use super::MetricEntry; | ||
use super::MetricMetadata; | ||
use crate::metric::{Metric, Numeric}; | ||
|
||
/// The loss metric. | ||
#[derive(Default)] | ||
pub struct IterationSpeedMetric { | ||
state: NumericMetricState, | ||
instant: Option<std::time::Instant>, | ||
} | ||
|
||
impl IterationSpeedMetric { | ||
/// Create the metric. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
} | ||
|
||
impl Metric for IterationSpeedMetric { | ||
const NAME: &'static str = "Iteration Speed"; | ||
|
||
type Input = (); | ||
|
||
fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> MetricEntry { | ||
let raw = match self.instant { | ||
Some(val) => metadata.iteration as f64 / val.elapsed().as_secs_f64(), | ||
None => { | ||
self.instant = Some(std::time::Instant::now()); | ||
0.0 | ||
} | ||
}; | ||
|
||
self.state.update( | ||
raw, | ||
1, | ||
FormatOptions::new(Self::NAME).unit("iter/sec").precision(2), | ||
) | ||
} | ||
|
||
fn clear(&mut self) { | ||
self.instant = None; | ||
} | ||
} | ||
|
||
impl Numeric for IterationSpeedMetric { | ||
fn value(&self) -> f64 { | ||
self.state.value() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
use super::{Event, EventProcessor}; | ||
use async_channel::{Receiver, Sender}; | ||
|
||
pub struct AsyncProcessor<P: EventProcessor> { | ||
sender: Sender<Message<P>>, | ||
} | ||
|
||
struct Worker<P: EventProcessor> { | ||
processor: P, | ||
rec: Receiver<Message<P>>, | ||
} | ||
|
||
impl<P: EventProcessor + 'static> Worker<P> { | ||
pub fn start(processor: P, rec: Receiver<Message<P>>) { | ||
let mut worker = Self { processor, rec }; | ||
|
||
std::thread::spawn(move || { | ||
while let Ok(msg) = worker.rec.recv_blocking() { | ||
match msg { | ||
Message::Train(event) => worker.processor.process_train(event), | ||
Message::Valid(event) => worker.processor.process_valid(event), | ||
} | ||
} | ||
}); | ||
} | ||
} | ||
|
||
impl<P: EventProcessor + 'static> AsyncProcessor<P> { | ||
pub fn new(processor: P) -> Self { | ||
let (sender, rec) = async_channel::bounded(1); | ||
|
||
Worker::start(processor, rec); | ||
|
||
Self { sender } | ||
} | ||
} | ||
|
||
enum Message<P: EventProcessor> { | ||
Train(Event<P::ItemTrain>), | ||
Valid(Event<P::ItemValid>), | ||
} | ||
|
||
impl<P: EventProcessor> EventProcessor for AsyncProcessor<P> { | ||
type ItemTrain = P::ItemTrain; | ||
type ItemValid = P::ItemValid; | ||
|
||
fn process_train(&mut self, event: Event<Self::ItemTrain>) { | ||
self.sender.send_blocking(Message::Train(event)).unwrap(); | ||
} | ||
|
||
fn process_valid(&mut self, event: Event<Self::ItemValid>) { | ||
self.sender.send_blocking(Message::Valid(event)).unwrap(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.