Skip to content

Commit

Permalink
change public func decode's behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Mar 7, 2024
1 parent f20dd03 commit d123859
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
8 changes: 6 additions & 2 deletions Sources/LLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ open class LLM: ObservableObject {
let processorCount = UInt32(ProcessInfo().processorCount)
self.maxTokenCount = Int(min(maxTokenCount, llama_n_ctx_train(model)))
params.seed = seed
params.n_ctx = UInt32(maxTokenCount) + (maxTokenCount % 2 == 1 ? 1 : 2)
params.n_ctx = UInt32(self.maxTokenCount)
params.n_batch = params.n_ctx
params.n_threads = processorCount
params.n_threads_batch = processorCount
Expand Down Expand Up @@ -359,10 +359,14 @@ open class LLM: ObservableObject {
}

private var multibyteCharacter: [CUnsignedChar] = []
public func decode(_ token: Token) -> String {
private func decode(_ token: Token) -> String {
return model.decode(token, with: &multibyteCharacter)
}

public func decode(_ tokens: [Token]) -> String {
return tokens.map({model.decodeOnly($0)}).joined()
}

@inlinable
public func encode(_ text: borrowing String) -> [Token] {
model.encode(text)
Expand Down
9 changes: 9 additions & 0 deletions Tests/LLMTests/LLMTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,13 @@ final class LLMTests: XCTestCase {
await bot.respond(to: input)
#assert(bot.output == "tl;dr")
}

func testEncodingDecodingFromHuggingFaceModel() async throws {
let bot = try await LLM(from: model)
let input = "have you heard of this so-called LLM.swift library?"
let tokens = bot.encode(input)
let decoded = bot.decode(tokens).trimmingCharacters(in: .whitespacesAndNewlines)
#assert(!tokens.isEmpty)
#assert(decoded == input)
}
}

0 comments on commit d123859

Please sign in to comment.