Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed shape checking, added comparison and select operations #37

Merged
merged 5 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 95 additions & 5 deletions src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use super::*;

impl Context {
pub fn stop_gradient(&mut self, node: NodeIdentifier) -> NodeIdentifier {
self.nodes.insert(Node {
callsite: callsite!(1),
shape: self.nodes[node].shape.clone(),
operation: Operation::StopGradient(node),
dtype: self.nodes[node].dtype,
})
}

pub fn diff(&mut self, node: NodeIdentifier, with_respect_to: Parameter) -> NodeIdentifier {
self.nodes.insert(Node {
callsite: callsite!(1),
Expand Down Expand Up @@ -28,17 +37,31 @@ impl Context {
// leaf nodes mean no further processing
Operation::Constant(_) => Ok(false),
Operation::Parameter(_) => Ok(false),
Operation::StopGradient(_) => Ok(false),
// operations mean we need to go deeper
Operation::Add(a, b) => {
Operation::Add(a, b)
| Operation::Mul(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::LessThanEq(a, b)
| Operation::GreaterThanEq(a, b) => {
let r = self.autodiff(a, modification_limit)?;
self.autodiff(b, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::Mul(a, b) => {
let r = self.autodiff(a, modification_limit)?;
self.autodiff(b, modification_limit - (r as usize))
Operation::Select {
pred: _,
on_true,
on_false,
} => {
let r = self.autodiff(on_true, modification_limit)?;
self.autodiff(on_false, modification_limit - (r as usize))
.map(|v| v || r)
}
Operation::TypeCast(node, ty) => self.autodiff(node, modification_limit),
Operation::SliceInDim { node, .. } => self.autodiff(node, modification_limit),
Operation::ZerosLike(node) => self.autodiff(node, modification_limit),
// finally a Diff node, lets distribute it
Operation::Diff(outer, outer_param) => {
let outer_node = &self.nodes[outer];
Expand All @@ -56,13 +79,14 @@ impl Context {
// derivative of a parameter with respect to itself is one, and otherwise zero
self.nodes[input].operation = Operation::Constant(ConstantBinding {
value: xla::Literal::scalar(
(outer == outer_param.into()) as u32 as f32,
(outer == outer_param.into()) as u32,
)
.convert(outer_dtype)?,
});
self.nodes[input].shape = [].into();
Ok(true)
}
Operation::StopGradient(_) => Ok(false),
Operation::Add(a, b) => {
// derivative of a sum is the sum of derivatives
// Diff (Sum a b) x = Sum (Diff a x) (Diff b x)
Expand Down Expand Up @@ -129,6 +153,72 @@ impl Context {
// rerun autodiff on the node we replaced
self.autodiff(input, modification_limit - 1)
}

Operation::Equal(_, _)
| Operation::LessThan(_, _)
| Operation::GreaterThan(_, _)
| Operation::LessThanEq(_, _)
| Operation::GreaterThanEq(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),
Operation::TypeCast(_, _) => Err(ContextError::NonDifferentiableError(outer_node.callsite.clone())),

Operation::Select {
pred,
on_true,
on_false,
} => {
// derivative of select is select of derivatives
let diff_true_node = Node {
// propagate original Diff callsite to the new Diff node
callsite: input_node.callsite.clone(),
shape: self.nodes[on_true].shape.clone(),
operation: Operation::Diff(on_true, outer_param),
dtype: self.nodes[on_true].dtype,
};
let diff_false_node = Node {
// propagate original Diff callsite to the new Diff node
callsite: input_node.callsite.clone(),
shape: self.nodes[on_false].shape.clone(),
operation: Operation::Diff(on_false, outer_param),
dtype: self.nodes[on_false].dtype,
};
// propagate original Mul callsite to the new Add node
self.nodes[input].callsite = outer_node.callsite.clone();
let diff_true = self.nodes.insert(diff_true_node);
let diff_false = self.nodes.insert(diff_false_node);

self.nodes[input].operation = Operation::Select {
pred: pred,
on_true: diff_true,
on_false: diff_false,
};
// rerun autodiff on the node we replaced
self.autodiff(input, modification_limit - 1)
}

/*
Operation::SliceInDim { node, start, stop, stride, dim } => {
let diff_node = Node {
callsite: input_node.callsite.clone(),
shape: self.nodes[node].shape.clone(),
operation: Operation::Diff(node, outer_param),
dtype: self.nodes[node].dtype
};
self.nodes[input].callsite = outer_node.callsite.clone();
let diff_node = self.nodes.insert(diff_node);
let zero_node = self.nodes.insert(Node {
callsite: input_node.callsite.clone(),
shape: node.shape.clone(),
operation: Operation::ZerosLike(node)
});
}
*/
Operation::SliceInDim { node, start, stop, stride, dim } => panic!("Differentiating SliceInDim not yet supported, xla-rs must implement scatter operation."),

Operation::ZerosLike(node) => {
self.nodes[input].operation = Operation::ZerosLike(node);
self.autodiff(input, modification_limit - 1)
}

Operation::Diff(inner, _) => {
// derivative of a derivative, apply the inner one first then try again on the outer.
let r = self.autodiff(inner, modification_limit)?;
Expand Down
186 changes: 184 additions & 2 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use serde_json::de;
use slotmap::SlotMap;
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet, VecDeque};
Expand Down Expand Up @@ -37,8 +38,21 @@ impl Context {
parameters.insert(this_node_id);
Ok(())
}
Operation::StopGradient(node) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Diff(_, _) => Err(CompileError::DiffNode(input_node.callsite.clone()))?,
Operation::Mul(node1, node2) | Operation::Add(node1, node2) => {
Operation::Mul(node1, node2)
| Operation::Add(node1, node2)
| Operation::Equal(node1, node2)
| Operation::LessThan(node1, node2)
| Operation::GreaterThan(node1, node2)
| Operation::LessThanEq(node1, node2)
| Operation::GreaterThanEq(node1, node2) => {
dep_nodes
.entry(node1)
.or_insert(Vec::new())
Expand All @@ -50,6 +64,48 @@ impl Context {
self.get_dependent_nodes(node1, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(node2, dep_nodes, constants, parameters)
}
Operation::TypeCast(node, _) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::SliceInDim{ node, .. } => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
Operation::Select {
pred,
on_true,
on_false,
} => {
dep_nodes
.entry(pred)
.or_insert(Vec::new())
.push(this_node_id);
dep_nodes
.entry(on_true)
.or_insert(Vec::new())
.push(this_node_id);
dep_nodes
.entry(on_false)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(pred, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(on_true, dep_nodes, constants, parameters)?;
self.get_dependent_nodes(on_false, dep_nodes, constants, parameters)
}
Operation::ZerosLike(node) => {
dep_nodes
.entry(node)
.or_insert(Vec::new())
.push(this_node_id);
self.get_dependent_nodes(node, dep_nodes, constants, parameters)
}
}
}

Expand Down Expand Up @@ -167,6 +223,12 @@ impl Context {
}
Operation::Constant(_) => unreachable!("Constants can't depend on other nodes"),
Operation::Diff(_, _) => Err(CompileError::DiffNode(node.callsite.clone()))?,
Operation::StopGradient(node) => {
let xla_id = unda_xla_map[&node];
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}

Operation::Mul(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
Expand All @@ -193,11 +255,131 @@ impl Context {
covered_ops.insert(*dependent_op);
}
}

Operation::Equal(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.eq(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::LessThan(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.lt(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::GreaterThan(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.gt(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::LessThanEq(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.le(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}

Operation::GreaterThanEq(node1, node2) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node1])
&& xla_op_slotmap.contains_key(unda_xla_map[&node1])
{
let xla_op = xla_op_slotmap[unda_xla_map[&node1]]
.ge(&xla_op_slotmap[unda_xla_map[&node2]])?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::Select {
pred,
on_true,
on_false,
} => {
if unda_xla_map.contains_key(&pred)
&& unda_xla_map.contains_key(&on_true)
&& unda_xla_map.contains_key(&on_false)
&& xla_op_slotmap.contains_key(unda_xla_map[&pred])
&& xla_op_slotmap.contains_key(unda_xla_map[&on_true])
&& xla_op_slotmap.contains_key(unda_xla_map[&on_false])
{
let xla_op = xla_op_slotmap[unda_xla_map[&pred]].select(
&xla_op_slotmap[unda_xla_map[&on_true]],
&xla_op_slotmap[unda_xla_map[&on_false]],
)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::TypeCast(node, ty) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].convert(ty.primitive_type())?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::SliceInDim{ node, start, stop, stride, dim } => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].slice_in_dim(start, stop, stride, dim)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
Operation::ZerosLike(node) => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].zeros_like()?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
covered_ops.insert(*dependent_op);
}
}
}
}
}

let xla_return_vec: Vec<&xla::XlaOp> = returns.into_iter().map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]]).collect();
let xla_return_vec: Vec<&xla::XlaOp> = returns
.into_iter()
.map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]])
.collect();
let xla_return_tuple = builder.tuple(&xla_return_vec.as_slice())?;

let xla_computation = xla_return_tuple.build()?;
Expand Down
Loading
Loading