diff --git a/src/collect.rs b/src/collect.rs index 0e2b913d..cd048c3b 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -549,7 +549,23 @@ impl<'a> AtomView<'a> { None } } - AtomView::Pow(_) | AtomView::Var(_) | AtomView::Fun(_) => None, + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + if let Ok(e) = i64::try_from(e) { + if let Some(n) = get_num(b) { + if let Coefficient::Rational(r) = n { + if e < 0 { + return Some(r.pow((-e) as u64).inv().into()); + } else { + return Some(r.pow(e as u64).into()); + } + } + } + } + + None + } + AtomView::Var(_) | AtomView::Fun(_) => None, } } @@ -609,6 +625,25 @@ impl<'a> AtomView<'a> { changed } + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + + let mut changed = false; + let mut nb = ws.new_atom(); + changed |= b.collect_num_impl(ws, &mut nb); + let mut ne = ws.new_atom(); + changed |= e.collect_num_impl(ws, &mut ne); + + if !changed { + out.set_from_view(self); + } else { + let mut np = ws.new_atom(); + np.to_pow(nb.as_view(), ne.as_view()); + np.as_view().normalize(ws, out); + } + + changed + } _ => { out.set_from_view(self); false diff --git a/src/domains/atom.rs b/src/domains/atom.rs index 85da7dcd..71384881 100644 --- a/src/domains/atom.rs +++ b/src/domains/atom.rs @@ -7,10 +7,33 @@ use super::{ integer::Integer, Derivable, EuclideanDomain, Field, InternalOrdering, Ring, SelfRing, }; +use dyn_clone::DynClone; use rand::Rng; -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub struct AtomField {} +pub trait Map: Fn(AtomView, &mut Atom) -> bool + DynClone + Send + Sync {} +dyn_clone::clone_trait_object!(Map); +impl, &mut Atom) -> bool> Map for T {} + +/// The field of general expressions. +#[derive(Clone)] +pub struct AtomField { + /// Perform a cancellation check of numerators and denominators after a division. + pub cancel_check_on_division: bool, + /// A custom normalization function applied after every operation. + pub custom_normalization: Option>, +} + +impl PartialEq for AtomField { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for AtomField {} + +impl std::hash::Hash for AtomField { + fn hash(&self, _state: &mut H) {} +} impl Default for AtomField { fn default() -> Self { @@ -20,7 +43,34 @@ impl Default for AtomField { impl AtomField { pub fn new() -> AtomField { - AtomField {} + AtomField { + custom_normalization: None, + cancel_check_on_division: false, + } + } + + #[inline(always)] + fn normalize(&self, r: Atom) -> Atom { + if let Some(f) = &self.custom_normalization { + let mut res = Atom::new(); + if f(r.as_view(), &mut res) { + res + } else { + r + } + } else { + r + } + } + + #[inline(always)] + fn normalize_mut(&self, r: &mut Atom) { + if let Some(f) = &self.custom_normalization { + let mut res = Atom::new(); + if f(r.as_view(), &mut res) { + std::mem::swap(r, &mut res); + } + } } } @@ -46,39 +96,44 @@ impl Ring for AtomField { type Element = Atom; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a + b + self.normalize(a + b) } fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a - b + self.normalize(a - b) } fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a * b + self.normalize(a * b) } fn add_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = &*a + b; + self.normalize_mut(a); } fn sub_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = &*a - b; + self.normalize_mut(a); } fn mul_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = self.mul(a, b); + self.normalize_mut(a); } fn add_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { *a = &*a + self.mul(b, c); + self.normalize_mut(a); } fn sub_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { *a = &*a - self.mul(b, c); + self.normalize_mut(a); } fn neg(&self, a: &Self::Element) -> Self::Element { - -a + self.normalize(-a) } fn zero(&self) -> Self::Element { @@ -90,11 +145,12 @@ impl Ring for AtomField { } fn pow(&self, b: &Self::Element, e: u64) -> Self::Element { - b.npow(Integer::from(e)) + self.normalize(b.npow(Integer::from(e))) } + /// Check if the result could be 0 using a statistical method. fn is_zero(a: &Self::Element) -> bool { - a.is_zero() + !a.as_view().zero_test(10, f64::EPSILON).is_false() } fn is_one(&self, a: &Self::Element) -> bool { @@ -162,7 +218,7 @@ impl EuclideanDomain for AtomField { } fn quot_rem(&self, a: &Self::Element, b: &Self::Element) -> (Self::Element, Self::Element) { - (a / b, self.zero()) + (self.div(a, b), self.zero()) } fn gcd(&self, _a: &Self::Element, _b: &Self::Element) -> Self::Element { @@ -173,16 +229,28 @@ impl EuclideanDomain for AtomField { impl Field for AtomField { fn div(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a / b + let r = a / b; + + self.normalize(if self.cancel_check_on_division { + r.cancel() + } else { + r + }) } fn div_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = self.div(a, b); + + if self.cancel_check_on_division { + *a = a.cancel(); + } + + self.normalize_mut(a); } fn inv(&self, a: &Self::Element) -> Self::Element { let one = Atom::new_num(1); - self.div(&one, a) + self.normalize(self.div(&one, a)) } } diff --git a/src/domains/float.rs b/src/domains/float.rs index d7fb1e67..c7cea6f2 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -1993,6 +1993,22 @@ pub struct ErrorPropagatingFloat { abs_err: f64, } +impl From for ErrorPropagatingFloat { + fn from(value: f64) -> Self { + if value == 0. { + ErrorPropagatingFloat { + value, + abs_err: f64::EPSILON, + } + } else { + ErrorPropagatingFloat { + value, + abs_err: f64::EPSILON * value.abs(), + } + } + } +} + impl Neg for ErrorPropagatingFloat { type Output = Self; diff --git a/src/evaluate.rs b/src/evaluate.rs index ba176cc2..48ad50dd 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -17,10 +17,13 @@ use crate::{ coefficient::CoefficientView, combinatorics::unique_permutations, domains::{ - float::{Complex, NumericalFloatLike, Real}, + float::{ + Complex, ErrorPropagatingFloat, NumericalFloatLike, Real, RealNumberLike, SingleFloat, + }, integer::Integer, rational::Rational, }, + id::ConditionResult, state::State, LicenseManager, }; @@ -237,6 +240,12 @@ impl Atom { optimization_settings.verbose, )) } + + /// Check if the expression could be 0, using (potentially) numerical sampling with + /// a given tolerance and number of iterations. + pub fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { + self.as_view().zero_test(iterations, tolerance) + } } #[derive(Debug, Clone)] @@ -4155,6 +4164,164 @@ impl<'a> AtomView<'a> { } } } + + /// Check if the expression could be 0, using (potentially) numerical sampling with + /// a given tolerance and number of iterations. + pub fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { + match self { + AtomView::Num(num_view) => { + if num_view.is_zero() { + ConditionResult::True + } else { + ConditionResult::False + } + } + AtomView::Var(_) => ConditionResult::False, + AtomView::Fun(_) => ConditionResult::False, + AtomView::Pow(p) => p.get_base().zero_test(iterations, tolerance), + AtomView::Mul(mul_view) => { + let mut is_zero = ConditionResult::False; + for arg in mul_view { + match arg.zero_test(iterations, tolerance) { + ConditionResult::True => return ConditionResult::True, + ConditionResult::False => {} + ConditionResult::Inconclusive => { + is_zero = ConditionResult::Inconclusive; + } + } + } + + is_zero + } + AtomView::Add(_) => { + // an expanded polynomial is only zero if it is a literal zero + if self.is_polynomial(false, true).is_some() { + ConditionResult::False + } else { + self.zero_test_impl(iterations, tolerance) + } + } + } + } + + fn zero_test_impl(&self, iterations: usize, tolerance: f64) -> ConditionResult { + // collect all variables and functions and fill in random variables + + let mut rng = rand::thread_rng(); + + if self.contains_symbol(State::I) { + let mut vars: HashMap<_, _> = self + .get_all_indeterminates(true) + .into_iter() + .filter_map(|x| { + let s = x.get_symbol().unwrap(); + if !State::is_builtin(s) || s == Atom::DERIVATIVE { + Some((x, Complex::new(0f64.into(), 0f64.into()))) + } else { + None + } + }) + .collect(); + + let mut cache = HashMap::default(); + + for _ in 0..iterations { + cache.clear(); + + for x in vars.values_mut() { + *x = x.sample_unit(&mut rng); + } + + let r = self + .evaluate( + |x| { + Complex::new( + ErrorPropagatingFloat::new( + 0f64.from_rational(x), + -0f64.get_epsilon().log10(), + ), + ErrorPropagatingFloat::new( + 0f64.zero(), + -0f64.get_epsilon().log10(), + ), + ) + }, + &vars, + &HashMap::default(), + &mut cache, + ) + .unwrap(); + + let res_re = r.re.get_num().to_f64(); + let res_im = r.im.get_num().to_f64(); + if res_re.is_finite() + && (res_re - r.re.get_absolute_error() > 0. + || res_re + r.re.get_absolute_error() < 0.) + || res_im.is_finite() + && (res_im - r.im.get_absolute_error() > 0. + || res_im + r.im.get_absolute_error() < 0.) + { + return ConditionResult::False; + } + + if vars.len() == 0 && r.re.get_absolute_error() < tolerance { + return ConditionResult::True; + } + } + + ConditionResult::Inconclusive + } else { + let mut vars: HashMap<_, ErrorPropagatingFloat> = self + .get_all_indeterminates(true) + .into_iter() + .filter_map(|x| { + let s = x.get_symbol().unwrap(); + if !State::is_builtin(s) || s == Atom::DERIVATIVE { + Some((x, 0f64.into())) + } else { + None + } + }) + .collect(); + + let mut cache = HashMap::default(); + + for _ in 0..iterations { + cache.clear(); + + for x in vars.values_mut() { + *x = x.sample_unit(&mut rng); + } + + let r = self + .evaluate( + |x| { + ErrorPropagatingFloat::new( + 0f64.from_rational(x), + -0f64.get_epsilon().log10(), + ) + }, + &vars, + &HashMap::default(), + &mut cache, + ) + .unwrap(); + + let res = r.get_num().to_f64(); + if res.is_finite() + && (res - r.get_absolute_error() > 0. || res + r.get_absolute_error() < 0.) + { + return ConditionResult::False; + } + + if vars.len() == 0 && r.get_absolute_error() < tolerance { + return ConditionResult::True; + } + } + + ConditionResult::Inconclusive + } + } } #[cfg(test)] @@ -4165,6 +4332,7 @@ mod test { atom::Atom, domains::{float::Float, rational::Rational}, evaluate::{EvaluationFn, FunctionMap, OptimizationSettings}, + id::ConditionResult, state::State, }; @@ -4304,4 +4472,13 @@ mod test { let r = e_f64.evaluate_single(&[1.1]); assert!((r - 1622709.2254269677).abs() / 1622709.2254269677 < 1e-10); } + + #[test] + fn zero_test() { + let e = Atom::parse("(sin(v1)^2-sin(v1))(sin(v1)^2+sin(v1))^2 - (1/4 sin(2v1)^2-1/2 sin(2v1)cos(v1)-2 cos(v1)^2+1/2 sin(2v1)cos(v1)^3+3 cos(v1)^4-cos(v1)^6)").unwrap(); + assert_eq!(e.zero_test(10, f64::EPSILON), ConditionResult::Inconclusive); + + let e = Atom::parse("x + (1+x)^2 + (x+2)*5").unwrap(); + assert_eq!(e.zero_test(10, f64::EPSILON), ConditionResult::False); + } } diff --git a/src/id.rs b/src/id.rs index b00795f0..fd92a636 100644 --- a/src/id.rs +++ b/src/id.rs @@ -122,6 +122,21 @@ impl Atom { self.as_view().contains(s.as_atom_view()) } + /// Check if the expression can be considered a polynomial in some variables, including + /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas + /// `f(x)+x` is not a polynomial. + /// + /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered + /// polynomial in `x^y`. + pub fn is_polynomial( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + ) -> Option>> { + self.as_view() + .is_polynomial(allow_not_expanded, allow_negative_powers) + } + /// Replace all occurrences of the pattern. pub fn replace_all( &self, @@ -344,6 +359,150 @@ impl<'a> AtomView<'a> { false } + /// Check if the expression can be considered a polynomial in some variables, including + /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas + /// `f(x)+x` is not a polynomial. + /// + /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered + /// polynomial in `x^y`. + pub fn is_polynomial( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + ) -> Option>> { + let mut vars = HashMap::default(); + let mut symbol_cache = HashSet::default(); + if self.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + &mut vars, + &mut symbol_cache, + ) { + symbol_cache.clear(); + for (k, v) in vars { + if v { + symbol_cache.insert(k); + } + } + + Some(symbol_cache) + } else { + None + } + } + + fn is_polynomial_impl( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + variables: &mut HashMap, bool>, + symbol_cache: &mut HashSet>, + ) -> bool { + if let Some(x) = variables.get(self) { + return *x; + } + + macro_rules! block_check { + ($e: expr) => { + symbol_cache.clear(); + $e.get_all_indeterminates_impl(true, symbol_cache); + for x in symbol_cache.drain() { + if variables.contains_key(&x) { + return false; + } else { + variables.insert(x, false); // disallow at any level + } + } + + variables.insert(*$e, true); // overwrites block above + }; + } + + match self { + AtomView::Num(_) => true, + AtomView::Var(_) => { + variables.insert(*self, true); + true + } + AtomView::Fun(_) => { + block_check!(self); + true + } + AtomView::Pow(pow_view) => { + // x^y is allowed if x and y do not appear elsewhere + let (base, exp) = pow_view.get_base_exp(); + + if let AtomView::Num(_) = exp { + let (positive, integer) = if let Ok(k) = i64::try_from(exp) { + (k >= 0, true) + } else { + (false, false) + }; + + if integer && (allow_negative_powers || positive) { + if variables.get(&base) == Some(&true) { + return true; + } + + if allow_not_expanded && positive { + // do not consider (x+y)^-2 a polynomial in x and y + return base.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ); + } + + // turn the base into a variable + block_check!(&base); + return true; + } + } + + block_check!(self); + true + } + AtomView::Mul(mul_view) => { + for child in mul_view { + if !allow_not_expanded { + if let AtomView::Add(_) = child { + if variables.get(&child) == Some(&true) { + continue; + } + + block_check!(&child); + continue; + } + } + + if !child.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ) { + return false; + } + } + true + } + AtomView::Add(add_view) => { + for child in add_view { + if !child.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ) { + return false; + } + } + true + } + } + } + /// Replace part of an expression by calling the map `m` on each subexpression. /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. /// A [Context] object is passed to the function, which contains information about the current position in the expression. @@ -3650,4 +3809,15 @@ mod test { let expr = p.replace_all(expr.as_view(), &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); } + + #[test] + fn is_polynomial() { + let e = Atom::parse("v1^2 + (1+v5)^3 / v1 + (1+v3)*(1+v4)^v7 + v1^2 + (v1+v2)^3").unwrap(); + let vars = e.as_view().is_polynomial(true, true).unwrap(); + assert_eq!(vars.len(), 5); + + let e = Atom::parse("(1+v5)^(3/2) / v6 + (1+v3)*(1+v4)^v7 + (v1+v2)^3").unwrap(); + let vars = e.as_view().is_polynomial(false, false).unwrap(); + assert_eq!(vars.len(), 5); + } }