Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cask (shader run failed) #3459

Open
DaraOrange opened this issue Nov 15, 2023 · 12 comments
Open

Cask (shader run failed) #3459

DaraOrange opened this issue Nov 15, 2023 · 12 comments
Assignees
Labels
duplicate This issue or pull request already exists internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers

Comments

@DaraOrange
Copy link

Hello! I'm trying to convert model to int8. trtexec converts in successfully, but while converting with C++ API and my own calibrator I get the following error.
1: [softMaxV2Runner.cpp::execute::226] Error Code 1: Cask (shader run failed)
Polygraphy convert with implemented dataloader.py also gives this error.
Could you suppose what could be wrong?

@zerollzeng
Copy link
Collaborator

This looks like dup of #3339 which I've already filed internal to track.

@zerollzeng zerollzeng self-assigned this Nov 18, 2023
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Nov 18, 2023
@zerollzeng
Copy link
Collaborator

Would be great if you can also provide a reproduce too.

@zerollzeng zerollzeng added duplicate This issue or pull request already exists internal-bug-tracked Tracked internally, will be fixed in a future release. labels Nov 18, 2023
@DaraOrange
Copy link
Author

I can't provide the code, but I localized the reason of crash - NmsOp from MMDet. However, I don't know how to rewrite it yet.
Is that bug fixed?

@DaraOrange
Copy link
Author

Basically, is it possible to use nms in tensorrt int8? Could you advice me some working option?

@zerollzeng
Copy link
Collaborator

NmsOp from MMDet

is it a standard onnx operator or a customer op?

@DaraOrange
Copy link
Author

It is an operator from MMDet library. Now I rewrote it in Pytorch and trying to convert to onnx:

def nms(boxes: array_like_type,
        scores: array_like_type,
        iou_threshold: float,
        offset: int = 0,
        score_threshold: float = 0,
        max_num: int = -1) -> Tuple[array_like_type, array_like_type]:
    assert isinstance(boxes, Tensor)
    assert isinstance(scores, Tensor)
    assert boxes.size(1) == 4
    assert boxes.size(0) == scores.size(0)
    assert offset in (0, 1)

    if score_threshold > 0:
        boxes = boxes[scores > score_threshold]
        scores = scores[scores > score_threshold]
   
    N = len(boxes)
    max_l = torch.max(boxes[:, 0].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 0].repeat(N))
    min_r = torch.min(boxes[:, 2].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 2].repeat(N))
    max_u = torch.max(boxes[:, 1].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 1].repeat(N))
    min_d = torch.min(boxes[:, 3].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 3].repeat(N))
    
    diff_l = (min_r - max_l).reshape((N, N))
    diff_r = (min_d - max_u).reshape((N, N))
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sum_areas = (areas.unsqueeze(-1).repeat(1, N).flatten() +
                 areas.repeat(N)).reshape((N, N))
    iou = (diff_l * diff_r) / (sum_areas - diff_l * diff_r)

    mask = (diff_l > 0).float() * (diff_r > 0).float() * (iou > iou_threshold).float()
    inds = torch.nonzero(mask)
    bad_inds = inds[inds[:,1] < inds[:,0]][:,0].flatten()#.unique()
    mask_inds = torch.ones(N)

    mask_inds[bad_inds] = 0
    inds = torch.nonzero(mask_inds).flatten()
    
    dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
    return dets, inds

I got the same error in dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) or even just boxes[inds]

@DaraOrange
Copy link
Author

DaraOrange commented Nov 22, 2023

I prepared a toy example of a mistake in this function.

  1. Code of onnx model creation:
import torch
import numpy as np

def nms(boxes,
        scores,
        iou_threshold: float,
        offset: int = 0,
        score_threshold: float = 0,
        max_num: int = -1):
    assert boxes.size(1) == 4
    assert boxes.size(0) == scores.size(0)
    assert offset in (0, 1)
    
    if score_threshold > 0:
        boxes = boxes[scores > score_threshold]
        scores = scores[scores > score_threshold]

    N = len(boxes)
    max_l = torch.max(boxes[:, 0].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 0].repeat(N))
    min_r = torch.min(boxes[:, 2].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 2].repeat(N))
    max_u = torch.max(boxes[:, 1].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 1].repeat(N))
    min_d = torch.min(boxes[:, 3].unsqueeze(-1).repeat(1, N).flatten(), boxes[:, 3].repeat(N))
    
    diff_l = (min_r - max_l).reshape((N, N))
    diff_r = (min_d - max_u).reshape((N, N))
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sum_areas = (areas.unsqueeze(-1).repeat(1, N).flatten() +
                 areas.repeat(N)).reshape((N, N))
    iou = (diff_l * diff_r) / (sum_areas - diff_l * diff_r)

    mask = (diff_l > 0).float() * (diff_r > 0).float() * (iou > iou_threshold).float()
    inds = torch.nonzero(mask)
    bad_inds = inds[inds[:,1] < inds[:,0]][:,0].flatten()#.unique()
    mask_inds = torch.ones(N)

    mask_inds[bad_inds] = 0
    inds = torch.nonzero(mask_inds).flatten()
    
    res = torch.zeros(100, 4)
    boxes_cnt = torch.LongTensor([inds.shape[0], 100]).min()
    res[:boxes_cnt] = boxes[inds][:boxes_cnt]
    return res

class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.w_b = torch.nn.Linear(16, 4)
        self.w_s = torch.nn.Linear(16, 1)

    def forward(self, x):
        x_b = self.w_b(x)
        x_s = self.w_s(x)
        x_s = x_s.sigmoid()
        x_b[:,2] += x_b[:,0]
        x_b[:,3] += x_b[:,1]
        return nms(x_b[0], x_s[0], 0.7, 0, 0)

dummy_boxes = torch.randn((1, 4300, 16)).cuda()
dummy_model = DummyModel().cuda()
torch.onnx.export(dummy_model, dummy_boxes, "model.onnx", verbose=True, 
                  input_names=["x_in"], 
                  output_names=["x_out"], 
                  dynamic_axes=None, opset_version=11)
  1. And then I use polygraphy tool with a following dummy dataloader:
import numpy as np

def load_data():
    for i in range(100):
        img = np.random.randn(1, 4300, 16).astype(np.float32)
        yield {"x_in": img}

Polygraphy command:

polygraphy convert model.onnx --int8  --data-loader-script ./dummy_dataloader.py     --calibration-cache calib_dummy.cache -o dummy.engine --pool-limit workspace:2G --verbose

I got the following error (it differs, but proves that there is some bug in tensorrt while such simple nms code conversion)

[E] 1: [executionContext.cpp::commonEmitDebugTensor::1821] Error Code 1: Cuda Runtime (invalid argument)
[E] 3: [engine.cpp::~Engine::289] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/engine.cpp::~Engine::289, condition: mExecutionContextCounter.use_count() == 1. Destroying an engine object before destroying the IExecutionContext objects it created leads to undefined behavior.
    )
[E] 2: [calibrator.cpp::calibrateEngine::1181] Error Code 2: Internal Error (Assertion context->executeV2(&bindings[0]) failed. )
[!] Invalid Engine. Please ensure the engine was built correctly

@Data-Iab
Copy link

@Steven-stars
Copy link

I got the same error and still can't solve it, please help me!

@Steven-stars
Copy link

Have you solved this problem now?

@DaraOrange
Copy link
Author

I just removed nms layer from my model. I still don't know about good decision :(

@Egorundel
Copy link

@DaraOrange Hello, can you help me, please?

#4053

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists internal-bug-tracked Tracked internally, will be fixed in a future release. triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants