Skip to content

Commit

Permalink
Use radix sort for samplers (#137)
Browse files Browse the repository at this point in the history
* Bump kbnf version

* Use radix sort in samplers

* Make `cargo fmt` happy
  • Loading branch information
Dan-wanna-M authored Jul 21, 2024
1 parent 4888fcc commit 87cca8b
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 24 deletions.
17 changes: 13 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/ai00-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ bytemuck = "1"
cbor4ii = { version = "0.3.2", features = ["serde1"] }
fastrand = "2"
half = "2.4"
kbnf = "0.1.3"
kbnf = "0.1.6"
voracious_radix_sort = "1.2.0"
qp-trie = "0.8"
rustc-hash = "1.1.0"
uuid = { version = "1.8.0", features = ["serde", "v4"] }
Expand Down
17 changes: 12 additions & 5 deletions crates/ai00-core/src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::Sampler;
use super::{utils, Sampler};
use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use voracious_radix_sort::RadixSort;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
Expand Down Expand Up @@ -44,19 +45,25 @@ impl Sampler for MirostatSampler {
let MirostatSampler { params, state } = self;

// sort the surprise values and truncate
let sorted = probs
let mut sorted = probs
.iter()
.copied()
.enumerate()
.sorted_unstable_by(|(_, x), (_, y)| y.total_cmp(x))
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
.map(|(id, x)| utils::F32WithIndex(id, x))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| {
// if *cum > params.top_p {
// None
// } else {
// *cum += x;
// Some((id, *cum, *x))
// }
*cum += x;
Some((id, *cum, *x))
Some((id, *cum, x))
})
.collect_vec();
let k = sorted
Expand Down
2 changes: 1 addition & 1 deletion crates/ai00-core/src/sampler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod bnf;
pub mod mirostat;
pub mod nucleus;
pub mod typical;

mod utils;
pub trait Sampler {
/// Initialize the sampler state.
fn init(&mut self, model_tokens: &[u16]);
Expand Down
20 changes: 12 additions & 8 deletions crates/ai00-core/src/sampler/nucleus.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::collections::HashMap;

use super::{utils, Sampler};
use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};

use super::Sampler;
use voracious_radix_sort::RadixSort;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
Expand Down Expand Up @@ -68,18 +68,23 @@ impl Sampler for NucleusSampler {

fn sample(&mut self, probs: &[f32]) -> u16 {
let NucleusSampler { params, state } = self;

let sorted = probs
let mut sorted = probs
.iter()
.copied()
.enumerate()
.sorted_unstable_by(|(_, x), (_, y)| y.total_cmp(x))
.map(|(id, x)| utils::F32WithIndex(id, x))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
.scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| {
if *cum > params.top_p {
None
} else {
*cum += x;
Some((id, *cum, *x))
Some((id, *cum, x))
}
})
.map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
Expand All @@ -94,7 +99,6 @@ impl Sampler for NucleusSampler {
Some((id, *cum))
})
.collect_vec();

let rand = fastrand::f32();
let token = sorted
.into_iter()
Expand Down
15 changes: 10 additions & 5 deletions crates/ai00-core/src/sampler/typical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use voracious_radix_sort::RadixSort;

use super::Sampler;
use super::{utils, Sampler};

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
Expand Down Expand Up @@ -76,11 +77,15 @@ impl Sampler for TypicalSampler {
.map(|(id, &x)| (id, x, -x.ln()))
.collect_vec();
let entropy = probs.iter().map(|(_, x, y)| x * y).sum::<f32>();
let sorted = probs
let mut sorted = probs
.into_iter()
.map(|(id, x, y)| (id, x, (y - entropy).abs()))
.sorted_unstable_by(|(_, _, x), (_, _, y)| x.total_cmp(y))
.map(|(id, x, _)| (id, x))
.map(|(id, x, y)| utils::DoubleF32WithIndex(id, x, (y - entropy).abs()))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.map(|utils::DoubleF32WithIndex(id, x, _)| (id, x))
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.tau {
Expand Down
40 changes: 40 additions & 0 deletions crates/ai00-core/src/sampler/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::cmp::Ordering;
use voracious_radix_sort::Radixable;
#[derive(Copy, Clone, Debug)]
pub struct F32WithIndex(pub usize, pub f32);
impl PartialOrd for F32WithIndex {
fn partial_cmp(&self, other: &F32WithIndex) -> Option<Ordering> {
self.1.partial_cmp(&other.1)
}
}
impl PartialEq for F32WithIndex {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}
impl Radixable<f32> for F32WithIndex {
type Key = f32;
#[inline]
fn key(&self) -> Self::Key {
self.1
}
}
#[derive(Copy, Clone, Debug)]
pub struct DoubleF32WithIndex(pub usize, pub f32, pub f32);
impl PartialOrd for DoubleF32WithIndex {
fn partial_cmp(&self, other: &DoubleF32WithIndex) -> Option<Ordering> {
self.2.partial_cmp(&other.2)
}
}
impl PartialEq for DoubleF32WithIndex {
fn eq(&self, other: &Self) -> bool {
self.2 == other.2
}
}
impl Radixable<f32> for DoubleF32WithIndex {
type Key = f32;
#[inline]
fn key(&self) -> Self::Key {
self.2
}
}

0 comments on commit 87cca8b

Please sign in to comment.