Skip to content

Commit

Permalink
Do not upstream - Revert "[Unity] Bump fpA_intB_gemm (#16244)"
Browse files Browse the repository at this point in the history
This reverts commit e98fdea.
  • Loading branch information
CharlieFRuan authored and MasterJH5574 committed Dec 25, 2023
1 parent 7f9127f commit 72a7644
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
15 changes: 6 additions & 9 deletions src/runtime/contrib/cutlass/weight_preprocess.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,18 @@ namespace runtime {
// The preprocessing functions are defined in C++, so we need to copy the input weight to CPU.
TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
.set_body_typed([](NDArray packed_weight, int sm, bool is_int4) {
bool is_2d = packed_weight->ndim == 2;
int num_experts = is_2d ? 1 : packed_weight->shape[0];
int rows = packed_weight->shape[is_2d ? 0 : 1];
int cols = packed_weight->shape[is_2d ? 1 : 2];

std::vector<int8_t> input_cpu(num_experts * rows * cols);
std::vector<int8_t> output_cpu(num_experts * rows * cols);
int rows = packed_weight->shape[0];
int cols = packed_weight->shape[1];
std::vector<int8_t> input_cpu(rows * cols);
std::vector<int8_t> output_cpu(rows * cols);
packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
// multiply cols by 2 since the "col" params in preprocess_weights refers to the column of
// the unpacked weight.
if (is_int4) {
cols *= 2;
}
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), num_experts, rows,
cols, is_int4, sm);
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), rows, cols,
is_int4, sm);
auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device);
out.CopyFromBytes(output_cpu.data(), output_cpu.size());
return out;
Expand Down
1 change: 0 additions & 1 deletion tests/scripts/task_config_build_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake
echo set\(USE_CUTLASS ON\) >> config.cmake
echo set\(USE_CMSISNN ON\) >> config.cmake
echo set\(USE_MSC ON\) >> config.cmake
echo set\(CMAKE_CUDA_ARCHITECTURES 75\) >> config.cmake

0 comments on commit 72a7644

Please sign in to comment.