Skip to content

Commit

Permalink
修改check_tensor的实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Mar 18, 2024
1 parent 5c29569 commit e834b5b
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions source/runtime/runtime_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,16 @@ static sftensor CreateTensor(const std::vector<int32_t>& operand_shapes) {

static void CheckAndReshapeTensor(sftensor& output_tensor,
const std::vector<int32_t>& operand_shapes) {
const std::vector<uint32_t>& tensor_shapes = output_tensor->shapes();
switch (operand_shapes.size()) {
case 4:
if (tensor_shapes[0] != operand_shapes[1] || tensor_shapes[1] != operand_shapes[2] ||
tensor_shapes[2] != operand_shapes[3]) {
output_tensor->Reshape({(uint32_t)operand_shapes[1], (uint32_t)operand_shapes[2],
(uint32_t)operand_shapes[3]});
}
output_tensor->Reshape(
{(uint32_t)operand_shapes[1], (uint32_t)operand_shapes[2], (uint32_t)operand_shapes[3]});
break;
case 3:
if (tensor_shapes[0] != 1 || tensor_shapes[1] != operand_shapes[1] ||
tensor_shapes[2] != operand_shapes[2]) {
output_tensor->Reshape({(uint32_t)operand_shapes[1], (uint32_t)operand_shapes[2]});
}
output_tensor->Reshape({(uint32_t)operand_shapes[1], (uint32_t)operand_shapes[2]});
break;
case 2:
if (tensor_shapes[0] != 1 || tensor_shapes[1] != operand_shapes[1] || tensor_shapes[2] != 1) {
output_tensor->Reshape({(uint32_t)operand_shapes[1]});
}
output_tensor->Reshape({(uint32_t)operand_shapes[1]});
break;
default:
LOG(FATAL) << "Unknown output operand shape length: " << operand_shapes.size();
Expand Down Expand Up @@ -160,9 +151,9 @@ void RuntimeOperatorUtils<float>::InitOperatorOutput(
runtime_op->output_operands->type = RuntimeDataType::kTypeFloat32;
const auto& prev_runtime_op_tensors = prev_runtime_op->output_operands->datas;
for (uint32_t b = 0; b < batch; ++b) {
sftensor output_tensor = prev_runtime_op_tensors.at(b);
output_tensors->datas[b] =
std::make_shared<ftensor>(output_tensor->raw_ptr(), output_tensor->shapes());
sftensor prev_output_tensor = prev_runtime_op_tensors.at(b);
output_tensors->datas[b] = std::make_shared<ftensor>(prev_output_tensor->raw_ptr(),
prev_output_tensor->shapes());
CheckAndReshapeTensor(output_tensors->datas[b], operand_shapes);
}
prev_runtime_op->occur_end_time = runtime_op->end_time;
Expand Down

0 comments on commit e834b5b

Please sign in to comment.