-
I am attemting to implement dice loss with burn. However, the following code leads to an error inside libtorch as in the title of this discussion. Also, when I remove This may be an issue in my code, but I have not been able to find the cause of this error. use burn::backend::Autodiff;
use burn::nn::conv::Conv2dConfig;
use burn::nn::PaddingConfig2d;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::Tensor;
fn main() {
type MyBackend = burn::backend::LibTorch;
type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::libtorch::LibTorchDevice::Cpu;
test::<MyAutodiffBackend>(&device);
}
fn test<B: AutodiffBackend>(device: &B::Device) {
let conv2d = Conv2dConfig::new([16, 1], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device);
let x: Tensor<B, 4> = Tensor::zeros([1, 16, 512, 512], device);
let x: Tensor<B, 4> = conv2d.forward(x);
let x: Tensor<B, 1> = x.flatten(0, 3);
let targets: Tensor<B, 4> = Tensor::zeros([1, 1, 512, 512], device);
let targets: Tensor<B, 1> = targets.flatten(0, 3);
let intersection: Tensor<B, 1> = (x.clone() * targets.clone()).sum();
let dice: Tensor<B, 1> = (intersection * 2.0 + 1) / ((x.clone()).sum() + (targets.clone()).sum() + 1);
dice.backward();
} Here is the message emitted on panic.
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Did you try with another backend? |
Beta Was this translation helpful? Give feedback.
-
Changing let intersection: Tensor<B, 1> = (x.clone() * targets.clone()).sum();
let dice: Tensor<B, 1> = (intersection * 2.0 + 1) / ((x.clone()).sum() + (targets.clone()).sum() + 1); to let intersection: Tensor<B, 1> = (x.clone() * targets.clone()).sum();
let dice: Tensor<B, 1> = (intersection * 2.0 + 1).reshape([1])
/ ((x.clone()).sum() + (targets.clone()).sum() + 1).reshape([1]); fixed the issue. It seems that Maybe related to #1689. As I found the workaround, I will close discussion. |
Beta Was this translation helpful? Give feedback.
Changing
to
fixed the issue.
It seems that
sum()
returns 0-dimension tensor, which results in a dimension mismatch inhttps://github.com/tracel-ai/burn/blob/v0.13.2/crates/burn-autodiff/src/ops/tensor.rs#L393.
Maybe related to #1689.
As I found the workaround, I will close discussion.