Skip to content

Commit

Permalink
Fix bug for #2433
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Aug 16, 2023
1 parent 1115d2f commit d82a349
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 8 deletions.
1 change: 0 additions & 1 deletion source/backend/vulkan/image/backend/VulkanTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ std::array<int, 4> VulkanTensor::tensorShapeFormat(const Tensor *input) {
iW = 1;
iC = input->buffer().dim[0].extent;
}

#ifdef LOG_VERBOSE
MNN_PRINT("tensorShapeFormat : [%d, %d, %d, %d] \n", iN, iH, iW, iC);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ void VulkanImageConverter::encodeTensorToBuffer(const Tensor* srcTensor, VkBuffe
const VulkanCommandPool::Buffer* cmdBuffer) {
auto sourceFormat = TensorUtils::getDescribe(srcTensor)->dimensionFormat;
auto destFormat = destBufferFormat;
if (sourceFormat == MNN_DATA_FORMAT_NC4HW4 && 1 >= srcTensor->width() && 1 >= srcTensor->height() &&
srcTensor->channel() % 4 == 0) {
destFormat = MNN_DATA_FORMAT_NC4HW4;
}

auto vkTensor = (VulkanTensor*)(srcTensor->deviceId());
auto tensor = srcTensor;
Expand Down
4 changes: 1 addition & 3 deletions source/backend/vulkan/image/execution/VulkanRaster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ ErrorCode VulkanRaster::onEncode(const std::vector<Tensor *> &___inputs, const s
for (int i=0; i<origin->dimensions(); ++i) {
bufferSize *= origin->length(i);
}
ivec4 dims;
writeNCHW(dims, origin);
ConvertInfo info;
info.buffer.reset(new VulkanBuffer(vkBn->getDynamicMemoryPool(),
false, bufferSize, nullptr, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT));
Expand Down Expand Up @@ -212,6 +210,7 @@ ErrorCode VulkanRaster::onEncode(const std::vector<Tensor *> &___inputs, const s
for (auto& iter : mInputBuffers) {
auto& info = iter.second;
info.convert->encodeTensorToBuffer(iter.first, info.buffer->buffer(), info.buffer->size(), 0, VulkanImageConverter::getTensorLinearFormat(iter.first), cmdBuffer);
cmdBuffer->barrierSource(info.buffer->buffer(), 0, info.buffer->size());
}
//Blit
if (needZero) {
Expand All @@ -223,7 +222,6 @@ ErrorCode VulkanRaster::onEncode(const std::vector<Tensor *> &___inputs, const s
info.describe->writeBuffer(info.srcBuffer, 1, info.srcBufferSize);
info.describe->writeBuffer(info.uniform->buffer(), 2, info.uniform->size());
info.pipeline->bind(cmdBuffer->get(), info.describe->get());
cmdBuffer->barrierSource(info.srcBuffer, 0, info.srcBufferSize);
vkCmdDispatch(cmdBuffer->get(), info.workGroup[0], info.workGroup[1], info.workGroup[2]);
}

Expand Down

0 comments on commit d82a349

Please sign in to comment.