Skip to content

Commit

Permalink
Create cpu test file for easily swapping tests out
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Mar 4, 2024
1 parent 76bb657 commit 1446593
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/core/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod parameter;
mod shape;
mod subterm;
mod tests;
//mod tests_cpu;

use callsite::{callsite, Callsite};
pub use compile::CompileError;
Expand Down
320 changes: 320 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
#[cfg(test)]
mod tests {
use crate::core::graph::Context;
use xla::FromRawBytes;

#[test]
fn test_mul_add_scalar_consts_and_params() {
let mut ctx = Context::new();

let three = ctx.scalar(3, xla::ElementType::F32).expect("three");

let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x");
let y = ctx.parameter("y", [], xla::ElementType::F32).expect("y");

let product = ctx.mul(x, three).expect("product");
let sum = ctx.add(product, y).expect("sum");

// output XLA
// client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation
let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [sum], &client).expect("executable");

let x_input = xla::Literal::scalar(2f32);
let y_input = xla::Literal::scalar(3f32);
// args are just provided in the order they are defined, would be nice to pass a dict or something
// a pjrtbuffer is just an array slice on some device
// but im not sure why its a nested vector instead of just one vector
let device_result = executable.execute(&[x_input, y_input]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

assert_eq!(rust_result[0], 9f32);
}

#[test]
fn test_vector_matrix_bf16() {
let mut ctx = Context::new();

let foo = ctx.vector([1, 2, 3], xla::ElementType::Bf16).expect("foo");
let bar = ctx
.matrix([[4, 5, 6], [7, 8, 9], [10, 11, 12]], xla::ElementType::Bf16)
.expect("bar");

let baz = ctx.reshape_const(foo, [1, 3]).expect("baz");
let barbaz = ctx.mul(bar, baz).expect("barbaz");

let client = xla::PjRtClient::cpu().expect("client");//(0.7, false).expect("client");
let executable = ctx.compile("test", [barbaz], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let f32_result = untupled_result
.convert(xla::ElementType::F32.primitive_type())
.expect("f32 conversion");
let rust_result = f32_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(
rust_result.as_slice(),
&[4f32, 10f32, 18f32, 7f32, 16f32, 27f32, 10f32, 22f32, 36f32]
);
}

#[test]
fn test_npy_loading() {
let mut ctx = Context::new();

let my_const = ctx.const_from_npy("test.npy").expect("my_const");
println!("{}", ctx.nodes[my_const].dtype);
let my_param = ctx
.parameter("my_param", [2, 2], xla::ElementType::S64)
.expect("my_param");

let sum = ctx.add(my_const, my_param).expect("sum");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let executable = ctx.compile("test", [sum], &client).expect("executable");

let my_param_input = xla::Literal::read_npy("test.npy", &()).expect("my_param_input");

let device_result = executable.execute(&[my_param_input]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 2, 2, 0]);
}

#[test]
fn test_multiple_outputs() {
let mut ctx = Context::new();

let three = ctx.scalar(3, xla::ElementType::F32).expect("three");

let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x");
let y = ctx.parameter("y", [], xla::ElementType::F32).expect("y");

let product = ctx.mul(x, three).expect("product");
let sum = ctx.add(product, y).expect("sum");
let sum2 = ctx.add(three, x).expect("sum2");

// output XLA
// client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation
let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx
.compile(&name, [sum, product, sum2], &client)
.expect("executable");

let x_input = xla::Literal::scalar(2f32);
let y_input = xla::Literal::scalar(3f32);
// args are just provided in the order they are defined, would be nice to pass a dict or something
// a pjrtbuffer is just an array slice on some device
// but im not sure why its a nested vector instead of just one vector
let device_result = executable.execute(&[x_input, y_input]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let (eval_sum, eval_product, eval_sum2) = host_result.to_tuple3().expect("untuple");
let rust_result1 = eval_sum.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result1);

assert_eq!(rust_result1[0], 9f32);
let rust_result2 = eval_product.to_vec::<f32>().expect("to_vec");
assert_eq!(rust_result2[0], 6f32);
let rust_result3 = eval_sum2.to_vec::<f32>().expect("to_vec");
assert_eq!(rust_result3[0], 5f32)
}

#[test]
fn test_minimum() {
let mut ctx = Context::new();

let test_const1 = ctx.const_from_npy("test.npy").expect("test_const1");
let test_const2 = ctx.const_from_npy("test2.npy").expect("test_const2");
let min = ctx.minimum(test_const1, test_const2).expect("min");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [min], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[-2, 1, 1, -2]);
}

#[test]
fn test_relu() {
let mut ctx = Context::new();

let test_const = ctx.const_from_npy("test2.npy").expect("test_const");
let relu = ctx.relu(test_const).expect("relu");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [relu], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]);
}

#[test]
fn test_slice_in_dim() {
let mut ctx = Context::new();

let test_const = ctx.const_from_npy("test2.npy").expect("test_const");
let relu = ctx.relu(test_const).expect("relu");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [relu], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 4, 4, 0]);
}

