Skip to content

Commit

Permalink
feat: decouple input_ids and output_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Jul 10, 2024
1 parent e53fa70 commit 0311ef1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
14 changes: 12 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,20 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
const auto output_ids_base = state.output_ids + session_len_ * idx;
auto output_ids = output_ids_base;

auto incoming_input_ids = state.input_ids + session_len_ * idx;

// copy history tokens
if (!seq.tokens.empty()) {
output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);

incoming_input_ids = Copy(seq.tokens.data(), seq.tokens.size(), incoming_input_ids);
}

// copy input tokens
if (input_length) {
output_ids = Copy(input_ids, input_length, output_ids);

Copy(input_ids, input_length, incoming_input_ids);
}

// copy input tokens to prompt for prefix matching
Expand Down Expand Up @@ -682,7 +688,8 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta
IndexedCopy(s_idx,
d_idx,
std::tuple{s->output_ids, d->output_ids, session_len_},
std::tuple{s->curand_state, d->curand_state, 1});
std::tuple{s->curand_state, d->curand_state, 1},
std::tuple{s->input_ids, d->input_ids, session_len_});
}

for (const auto& [s, d, si, di] : desc) {
Expand Down Expand Up @@ -801,6 +808,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size, int cache_bl

for (auto& s : states_) {
s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
s.input_ids = (int*)allocator_->reMalloc(s.input_ids, sizeof(int) * max_batch_size * session_len_, true);
s.curand_state =
(curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true);
}
Expand Down Expand Up @@ -911,6 +919,7 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&s.h_finished, true);
allocator_->free((void**)&s.h_rope_theta, true);
allocator_->free((void**)&s.output_ids);
allocator_->free((void**)&s.input_ids);
allocator_->free((void**)&s.curand_state);
}
allocator_->free((void**)&h_cu_block_counts_, true);
Expand Down Expand Up @@ -1216,6 +1225,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>

// [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
invokeGatherOutput(state_->output_ids,
state_->input_ids,
token_ids_buf_,
init_context_length_,
g.max_init_ctx_len,
Expand Down Expand Up @@ -1558,7 +1568,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
// const int missing = state_->h_context_length[i] - seq.cache_len;
FT_CHECK(seq.input_length >= 1);
h_input_length_buf_[i] = seq.input_length;
input_d_ptrs[i] = state_->output_ids + i * session_len_ + seq.cache_len;
input_d_ptrs[i] = state_->input_ids + i * session_len_;
if (seq.input_length > 1 && pf_offset < 0) {
pf_offset = i;
}
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct BatchState {

curandState_t* curand_state;
int* output_ids; // output ids in [B, S]
int* input_ids;

float* h_rope_theta;

Expand Down
8 changes: 7 additions & 1 deletion src/turbomind/models/llama/llama_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ __global__ void KernelWrapper(Params params)
} // namespace

__global__ void gatherOutput(int* output_ids,
int* input_ids,
const int* ids,
const int* context_length,
int max_context_len,
Expand All @@ -250,6 +251,7 @@ __global__ void gatherOutput(int* output_ids,
const int batch_id = blockIdx.x;
const int context_len = context_length[batch_id];
output_ids += batch_id * max_output_len;
input_ids += batch_id * max_output_len;
for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) {
// skip padding for src
if (context_len <= src_idx && src_idx < max_context_len) {
Expand All @@ -259,11 +261,15 @@ __global__ void gatherOutput(int* output_ids,
const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len);
if (dst_idx < max_output_len) {
output_ids[dst_idx] = ids[src_idx * batch_size + batch_id];
if (src_idx == max_gen_step - 1) {
input_ids[0] = ids[src_idx * batch_size + batch_id];
}
}
}
}

void invokeGatherOutput(int* output_ids,
int* input_ids,
const int* ids,
const int* context_length,
int max_context_len,
Expand All @@ -275,7 +281,7 @@ void invokeGatherOutput(int* output_ids,
int block_size = 128;
int grid_size = batch_size;
gatherOutput<<<grid_size, block_size, 0, stream>>>(
output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
output_ids, input_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
}

__global__ void updateOutput(int** request_output_ids_ptrs,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/llama_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
const float* kv_scale);

void invokeGatherOutput(int* output_ids,
int* input_ids,
const int* ids,
const int* context_length,
int max_context_len,
Expand Down

0 comments on commit 0311ef1

Please sign in to comment.