Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #116 from juho-p/fix-out-of-context-memory
Browse files Browse the repository at this point in the history
Reserve more eval memory and use ggml scratch buffers
  • Loading branch information
philpax authored Apr 13, 2023
2 parents 7dd6748 + c48ab9f commit 5db8b4f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
47 changes: 47 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,28 @@ impl Context {
pub fn used_mem(&self) -> usize {
unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) }
}

/// Sets the scratch buffer to be used by this [Context].
///
/// If `scratch_buffer` is `None`, the scratch buffer will be disabled.
pub fn use_scratch<'a>(&'a self, scratch_buffer: Option<&'a mut Buffer>) {
let (size, data) = if let Some(buffer) = scratch_buffer {
(buffer.data.len(), buffer.data.as_ptr() as *mut c_void)
} else {
(0, std::ptr::null_mut())
};
// SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API
unsafe {
ggml_sys::ggml_set_scratch(
self.ptr.as_ptr(),
ggml_sys::ggml_scratch {
offs: 0,
size,
data,
},
);
}
}
}

impl Drop for Context {
Expand All @@ -390,6 +412,31 @@ impl Drop for Context {
}
}

/// A buffer of memory that can be used as a scratch buffer for a [Context].
///
/// See [Context::use_scratch].
pub struct Buffer {
data: Box<[u8]>,
}

impl Buffer {
/// Creates a new buffer of the specified size.
pub fn new(size: usize) -> Self {
let mut data: Vec<u8> = Vec::with_capacity(size);

// SAFETY: The contents are intentionally uninitialized, as they will be passed to
// the ggml C API which will fill them with data.
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(size);
}

Buffer {
data: data.into_boxed_slice(),
}
}
}

/// Tensors are owned by the context. A tensor is alive as long as the
/// underlying context it was created with is alive.
pub struct Tensor {
Expand Down
42 changes: 41 additions & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ mod util;
/// The end of text token.
pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

// The size of a scratch buffer used for inference. This is used for temporary
// storage of intermediate results during inference.
//
// The specific value was copied from `llama.cpp`.
const SCRATCH_SIZE: usize = 512 * 1024 * 1024;

/// The hyperparameters of the model.
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)]
pub struct Hyperparameters {
Expand Down Expand Up @@ -103,6 +109,12 @@ pub struct InferenceSession {

/// The logits that were last predicted by the network. Zeroed out otherwise.
last_logits: Vec<f32>,

/// Scratch buffers used during inference.
///
/// The number of scratch buffers was copied from `llama.cpp`.
/// There is no specific reason for this number, but one is insufficient.
scratch: [ggml::Buffer; 2],
}
impl InferenceSession {
fn repetition_penalty_tokens(&self) -> &[TokenId] {
Expand All @@ -128,10 +140,18 @@ impl Clone for InferenceSession {
mem_per_token: self.mem_per_token,
tokens: self.tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: inference_session_scratch_buffers(),
}
}
}

fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
ggml::Buffer::new(SCRATCH_SIZE),
]
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process.
/// Can be created by calling [InferenceSession::get_snapshot].
Expand Down Expand Up @@ -1121,6 +1141,7 @@ impl Model {
mem_per_token: 0,
tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch: inference_session_scratch_buffers(),
}
}

Expand Down Expand Up @@ -1157,7 +1178,18 @@ impl Model {

// For the first run, we need to guess a maximum buffer size so we can measure
// the actual memory consumption of the temporary ggml context.
let mut buf_size = 1024 * 1024 * 1024;
//
// These numbers are from `llama.cpp`, and could potentially be more efficient.
let mut buf_size = {
let buf_size_mb = if n_layer >= 80 {
1536
} else if n_layer >= 60 {
1280
} else {
1024
};
buf_size_mb * 1024 * 1024
};
if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
// add 10% to account for ggml object overhead
buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
Expand All @@ -1175,6 +1207,8 @@ impl Model {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;

ctx0.use_scratch(Some(&mut session.scratch[0]));

// norm
{
current = ctx0.op_rms_norm(&input_layer);
Expand Down Expand Up @@ -1300,6 +1334,8 @@ impl Model {
current = ctx0.op_mul_mat(&self.layers[il].wo, &current);
}

ctx0.use_scratch(Some(&mut session.scratch[1]));

let input_feed_forward = ctx0.op_add(&current, &input_self_attention);

// feed-forward network
Expand Down Expand Up @@ -1333,6 +1369,8 @@ impl Model {
input_layer = current;
}

ctx0.use_scratch(Some(&mut session.scratch[0]));

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

Expand All @@ -1350,6 +1388,8 @@ impl Model {
input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
}

ctx0.use_scratch(None);

// logits -> probs
// inpL = ctx0.op_soft_max(&inpL);

Expand Down

0 comments on commit 5db8b4f

Please sign in to comment.