#[test]
fn test_gradient_descent_polynomial() {
let mut ctx = Context::new();

let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x");
let x2 = ctx.mul(x, x).expect("x2");
let x4 = ctx.mul(x2, x2).expect("x4");
let half = ctx.scalar(0.5, xla::ElementType::F32).expect("half");
let quadratic_term = ctx.mul(half, x2).expect("quadratic_term");
let quarter = ctx.scalar(0.25, xla::ElementType::F32).expect("half");
let quartic_term = ctx.mul(quarter, x4).expect("quartic_term");
let y = ctx.sub(quartic_term, quadratic_term).expect("y");

let dydx = ctx.diff(y, x.into()).expect("dydx");
println!("{}", ctx.to_string(dydx));
let lr = ctx.scalar(0.75, xla::ElementType::F32).expect("lr");
let update = ctx.mul(lr, dydx).expect("update");
let new_x = ctx.sub(x, update).expect("new_x");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx
.compile(&name, [y, dydx, new_x], &client)
.expect("executable");

let mut x_rust = 0.5f32;
println!("x = {}", x_rust);
let y_vals: [f32; 5] = [
-0.109375,
-0.21204352,
-0.24990773,
-0.24997526,
-0.24999404,
];

for i in 0..5 {
let x_xla = xla::Literal::scalar(x_rust);
let buffers = executable.execute(&[x_xla]).expect("execute");
let literals = buffers[0][0].to_literal_sync().expect("to_literal_sync");
let (y, dydx, x) = literals.to_tuple3().expect("untuple");
let y_rust = y.to_vec::<f32>().expect("to_vec")[0];
let dydx_rust = dydx.to_vec::<f32>().expect("to_vec")[0];
x_rust = x.to_vec::<f32>().expect("to_vec")[0];
println!("y = {}", y_rust);
assert_eq!(y_rust, y_vals[i]);
println!("dydx = {}", dydx_rust);
println!("x = {}", x_rust);
}
let x_xla = xla::Literal::scalar(x_rust);
let buffers = executable.execute(&[x_xla]).expect("execute");
let literals = buffers[0][0].to_literal_sync().expect("to_literal_sync");
let (y, dydx, x) = literals.to_tuple3().expect("untuple");
let y_rust = y.to_vec::<f32>().expect("to_vec")[0];
let dydx_rust = dydx.to_vec::<f32>().expect("to_vec")[0];
println!("y = {}", y_rust);
println!("dydx = {}", dydx_rust);
}

#[test]
fn test_gradient_descent_relu() {
let mut ctx = Context::new();

let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x");
let rx = ctx.relu(x).expect("rx");
let nx = ctx.neg(x);
let rnx = ctx.relu(nx).expect("rnx");
let y = ctx.add(rnx, rx).expect("y");

let dydx = ctx.diff(y, x.into()).expect("dydx");
println!("{}", ctx.to_string(dydx));
let lr = ctx.scalar(0.1, xla::ElementType::F32).expect("lr");
let update = ctx.mul(lr, dydx).expect("update");
let new_x = ctx.sub(x, update).expect("new_x");

let client = xla::PjRtClient::cpu().expect("client");//gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx
.compile(&name, [y, dydx, new_x], &client)
.expect("executable");

let mut x_rust = 1f32;
println!("x = {}", x_rust);
let y_vals: [f32; 10] = [
1.0,
0.9,
0.79999995,
0.6999999,
0.5999999,
0.4999999,
0.39999992,
0.29999992,
0.19999993,
0.09999993,
];

for i in 0..10 {
let x_xla = xla::Literal::scalar(x_rust);
let buffers = executable.execute(&[x_xla]).expect("execute");
let literals = buffers[0][0].to_literal_sync().expect("to_literal_sync");
let (y, dydx, x) = literals.to_tuple3().expect("untuple");
let y_rust = y.to_vec::<f32>().expect("to_vec")[0];
let dydx_rust = dydx.to_vec::<f32>().expect("to_vec")[0];
x_rust = x.to_vec::<f32>().expect("to_vec")[0];
assert_eq!(y_rust, y_vals[i]);
println!("y = {}", y_rust);
println!("dydx = {}", dydx_rust);
println!("x = {}", x_rust);
}
let x_xla = xla::Literal::scalar(x_rust);
let buffers = executable.execute(&[x_xla]).expect("execute");
let literals = buffers[0][0].to_literal_sync().expect("to_literal_sync");
let (y, dydx, x) = literals.to_tuple3().expect("untuple");
let y_rust = y.to_vec::<f32>().expect("to_vec")[0];
let dydx_rust = dydx.to_vec::<f32>().expect("to_vec")[0];
println!("y = {}", y_rust);
println!("dydx = {}", dydx_rust);
}
}

0 comments on commit 1446593

Please sign in to comment.