diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 995cc937bdb..3cf3b4aa9f9 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -5223,6 +5223,7 @@ class GroupedMatmulFunctor { Maybe operator()(const TensorTuple& xs, const TensorTuple& weights) const { const int64_t input_size = xs.size(); const int64_t weight_size = weights.size(); + CHECK_LT_OR_RETURN(input_size, kMaxInputCount); CHECK_GE_OR_RETURN(input_size, 1) << Error::RuntimeError() << "The number of xs should be greater equal than 1."; CHECK_EQ_OR_RETURN(weight_size, input_size) diff --git a/oneflow/user/kernels/grouped_matmul_bias.cu b/oneflow/user/kernels/grouped_matmul_bias.cu index c23d9c925b8..2022fbec012 100644 --- a/oneflow/user/kernels/grouped_matmul_bias.cu +++ b/oneflow/user/kernels/grouped_matmul_bias.cu @@ -190,7 +190,13 @@ class GroupedMatmulBiasKernel final : public user_op::OpKernel, public user_op:: } void* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0)->mut_dptr(); for (const auto& group : groups) { - ApplyGroup(group.first, group.second, has_biases, workspace, ctx->stream()); + for (size_t i = 0; i < group.second.size(); i += kMaxProblemBatch) { + std::vector> ptrs( + {group.second.begin() + i, + group.second.begin() + i + + std::min(group.second.size() - i, kMaxProblemBatch)}); + ApplyGroup(group.first, ptrs, has_biases, workspace, ctx->stream()); + } } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }