Skip to content

Commit

Permalink
sampling v2 update
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Oct 19, 2024
1 parent 4689414 commit d1a528b
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions Sources/LLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ open class LLM: ObservableObject {
}
}

public var seed: UInt32
public var topK: Int32
public var topP: Float
public var temp: Float
Expand Down Expand Up @@ -74,13 +75,13 @@ open class LLM: ObservableObject {
#endif
let model = llama_load_model_from_file(self.path, modelParams)!
params = llama_context_default_params()
let processorCount = UInt32(ProcessInfo().processorCount)
let processorCount = Int32(ProcessInfo().processorCount)
self.maxTokenCount = Int(min(maxTokenCount, llama_n_ctx_train(model)))
params.seed = seed
params.n_ctx = UInt32(self.maxTokenCount)
params.n_batch = params.n_ctx
params.n_threads = processorCount
params.n_threads_batch = processorCount
self.seed = seed
self.topK = topK
self.topP = topP
self.temp = temp
Expand Down Expand Up @@ -186,22 +187,17 @@ open class LLM: ObservableObject {
@InferenceActor
private func predictNextToken() async -> Token {
guard shouldContinuePredicting else { return model.endToken }
let logits = llama_get_logits_ith(context.pointer, batch.n_tokens - 1)!
var candidates = (0..<totalTokenCount).map { token in
llama_token_data(id: Int32(token), logit: logits[token], p: 0.0)
}
var token: llama_token!
candidates.withUnsafeMutableBufferPointer { pointer in
var candidates = llama_token_data_array(
data: pointer.baseAddress,
size: totalTokenCount,
sorted: false
)
llama_sample_top_k(context.pointer, &candidates, topK, 1)
llama_sample_top_p(context.pointer, &candidates, topP, 1)
llama_sample_temp(context.pointer, &candidates, temp)
token = llama_sample_token(context.pointer, &candidates)
}
let samplerParams = llama_sampler_chain_default_params()
let sampler = llama_sampler_chain_init(samplerParams)

llama_sampler_chain_add(sampler, llama_sampler_init_top_k(topK))
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(topP, 1))
llama_sampler_chain_add(sampler, llama_sampler_init_temp(temp))
llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed))

let i = batch.n_tokens - 1
let token = llama_sampler_sample(sampler, context.pointer, i)

batch.clear()
batch.add(token, currentCount, [0], true)
context.decode(batch)
Expand Down

0 comments on commit d1a528b

Please sign in to comment.