Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master into api branch #78

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ impl Context {
//Again again, clone() here is not wonderful, there's gotta be a better way to
//store the i64 vec for Transpose
match self.nodes[dependent_node].operation.clone() {
Operation::Constant(_) => panic!("Constant found as dependent node!"),
Operation::Constant(_)
| Operation::RngUniform(_, _, _)
| Operation::RngNormal(_, _, _) => panic!("Constant found as dependent node!"),
Operation::Parameter(_) => panic!("Parameter found as dependent node!"),
Operation::StopGradient(_) => continue,

Expand Down
31 changes: 31 additions & 0 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use slotmap::SlotMap;
use smallvec::SmallVec;
use xla::{XlaOp, ArrayShape};
use std::collections::{HashMap, HashSet, VecDeque};

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -168,6 +169,36 @@ impl Context {
}
}

Operation::RngNormal(mu, sigma, shape) => {
if unda_xla_map.contains_key(&mu)
&& unda_xla_map.contains_key(&sigma)
&& xla_op_slotmap.contains_key(unda_xla_map[&mu])
&& xla_op_slotmap.contains_key(unda_xla_map[&sigma])
{
let dtype = self.nodes[mu].dtype;
let xla_op = XlaOp::rng_normal(&xla_op_slotmap[unda_xla_map[&mu]],
&xla_op_slotmap[unda_xla_map[&sigma]], &shape.to_array_shape(dtype))?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::RngUniform(min, max, shape) => {
if unda_xla_map.contains_key(&min)
&& unda_xla_map.contains_key(&max)
&& xla_op_slotmap.contains_key(unda_xla_map[&min])
&& xla_op_slotmap.contains_key(unda_xla_map[&max])
{
let dtype = self.nodes[min].dtype;
let xla_op = XlaOp::rng_uniform(&xla_op_slotmap[unda_xla_map[&min]],
&xla_op_slotmap[unda_xla_map[&max]], &shape.to_array_shape(dtype))?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::Pow(a, b) => {
if unda_xla_map.contains_key(&a)
&& unda_xla_map.contains_key(&b)
Expand Down
12 changes: 12 additions & 0 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ impl Context {
changed = true;
}
}
Operation::RngUniform(a, b, shape)
| Operation::RngNormal(a, b, shape) => {
if a == to_remove {
self.nodes[dep_node].operation = Operation::RngUniform(rep_with, b, shape);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::RngUniform(a, rep_with, shape);
changed = true;
}
}
Operation::Pow(a, b) => {
if a == to_remove && a == b {
self.nodes[dep_node].operation = Operation::Pow(rep_with, rep_with);
Expand Down Expand Up @@ -478,6 +488,8 @@ impl Context {
| Operation::NotEqual(a, b)
| Operation::Div(a, b)
| Operation::Pow(a, b)
| Operation::RngUniform(a, b, _)
| Operation::RngNormal(a, b, _)
| Operation::MatMul(a, b) => {
if self.nodes[a].is_const().is_none() {
to_visit.push(a);
Expand Down
4 changes: 3 additions & 1 deletion src/core/graph/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, fmt::format};

use super::*;

Expand Down Expand Up @@ -104,6 +104,8 @@ impl Context {
Operation::Exp(a) => format!("Exp ({})", self.to_string(a)),
Operation::Log(a) => format!("Log ({})", self.to_string(a)),
Operation::Transpose(a, b) => format!("Transpose: ({}) ({:?})", self.to_string(a), b),
Operation::RngUniform(a, b, shape) => format!("RngUniform: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape),
Operation::RngNormal(a, b, shape) => format!("RngNormal: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape),
Operation::Equal(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Expand Down
46 changes: 46 additions & 0 deletions src/core/graph/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,52 @@ impl Context {
Ok(node_id)
}

pub fn rng_uniform(&mut self, min: NodeIdentifier, max: NodeIdentifier, shape: &[u32]) -> Result<NodeIdentifier> {
if self.nodes[min].dtype != self.nodes[max].dtype {
Err(ContextError::IncompatibleOperandTypes(
self.nodes[min].dtype,
self.nodes[max].dtype,
callsite!(1),
))
} else {
let shape_node = Shape::from(shape);
let node = Node {
callsite: callsite!(1),
shape: shape_node.clone(),
operation: Operation::RngUniform(min, max, shape_node),
dtype: self.nodes[min].dtype,
};
let node_id = self.nodes.insert(node);
self.dependent_nodes.entry(min).or_default().push(node_id);
self.dependent_nodes.entry(max).or_default().push(node_id);

Ok(node_id)
}
}

pub fn rng_normal(&mut self, mu: NodeIdentifier, sigma: NodeIdentifier, shape: &[u32]) -> Result<NodeIdentifier> {
if self.nodes[mu].dtype != self.nodes[sigma].dtype {
Err(ContextError::IncompatibleOperandTypes(
self.nodes[mu].dtype,
self.nodes[sigma].dtype,
callsite!(1),
))
} else {
let shape_node = Shape::from(shape);
let node = Node {
callsite: callsite!(1),
shape: shape_node.clone(),
operation: Operation::RngNormal(mu, sigma, shape_node),
dtype: self.nodes[mu].dtype,
};
let node_id = self.nodes.insert(node);
self.dependent_nodes.entry(mu).or_default().push(node_id);
self.dependent_nodes.entry(sigma).or_default().push(node_id);

Ok(node_id)
}
}

pub fn exp(&mut self, a: NodeIdentifier) -> Result<NodeIdentifier> {
let node = Node {
callsite: callsite!(1),
Expand Down
10 changes: 10 additions & 0 deletions src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum Operation {
},

OneHot(NodeIdentifier),
RngUniform(NodeIdentifier, NodeIdentifier, Shape),
RngNormal(NodeIdentifier, NodeIdentifier, Shape)
}

impl Hash for Operation {
Expand Down Expand Up @@ -147,6 +149,12 @@ impl Hash for Operation {
n_tiles.hash(state);
dim.hash(state);
}
Self::RngUniform(a, b, dim)
| Self::RngNormal(a, b, dim) => {
a.hash(state);
b.hash(state);
dim.hash(state);
}
}
}
}
Expand Down Expand Up @@ -190,6 +198,8 @@ impl PartialEq for Operation {
},
) => pred == pred2 && on_true == on_true2 && on_false == on_false2,
(&Self::TypeCast(a, ty), &Self::TypeCast(b, ty2)) => a == b && ty == ty2,
(&Self::RngUniform(a, b, shape), &Self::RngUniform(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2,
(&Self::RngNormal(a, b, shape), &Self::RngNormal(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2,
(&Self::Transpose(a, dim), &Self::Transpose(b, dim2)) => a == b && dim == dim2,
(
&Self::SliceInDim {
Expand Down
5 changes: 5 additions & 0 deletions src/core/graph/shape.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use smallvec::SmallVec;
use xla::ArrayShape;

use super::callsite::Callsite;

Expand Down Expand Up @@ -64,6 +65,10 @@ impl Shape {
}
}

pub fn to_array_shape(&self, dtype: xla::ElementType) -> ArrayShape {
ArrayShape::new(self.sizes.iter().map(|d| *d as i64).collect(), dtype)
}

pub fn matmul_shape(&self, dim2: &[u32]) -> Option<Vec<u32>> {
let dim1 = &self.sizes;
if dim1.last()? == dim2.get(dim2.len().saturating_sub(2))? {
Expand Down
2 changes: 2 additions & 0 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ impl Context {
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::RngNormal(a, b, _)
| Operation::RngUniform(a, b, _)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
Expand Down
61 changes: 60 additions & 1 deletion src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,68 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_normal_dist() {
let mut ctx = Context::new();
let mu = ctx.scalar(0, xla::ElementType::F32).expect("mu = 0");
let sigma = ctx.scalar(1, xla::ElementType::F32).expect("sigma = 1");
let mat = ctx.rng_normal(mu, sigma, &[2,3]).expect("sample the normal distribution");

let client = xla::PjRtClient::cpu().expect("client");
let name = "test";
let executable = ctx.compile(&name, [mat], &client).expect("executable");

let device_result = executable.execute::<Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

match untupled_result.shape().unwrap() {
Shape::Array(shape) => {
assert_eq!(shape.dims(), &[2,3]);
},
_ => {
panic!("Shape is not correct");
}
}
}

#[test]
fn test_uniform_dist() {
let mut ctx = Context::new();
let min = ctx.scalar(0, xla::ElementType::F32).expect("min = 0");
let max = ctx.scalar(1, xla::ElementType::F32).expect("max = 10");
let mat = ctx.rng_uniform(min, max, &[10,1]).expect("sample the uniform distribution");

let client = xla::PjRtClient::cpu().expect("client");
let name = "test";
let executable = ctx.compile(&name, [mat], &client).expect("executable");

let device_result = executable.execute::<Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

match untupled_result.shape().unwrap() {
Shape::Array(shape) => {
assert_eq!(shape.dims(), &[10,1]);
},
_ => {
panic!("Shape is not correct");
}
}
}


#[test]
fn test_large_cte() {
let mut ctx = Context::new();
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let two = ctx.scalar(2, xla::ElementType::F32).expect("2");

Expand Down
Loading