Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kpbaks committed Feb 6, 2024
2 parents f15c016 + 0a53329 commit f48de1e
Showing 7 changed files with 235 additions and 2 deletions.
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion gbp-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ rstest = "0.18.2"
rand = "0.8.5"
rand_distr = "0.4.3"
typed-builder = "0.18.1"

num = "0.4.1"

[dev-dependencies]
charming = "0.3.1"
3 changes: 3 additions & 0 deletions gbp-rs/src/factorgraph/factor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::gaussian::Gaussian;

use super::message::Message;

#[derive(Debug)]
@@ -15,6 +17,7 @@ pub trait Factor {
fn robustify_loss(&self);
fn measurement_model(&self) -> MeasurementModel;
fn linerisation_point(&self) -> nalgebra::DVector<f64>;
fn get_gaussian(&self) -> &Gaussian;
}

// [1,,3].iter().
136 changes: 136 additions & 0 deletions gbp-rs/src/factorgraph/factorgraph.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::ops::AddAssign;

use crate::factorgraph::factor::Factor;
use crate::factorgraph::variable::Variable;

use super::factor::MeasurementModel;
use super::{Dropout, UnitInterval};
use crate::gaussian::Gaussian;

use typed_builder::TypedBuilder;

@@ -236,4 +239,137 @@ impl<F: Factor, V: Variable> FactorGraph<F, V> {
.map(|variable_node| variable_node.dofs)
.sum()
}

/// Get the joint distribution over all variables in the information form.
/// If non-linear factors exist, it is taken at the linearisation point.
fn joint_distribution(&self) -> Gaussian {
let dim = self.get_joint_dim();
let mut joint = Gaussian::new(dim, None, None);

// Priors
let mut var_ix = vec![0; self.variables.len()];
let mut counter = 0;

for variable_node in self.variables.iter() {
let variable = &variable_node.variable;

var_ix[variable_node.id] = counter;

joint
.information_vector
.rows_mut(counter, variable_node.dofs)
.add_assign(
&variable
.get_prior()
.information_vector
.rows(counter, variable_node.dofs),
);
joint
.precision_matrix
.view_mut(
(counter, counter + variable_node.dofs),
(counter, counter + variable_node.dofs),
)
.add_assign(
&variable
.get_prior()
.precision_matrix
.view((counter, variable_node.dofs), (counter, variable_node.dofs)),
);
counter += variable_node.dofs;
}

// Other factors
for factor_node in self.factors.iter() {
let mut fact_ix = 0;
for &adjacent_variable_node_id in factor_node.adjacent_variables.iter() {
let adjacent_variable_node = &self.variables[adjacent_variable_node_id];

// Diagonal contribution of factor
joint
.information_vector
.rows_mut(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
)
.add_assign(
factor_node
.factor
.get_gaussian()
.information_vector
.rows(fact_ix, adjacent_variable_node.dofs),
);

// joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \
// factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs]
joint
.precision_matrix
.view_mut(
(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
),
(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
),
)
.add_assign(factor_node.factor.get_gaussian().precision_matrix.view(
(fact_ix, adjacent_variable_node.dofs),
(fact_ix, adjacent_variable_node.dofs),
));

let mut other_fact_ix = 0;
for &other_adjacent_variable_node_id in factor_node.adjacent_variables.iter() {
if &other_adjacent_variable_node_id > &adjacent_variable_node_id {
let other_adjacent_variable_node =
&self.variables[other_adjacent_variable_node_id];

// off diagonal contributions of factor
// joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs] += \
// factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, other_factor_ix:other_factor_ix + other_adj_var_node.dofs]
joint
.precision_matrix
.view_mut(
(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
),
(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
),
)
.add_assign(factor_node.factor.get_gaussian().precision_matrix.view(
(fact_ix, adjacent_variable_node.dofs),
(other_fact_ix, adjacent_variable_node.dofs),
));
// joint.lam[var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \
// factor.factor.lam[other_factor_ix:other_factor_ix + other_adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs]
joint
.precision_matrix
.view_mut(
(
var_ix[other_adjacent_variable_node_id],
other_adjacent_variable_node.dofs,
),
(
var_ix[adjacent_variable_node_id],
adjacent_variable_node.dofs,
),
)
.add_assign(factor_node.factor.get_gaussian().precision_matrix.view(
(other_fact_ix, other_adjacent_variable_node.dofs),
(fact_ix, adjacent_variable_node.dofs),
));
other_fact_ix += other_adjacent_variable_node.dofs;
}

fact_ix += adjacent_variable_node.dofs;
}
}
}

joint
}
}
5 changes: 4 additions & 1 deletion gbp-rs/src/factorgraph/variable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
///
// pub trait Variable {}

// struct RobotId(usize);
@@ -17,8 +16,12 @@
// Self { node_id, robot_id }
// }
// }
use crate::gaussian::Gaussian;

pub trait Variable {
fn update_belief(&mut self);
fn prior_energy(&self) -> f64;

fn get_belief(&self) -> &Gaussian;
fn get_prior(&self) -> &Gaussian;
}
51 changes: 51 additions & 0 deletions gbp-rs/src/gaussian.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use nalgebra::{DMatrix, DVector};
// use num::Float;

// TODO: maybe there should be something to ensure the dimensions of the information vector and precision matrix match
#[derive(Debug)]
pub struct Gaussian {
size: usize, // dim
pub information_vector: DVector<f64>, // eta
pub precision_matrix: DMatrix<f64>, // lam
}

impl Gaussian {
/// information_vector commonly used symbol: lowercase eta (η)
/// precision_matrix commmonly used symbol: uppercase lambda (Λ)
pub fn new(
size: usize,
information_vector: Option<DVector<f64>>,
precision_matrix: Option<DMatrix<f64>>,
) -> Self {
let information_vector = information_vector.unwrap_or_else(|| DVector::zeros(size));
// check if the precision_matrix has an inverse
let precision_matrix = precision_matrix.unwrap_or_else(|| DMatrix::zeros(size, size));

Gaussian {
size,
information_vector,
precision_matrix,
}
}

pub fn mean(&self) -> DVector<f64> {
self.precision_matrix.try_inverse().unwrap() * &self.information_vector
}

pub fn covariance(&self) -> DMatrix<f64> {
self.precision_matrix.try_inverse().unwrap()
}

pub fn mean_and_coveriance(&self) -> (DVector<f64>, DMatrix<f64>) {
let covariance = self.covariance();
let mean = &covariance * &self.information_vector;

(mean, covariance)
}

pub fn set_with_covariance_form(&mut self, mean: DVector<f64>, covariance: DMatrix<f64>) {
// check for invertibility of covariance matrix input
self.precision_matrix = covariance.try_inverse().unwrap();
self.information_vector = &self.precision_matrix * mean;
}
}
2 changes: 2 additions & 0 deletions gbp-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pub mod factorgraph;
pub mod gaussian;

pub mod prelude {
pub use crate::factorgraph::factor::Factor;
pub use crate::factorgraph::factorgraph::FactorGraph;
pub use crate::factorgraph::message::Message;
pub use crate::factorgraph::variable::Variable;
pub use crate::gaussian;
}

0 comments on commit f48de1e

Please sign in to comment.