Skip to content

Commit

Permalink
Update cubecl
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Nov 14, 2024
1 parent 34b761f commit 10d9fa2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
25 changes: 13 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
7 changes: 4 additions & 3 deletions examples/text-classification/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use burn::{
record::{CompactRecorder, Recorder},
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric},
metric::{AccuracyMetric, CudaMetric, IterationSpeedMetric, LearningRateMetric, LossMetric},
LearnerBuilder,
},
};
Expand Down Expand Up @@ -92,10 +92,11 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CudaMetric::new())
.metric_valid(CudaMetric::new())
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train(IterationSpeedMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LearningRateMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(devices)
Expand Down

0 comments on commit 10d9fa2

Please sign in to comment.