Skip to content

Commit

Permalink
Impl normal and uniform rng dists
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Apr 8, 2024
1 parent 8608264 commit c218991
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ impl Context {
| Operation::LessThan(_, _)
| Operation::LessThanEq(_, _)
| Operation::GreaterThan(_, _)
| Operation::RngNormal(_, _, _)
| Operation::RngUniform(_, _, _)
| Operation::GreaterThanEq(_, _) => {
return Err(ContextError::NonDifferentiableOpError(
self.nodes[dependent_node].callsite.clone(),
Expand Down
31 changes: 31 additions & 0 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use slotmap::SlotMap;
use smallvec::SmallVec;
use xla::{XlaOp, ArrayShape};
use std::collections::{HashMap, HashSet, VecDeque};

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -168,6 +169,36 @@ impl Context {
}
}

Operation::RngNormal(mu, sigma, shape) => {
if unda_xla_map.contains_key(&mu)
&& unda_xla_map.contains_key(&sigma)
&& xla_op_slotmap.contains_key(unda_xla_map[&mu])
&& xla_op_slotmap.contains_key(unda_xla_map[&sigma])
{
let dtype = self.nodes[mu].dtype;
let xla_op = XlaOp::rng_normal(&xla_op_slotmap[unda_xla_map[&mu]],
&xla_op_slotmap[unda_xla_map[&sigma]], &shape.to_array_shape(dtype))?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::RngUniform(min, max, shape) => {
if unda_xla_map.contains_key(&min)
&& unda_xla_map.contains_key(&max)
&& xla_op_slotmap.contains_key(unda_xla_map[&min])
&& xla_op_slotmap.contains_key(unda_xla_map[&max])
{
let dtype = self.nodes[min].dtype;
let xla_op = XlaOp::rng_uniform(&xla_op_slotmap[unda_xla_map[&min]],
&xla_op_slotmap[unda_xla_map[&max]], &shape.to_array_shape(dtype))?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::Pow(a, b) => {
if unda_xla_map.contains_key(&a)
&& unda_xla_map.contains_key(&b)
Expand Down
12 changes: 12 additions & 0 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ impl Context {
changed = true;
}
}
Operation::RngUniform(a, b, shape)
| Operation::RngNormal(a, b, shape) => {
if a == to_remove {
self.nodes[dep_node].operation = Operation::RngUniform(rep_with, b, shape);
changed = true;
} else if b == to_remove {
self.nodes[dep_node].operation = Operation::RngUniform(a, rep_with, shape);
changed = true;
}
}
Operation::Pow(a, b) => {
if a == to_remove && a == b {
self.nodes[dep_node].operation = Operation::Pow(rep_with, rep_with);
Expand Down Expand Up @@ -478,6 +488,8 @@ impl Context {
| Operation::NotEqual(a, b)
| Operation::Div(a, b)
| Operation::Pow(a, b)
| Operation::RngUniform(a, b, _)
| Operation::RngNormal(a, b, _)
| Operation::MatMul(a, b) => {
if self.nodes[a].is_const().is_none() {
to_visit.push(a);
Expand Down
4 changes: 3 additions & 1 deletion src/core/graph/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, fmt::format};

use super::*;

Expand Down Expand Up @@ -104,6 +104,8 @@ impl Context {
Operation::Exp(a) => format!("Exp ({})", self.to_string(a)),
Operation::Log(a) => format!("Log ({})", self.to_string(a)),
Operation::Transpose(a, b) => format!("Transpose: ({}) ({:?})", self.to_string(a), b),
Operation::RngUniform(a, b, shape) => format!("RngUniform: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape),
Operation::RngNormal(a, b, shape) => format!("RngNormal: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape),
Operation::Equal(a, b) => {
format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b))
}
Expand Down
46 changes: 46 additions & 0 deletions src/core/graph/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,52 @@ impl Context {
Ok(node_id)
}

pub fn rng_uniform(&mut self, min: NodeIdentifier, max: NodeIdentifier, shape: &[u32]) -> Result<NodeIdentifier> {
if self.nodes[min].dtype != self.nodes[max].dtype {
Err(ContextError::IncompatibleOperandTypes(
self.nodes[min].dtype,
self.nodes[max].dtype,
callsite!(1),
))
} else {
let shape_node = Shape::from(shape);
let node = Node {
callsite: callsite!(1),
shape: shape_node.clone(),
operation: Operation::RngUniform(min, max, shape_node),
dtype: self.nodes[min].dtype,
};
let node_id = self.nodes.insert(node);
self.dependent_nodes.entry(min).or_default().push(node_id);
self.dependent_nodes.entry(max).or_default().push(node_id);

Ok(node_id)
}
}

pub fn rng_normal(&mut self, mu: NodeIdentifier, sigma: NodeIdentifier, shape: &[u32]) -> Result<NodeIdentifier> {
if self.nodes[mu].dtype != self.nodes[sigma].dtype {
Err(ContextError::IncompatibleOperandTypes(
self.nodes[mu].dtype,
self.nodes[sigma].dtype,
callsite!(1),
))
} else {
let shape_node = Shape::from(shape);
let node = Node {
callsite: callsite!(1),
shape: shape_node.clone(),
operation: Operation::RngNormal(mu, sigma, shape_node),
dtype: self.nodes[mu].dtype,
};
let node_id = self.nodes.insert(node);
self.dependent_nodes.entry(mu).or_default().push(node_id);
self.dependent_nodes.entry(sigma).or_default().push(node_id);

Ok(node_id)
}
}

pub fn exp(&mut self, a: NodeIdentifier) -> Result<NodeIdentifier> {
let node = Node {
callsite: callsite!(1),
Expand Down
10 changes: 10 additions & 0 deletions src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum Operation {
},

OneHot(NodeIdentifier),
RngUniform(NodeIdentifier, NodeIdentifier, Shape),
RngNormal(NodeIdentifier, NodeIdentifier, Shape)
}

impl Hash for Operation {
Expand Down Expand Up @@ -147,6 +149,12 @@ impl Hash for Operation {
n_tiles.hash(state);
dim.hash(state);
}
Self::RngUniform(a, b, dim)
| Self::RngNormal(a, b, dim) => {
a.hash(state);
b.hash(state);
dim.hash(state);
}
}
}
}
Expand Down Expand Up @@ -190,6 +198,8 @@ impl PartialEq for Operation {
},
) => pred == pred2 && on_true == on_true2 && on_false == on_false2,
(&Self::TypeCast(a, ty), &Self::TypeCast(b, ty2)) => a == b && ty == ty2,
(&Self::RngUniform(a, b, shape), &Self::RngUniform(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2,
(&Self::RngNormal(a, b, shape), &Self::RngNormal(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2,
(&Self::Transpose(a, dim), &Self::Transpose(b, dim2)) => a == b && dim == dim2,
(
&Self::SliceInDim {
Expand Down
5 changes: 5 additions & 0 deletions src/core/graph/shape.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use smallvec::SmallVec;
use xla::ArrayShape;

use super::callsite::Callsite;

Expand Down Expand Up @@ -64,6 +65,10 @@ impl Shape {
}
}

pub fn to_array_shape(&self, dtype: xla::ElementType) -> ArrayShape {
ArrayShape::new(self.sizes.iter().map(|d| *d as i64).collect(), dtype)
}

pub fn matmul_shape(&self, dim2: &[u32]) -> Option<Vec<u32>> {
let dim1 = &self.sizes;
if dim1.last()? == dim2.get(dim2.len().saturating_sub(2))? {
Expand Down
2 changes: 2 additions & 0 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ impl Context {
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::RngNormal(a, b, _)
| Operation::RngUniform(a, b, _)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
Expand Down
32 changes: 31 additions & 1 deletion src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,39 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_normal_dist() {
let mut ctx = Context::new();
let mu = ctx.scalar(0, xla::ElementType::F32).expect("mu = 0");
let sigma = ctx.scalar(1, xla::ElementType::F32).expect("sigma = 1");
let mat = ctx.rng_normal(mu, sigma, &[2,3]).expect("sample the normal distribution");

let client = xla::PjRtClient::cpu().expect("client");
let name = "test";
let executable = ctx.compile(&name, [mat], &client).expect("executable");

let device_result = executable.execute::<Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

match untupled_result.shape().unwrap() {
Shape::Array(shape) => {
assert_eq!(shape.dims(), &[2,3]);
},
_ => {
panic!("Shape is not correct");
}
}
}


#[test]
fn test_large_cte() {
let mut ctx = Context::new();
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let two = ctx.scalar(2, xla::ElementType::F32).expect("2");

Expand Down

0 comments on commit c218991

Please sign in to comment.