Skip to content

Commit

Permalink
Merge branch 'add-kMaxInputCount-in-GroupedMatmulFunctor' of https://…
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Aug 25, 2023
2 parents 354cf21 + 8347ed8 commit 6705366
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,7 @@ class GroupedMatmulFunctor {
Maybe<TensorTuple> operator()(const TensorTuple& xs, const TensorTuple& weights) const {
const int64_t input_size = xs.size();
const int64_t weight_size = weights.size();
CHECK_LT_OR_RETURN(input_size, kMaxInputCount);
CHECK_GE_OR_RETURN(input_size, 1)
<< Error::RuntimeError() << "The number of xs should be greater equal than 1.";
CHECK_EQ_OR_RETURN(weight_size, input_size)
Expand Down
8 changes: 7 additions & 1 deletion oneflow/user/kernels/grouped_matmul_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,13 @@ class GroupedMatmulBiasKernel final : public user_op::OpKernel, public user_op::
}
void* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr();
for (const auto& group : groups) {
ApplyGroup<T>(group.first, group.second, has_biases, workspace, ctx->stream());
for (size_t i = 0; i < group.second.size(); i += kMaxProblemBatch) {
std::vector<Buffer<T>> ptrs(
{group.second.begin() + i,
group.second.begin() + i
+ std::min<size_t>(group.second.size() - i, kMaxProblemBatch)});
ApplyGroup<T>(group.first, ptrs, has_biases, workspace, ctx->stream());
}
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down

0 comments on commit 6705366

Please sign in to comment.