Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error Code 4:Internal Error (/encoder/layers.10/Conv: number of kernel weights does not match tensor dimensions) #3720

Closed
Liupei1101 opened this issue Mar 18, 2024 · 9 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@Liupei1101
Copy link

onnx convert to trt model failed .
commend ./trtexec --onnx=/RAID5/user/liupei/ProPainter/E2FGVI/release_model/E2FGVI-HQ-CVPR22_LIUPEI.onnx --saveEngine=/RAID5/user/liupei/ProPainter/E2FGVI/release_model/E2FGVI-HQ-CVPR22_LIUPEI.trt --explicitBatch

/encoder/layers.10/Conv /encoder/layers.10/Conv:kernel weights has count 1474560 but 737280 was expected /encoder/layers.10/Conv: count of 1474560 weights in kernel, but kernel dimensions (3,3) with 320 input channels, 512 output channels and 2 groups were specified. Expected Weights count is 320 * 3*3 * 512 / 2 = 737280

@Liupei1101
Copy link
Author

download

@zerollzeng
Copy link
Collaborator

  1. does it work with onnx runtime?
  2. I guess you are feeding incompatible input shapes.

@zerollzeng zerollzeng self-assigned this Mar 22, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Mar 22, 2024
@Liupei1101
Copy link
Author

  1. does it work with onnx runtime?
  2. I guess you are feeding incompatible input shapes.

image
/encoder/layers.10/Conv /encoder/layers.10/Conv:kernel weights has count 1474560 but 737280 was expected /encoder/layers.10/Conv: count of 1474560 weights in kernel, but kernel dimensions (3,3) with 320 input channels, 512 output channels and 2 groups were specified. Expected Weights count is 320 * 3*3 * 512 / 2 = 737280

when static_axes onnx convert to tensorRT model, it works well, when convert dynamic_axes onnx model to tensorRT, the error oappear again. i do not know why?

@Liupei1101
Copy link
Author

Liupei1101 commented Apr 9, 2024

onnx convert to trt model failed . commend ./trtexec --onnx=/RAID5/user/liupei/ProPainter/E2FGVI/release_model/E2FGVI-HQ-CVPR22_LIUPEI.onnx --saveEngine=/RAID5/user/liupei/ProPainter/E2FGVI/release_model/E2FGVI-HQ-CVPR22_LIUPEI.trt --explicitBatch

/encoder/layers.10/Conv /encoder/layers.10/Conv:kernel weights has count 1474560 but 737280 was expected /encoder/layers.10/Conv: count of 1474560 weights in kernel, but kernel dimensions (3,3) with 320 input channels, 512 output channels and 2 groups were specified. Expected Weights count is 320 * 3*3 * 512 / 2 = 737280

PWS8Xff8l1

image

if pth convert to onnx model without dynamic_axes, trtexec convert to tensorrt model works well, when with dynamic_axes , trtexec convert to tensorrt errot appear, i do not know why???

@zerollzeng
Copy link
Collaborator

Could you please share the onnx here? Maybe a bug.

@zerollzeng
Copy link
Collaborator

I quickly create a minimal reproduce but cannot repro:

$ ls
conv.onnx  make_conv.py
(base) zeroz @ ipp2-2325 2024/04/13-02:31:32 ~/scratch.zeroz_sw/github_bug/3720
$ cat make_conv.py
import torch

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()
        self.conv = torch.nn.Conv2d(640, 512, (3,3), stride=1, padding=1, groups=2)

    def forward(self, x):
        x = self.conv(x)
        return x

dummy_input = torch.randn(1, 640, 14, 14, device="cpu")
model = TinyModel()

input_names = [ "input" ]
output_names = [ "output" ]

# dynamic_axes={'input' : {0 : 'bs'}, 'output' : {0 : 'bs'}}
# export(..., dynamic_axes = dynamic_axes)
torch.onnx.export(model, dummy_input, "conv.onnx", opset_version=13, verbose=True, input_names=input_names, output_names=output_names)

@Liupei1101
Copy link
Author

Liupei1101 commented Apr 15, 2024

i find the error,because
image
when group>1, reshape node -1 can not know the shape, when i convert to
image
it solved. thanks

@Liupei1101
Copy link
Author

I quickly create a minimal reproduce but cannot repro:

$ ls
conv.onnx  make_conv.py
(base) zeroz @ ipp2-2325 2024/04/13-02:31:32 ~/scratch.zeroz_sw/github_bug/3720
$ cat make_conv.py
import torch

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()
        self.conv = torch.nn.Conv2d(640, 512, (3,3), stride=1, padding=1, groups=2)

    def forward(self, x):
        x = self.conv(x)
        return x

dummy_input = torch.randn(1, 640, 14, 14, device="cpu")
model = TinyModel()

input_names = [ "input" ]
output_names = [ "output" ]

# dynamic_axes={'input' : {0 : 'bs'}, 'output' : {0 : 'bs'}}
# export(..., dynamic_axes = dynamic_axes)
torch.onnx.export(model, dummy_input, "conv.onnx", opset_version=13, verbose=True, input_names=input_names, output_names=output_names)

thanks,i solved.

@zerollzeng
Copy link
Collaborator

Ok, I'm closing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants