Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Reserve more eval memory and use ggml scratch buffers #116

Merged
merged 3 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,28 @@ impl Context {
pub fn used_mem(&self) -> usize {
unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) }
}

/// Sets the scratch buffer to be used by this [Context].
///
/// If `scratch_buffer` is `None`, the scratch buffer will be disabled.
pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a mut Buffer>) {
let (size, data) = if let Some(buffer) = scratch_buffer {
(buffer.data.len(), buffer.data.as_ptr() as *mut c_void)
} else {
(0, std::ptr::null_mut())
};
// SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API
unsafe {
ggml_sys::ggml_set_scratch(
self.ptr.as_ptr(),
ggml_sys::ggml_scratch {
offs: 0,
size,
data,
},
);
}
}
}

impl Drop for Context {
Expand All @@ -390,6 +412,31 @@ impl Drop for Context {
}
}

/// A buffer of memory that can be used as a scratch buffer for a [Context].
///
/// See [Context::use_scratch].
pub struct Buffer {
data: Box<[u8]>,
}

impl Buffer {
/// Creates a new buffer of the specified size.
pub fn new(size: usize) -> Self {
let mut data: Vec<u8> = Vec::with_capacity(size);

// SAFETY: The contents are intentionally uninitialized, as they will be passed to
// the ggml C API which will fill them with data.
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(size);
}

Buffer {
data: data.into_boxed_slice(),
}
}
}

/// Tensors are owned by the context. A tensor is alive as long as the
/// underlying context it was created with is alive.
pub struct Tensor {
Expand Down
42 changes: 41 additions & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ mod util;
/// The end of text token.
pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

// The size of a scratch buffer used for inference. This is used for temporary
// storage of intermediate results during inference.
//
// The specific value was copied from `llama.cpp`.
const SCRATCH_SIZE: usize = 512 * 1024 * 1024;

/// The hyperparameters of the model.
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)]
pub struct Hyperparameters {
Expand Down Expand Up @@ -103,6 +109,12 @@ pub struct InferenceSession {

/// The logits that were last predicted by the network. Zeroed out otherwise.
last_logits: Vec<f32>,

/// Scratch buffers used during inference.
///
/// The number of scratch buffers was copied from `llama.cpp`.
/// There is no specific reason for this number, but one is insufficient.
scratch: [ggml::Buffer; 2],
}
impl InferenceSession {
fn repetition_penalty_tokens(&self) -> &[TokenId] {
Expand All @@ -128,10 +140,18 @@ impl Clone for InferenceSession {
mem_per_token: self.mem_per_token,
tokens: self.tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: inference_session_scratch_buffers(),
}
}
}

fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
ggml::Buffer::new(SCRATCH_SIZE),
]
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process.
/// Can be created by calling [InferenceSession::get_snapshot].
Expand Down Expand Up @@ -1121,6 +1141,7 @@ impl Model {
mem_per_token: 0,
tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch: inference_session_scratch_buffers(),
}
}

Expand Down Expand Up @@ -1157,7 +1178,18 @@ impl Model {

// For the first run, we need to guess a maximum buffer size so we can measure
// the actual memory consumption of the temporary ggml context.
let mut buf_size = 1024 * 1024 * 1024;
//
// These numbers are from `llama.cpp`, and could potentially be more efficient.
let mut buf_size = {
let buf_size_mb = if n_layer >= 80 {
1536
} else if n_layer >= 60 {
1280
} else {
1024
};
buf_size_mb * 1024 * 1024
};
if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
// add 10% to account for ggml object overhead
buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
Expand All @@ -1175,6 +1207,8 @@ impl Model {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;

ctx0.use_scratch(Some(&mut session.scratch[0]));

// norm
{
current = ctx0.op_rms_norm(&input_layer);
Expand Down Expand Up @@ -1300,6 +1334,8 @@ impl Model {
current = ctx0.op_mul_mat(&self.layers[il].wo, &current);
}

ctx0.use_scratch(Some(&mut session.scratch[1]));

let input_feed_forward = ctx0.op_add(&current, &input_self_attention);

// feed-forward network
Expand Down Expand Up @@ -1333,6 +1369,8 @@ impl Model {
input_layer = current;
}

ctx0.use_scratch(Some(&mut session.scratch[0]));

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

Expand All @@ -1350,6 +1388,8 @@ impl Model {
input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
}

ctx0.use_scratch(None);

// logits -> probs
// inpL = ctx0.op_soft_max(&inpL);

Expand Down