Skip to content

Commit

Permalink
https://github.com/vllm-project/vllm/pull/4015/files
Browse files Browse the repository at this point in the history
  • Loading branch information
cthiriet committed Apr 23, 2024
1 parent c2d6f97 commit d978d3f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
13 changes: 12 additions & 1 deletion csrc/punica/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
Expand Down Expand Up @@ -59,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py

// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
Expand Down
14 changes: 7 additions & 7 deletions csrc/punica/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}

inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
Expand All @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
Expand Down Expand Up @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
Expand Down Expand Up @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,9 @@ def create_lora_weights(
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024")
"32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros(
(
max_loras,
Expand Down

0 comments on commit d978d3f

Please sign in to comment.