Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
djmaxus committed Dec 26, 2024
1 parent 28b97c2 commit 897a75b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ unwrap_in_result = "warn"
cast_lossless = "warn"
indexing_slicing = "warn"
trivially_copy_pass_by_ref = "warn"
let_unit_value = "warn"

[lints.rustdoc]
private_doc_tests = "warn"
Expand Down
2 changes: 1 addition & 1 deletion src/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<V: Value, G> Grad<V> for G where
{
}

// FIXME: replace some trait bounds with another bounded traits like `NumOps` to write less code
// FIXME: reduce trait bounds following the API best practices
// TODO: implement construction of independent variables here
// TODO: core::ops::Index(Mut) ? implement/require Iterator?
// TODO: implement `eval/map` methods (for IntoVariable output structs asl well) to sequentially evaluate functions on dual number(s)
Expand Down
48 changes: 21 additions & 27 deletions src/solid/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ use core::{
};
use num_traits::Zero;

/// For statically-known number of variables
///```
/// use autodj::fluid::Dual;
/// use autodj::solid::array::{DualNumber,IntoVariables};
/// let x0 : DualNumber<f64,2> = 1.0.into(); // Parameter
/// let [x, y] = [2.,3.].into_variables();
/// let f = (x - x0) * y;
/// assert_eq!(f.value(), &3.);
/// assert_eq!(f.dual().as_ref().len(), 2);
/// ```
pub type DualNumber<V, const N: usize> = crate::solid::DualNumber<V, Grad<V, N>>;

/// Array of dual components
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
pub struct Grad<V: Value, const N: usize>([V; N]);
Expand All @@ -26,12 +38,9 @@ impl<V: Value, const N: usize, Arr: Into<[V; N]>> From<Arr> for Grad<V, N> {

impl<V: Value, const N: usize> AddAssign for Grad<V, N> {
fn add_assign(&mut self, rhs: Self) {
debug_assert_eq!(rhs.0.len(), self.0.len());
for (index, elem) in self.0.iter_mut().enumerate() {
// SAFETY: `Self` (and `self`) are both wrapped arrays of the same length
let value = unsafe { rhs.0.get_unchecked(index) }.to_owned();
self.0.iter_mut().zip(rhs.0).for_each(|(elem, value)| {
*elem += value;
}
});
}
}

Expand Down Expand Up @@ -88,24 +97,12 @@ where
}
}

/// For statically-known number of variables
///```
/// use autodj::fluid::Dual;
/// use autodj::solid::array::*;
/// let x0 : DualNumber<f64,2> = 1.0.into(); // Parameter
/// let [x, y] = [2.,3.].into_variables();
/// let f = (x - x0) * y;
/// assert_eq!(f.value(), &3.);
/// assert_eq!(f.dual().as_ref().len(), 2);
/// ```
pub type DualNumber<V, const N: usize> = crate::solid::DualNumber<V, Grad<V, N>>;

/// Construct independent variables from array
pub trait IntoVariables<V: Value, const N: usize>: Into<[V; N]> {
/// Construct independent variables from array
fn into_variables(self) -> [DualNumber<V, N>; N] {
let arr: [V; N] = self.into();
from_fn(|index| {
from_fn(move |index| {
let grad: [V; N] = from_fn(|grad_index| {
if grad_index == index {
V::one()
Expand All @@ -130,17 +127,14 @@ impl<V: Value, const N: usize> Display for Grad<V, N> {
impl<V: Value + LowerExp, const N: usize> LowerExp for Grad<V, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "+[")?;
for index in 1..=N {
// SAFETY (index - 1) in 0..=(N-1) by construction
debug_assert!(index >= 1);
debug_assert!((index - 1) <= (N - 1));
let deriv_value = unsafe { self.0.get_unchecked(index - 1) };
for (index, deriv_value) in self.0.iter().enumerate() {
write!(f, "{deriv_value:e}")?;
if index == N {
break;
if index == (N - 1) {
write!(f, "]")?;
} else {
write!(f, ", ")?;
}
write!(f, ", ")?;
}
write!(f, "]")
Ok(())
}
}
10 changes: 10 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,13 @@ mod vector {
assert_eq!(f, 0.5);
}
}

mod array {
#[test]
fn zero_sized_grad() {
use autodj::fluid::Dual;
let scalar = autodj::solid::array::DualNumber::<f64, 0>::new(1.0, [].into());
let grad_len = scalar.dual().as_ref().len();
assert_eq!(grad_len, 0);
}
}

0 comments on commit 897a75b

Please sign in to comment.