Skip to content

Commit

Permalink
Add replace_map for replacements using a function
Browse files Browse the repository at this point in the history
- Do not set a static trace level
- Allow transformers on rhs of contains
- Add is_type condition on transformers
  • Loading branch information
benruijl committed Dec 4, 2024
1 parent 88a9cbd commit fe31c83
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 7 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ crate-type = ["lib"]
name = "symbolica"

[features]
default = []
default = ["tracing_only_warnings"]
# if using this, make sure jemalloc is compiled with --disable-initial-exec-tls
# if symbolica is used as a dynamic library (as is the case for the Python API)
faster_alloc = ["tikv-jemallocator"]
Expand All @@ -37,6 +37,7 @@ python_api = ["pyo3", "bincode"]
python_no_module = ["python_api"]
# build a module that is independent of the specific Python version
python_abi3 = ["pyo3/abi3", "pyo3/abi3-py37"]
tracing_only_warnings = ["tracing/release_max_level_warn"]

[dependencies.pyo3]
features = ["extension-module", "abi3", "py-clone"]
Expand Down Expand Up @@ -67,6 +68,6 @@ smallvec = "1.13"
smartstring = "1.0"
tikv-jemallocator = {version = "0.5.4", optional = true}
tinyjson = "2.5"
tracing = {version = "0.1", features = ["max_level_trace", "release_max_level_warn"]}
tracing = "0.1"
wide = "0.7"
wolfram-library-link = {version = "0.2.9", optional = true}
25 changes: 21 additions & 4 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,23 @@ impl PythonTransformer {
})
}

/// Test if the expression is of a certain type.
pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition {
PythonCondition {
condition: Condition::Yield(Relation::IsType(
self.expr.clone(),
match atom_type {
PythonAtomType::Num => AtomType::Num,
PythonAtomType::Var => AtomType::Var,
PythonAtomType::Add => AtomType::Add,
PythonAtomType::Mul => AtomType::Mul,
PythonAtomType::Pow => AtomType::Pow,
PythonAtomType::Fn => AtomType::Fun,
},
)),
}
}

/// Returns true iff `self` contains `a` literally.
///
/// Examples
Expand Down Expand Up @@ -3181,13 +3198,13 @@ impl PythonExpression {
/// >>> e.contains(x) # True
/// >>> e.contains(x*y*z) # True
/// >>> e.contains(x*y) # False
pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition {
PythonCondition {
pub fn contains(&self, s: ConvertibleToPattern) -> PyResult<PythonCondition> {
Ok(PythonCondition {
condition: Condition::Yield(Relation::Contains(
self.expr.into_pattern(),
s.to_expression().expr.into_pattern(),
s.to_pattern()?.expr,
)),
}
})
}

/// Get all symbols in the current expression, optionally including function symbols.
Expand Down
1 change: 1 addition & 0 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ impl<'a> AtomView<'a> {
}
}

/// A mathematical expression.
#[derive(Clone)]
pub enum Atom {
Num(Num),
Expand Down
162 changes: 162 additions & 0 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ impl Atom {
) -> bool {
self.as_view().replace_all_multiple_into(replacements, out)
}

/// Replace part of an expression by calling the map `m` on each subexpression.
/// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`.
/// A [Context] object is passed to the function, which contains information about the current position in the expression.
pub fn replace_map<F: Fn(AtomView, &Context, &mut Atom) -> bool>(&self, m: &F) -> Atom {
self.as_view().replace_map(m)
}
}

/// The context of an atom.
#[derive(Clone, Copy, Debug)]
pub struct Context {
/// The level of the function in the expression tree.
pub function_level: usize,
/// The type of the parent atom.
pub parent_type: Option<AtomType>,
/// The index of the atom in the parent.
pub index: usize,
}

impl<'a> AtomView<'a> {
Expand Down Expand Up @@ -326,6 +344,134 @@ impl<'a> AtomView<'a> {
false
}

/// Replace part of an expression by calling the map `m` on each subexpression.
/// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`.
/// A [Context] object is passed to the function, which contains information about the current position in the expression.
pub fn replace_map<F: Fn(AtomView, &Context, &mut Atom) -> bool>(&self, m: &F) -> Atom {
let mut out = Atom::new();
self.replace_map_into(m, &mut out);
out
}

/// Replace part of an expression by calling the map `m` on each subexpression.
/// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`.
/// A [Context] object is passed to the function, which contains information about the current position in the expression.
pub fn replace_map_into<F: Fn(AtomView, &Context, &mut Atom) -> bool>(
&self,
m: &F,
out: &mut Atom,
) {
let context = Context {
function_level: 0,
parent_type: None,
index: 0,
};
Workspace::get_local().with(|ws| {
self.replace_map_impl(ws, m, context, out);
});
}

fn replace_map_impl<F: Fn(AtomView, &Context, &mut Atom) -> bool>(
&self,
ws: &Workspace,
m: &F,
mut context: Context,
out: &mut Atom,
) -> bool {
if m(*self, &context, out) {
return true;
}

let mut changed = false;
match self {
AtomView::Num(_) | AtomView::Var(_) => {
out.set_from_view(self);
}
AtomView::Fun(f) => {
let mut fun = ws.new_atom();
let fun = fun.to_fun(f.get_symbol());

context.parent_type = Some(AtomType::Fun);
context.function_level += 1;

for (i, arg) in f.iter().enumerate() {
context.index = i;

let mut arg_h = ws.new_atom();
changed |= arg.replace_map_impl(ws, m, context, &mut arg_h);
fun.add_arg(arg_h.as_view());
}

if changed {
fun.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
AtomView::Pow(p) => {
let (base, exp) = p.get_base_exp();

context.parent_type = Some(AtomType::Pow);
context.index = 0;

let mut base_h = ws.new_atom();
changed |= base.replace_map_impl(ws, m, context, &mut base_h);

context.index = 1;
let mut exp_h = ws.new_atom();
changed |= exp.replace_map_impl(ws, m, context, &mut exp_h);

if changed {
let mut pow_h = ws.new_atom();
pow_h.to_pow(base_h.as_view(), exp_h.as_view());
pow_h.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
AtomView::Mul(mm) => {
let mut mul_h = ws.new_atom();
let mul = mul_h.to_mul();

context.parent_type = Some(AtomType::Mul);

for (i, child) in mm.iter().enumerate() {
context.index = i;
let mut child_h = ws.new_atom();
changed |= child.replace_map_impl(ws, m, context, &mut child_h);
mul.extend(child_h.as_view());
}

if changed {
mul_h.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
AtomView::Add(a) => {
let mut add_h = ws.new_atom();
let add = add_h.to_add();

context.parent_type = Some(AtomType::Add);

for (i, child) in a.iter().enumerate() {
context.index = i;
let mut child_h = ws.new_atom();
changed |= child.replace_map_impl(ws, m, context, &mut child_h);
add.extend(child_h.as_view());
}

if changed {
add_h.as_view().normalize(ws, out);
} else {
out.set_from_view(self);
}
}
}

changed
}

/// Replace all occurrences of the patterns, where replacements are tested in the order that they are given.
pub fn replace_all(
&self,
Expand Down Expand Up @@ -3313,6 +3459,22 @@ mod test {

use super::Pattern;

#[test]
fn replace_map() {
let a = Atom::parse("v1 + f1(1,2, f1((1+v1)^2), (v1+v2)^2)").unwrap();

let r = a.replace_map(&|arg, context, out| {
if context.function_level > 0 {
arg.expand_into(None, out)
} else {
false
}
});

let res = Atom::parse("v1+f1(1,2,f1(2*v1+v1^2+1),v1^2+v2^2+2*v1*v2)").unwrap();
assert_eq!(r, res);
}

#[test]
fn overlap() {
let a = Atom::parse("(v1*(v2+v2^2+1)+v2^2 + v2)").unwrap();
Expand Down
7 changes: 6 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class Expression:
transformations can be applied.
"""

def contains(self, a: Expression | int | float | Decimal) -> Condition:
def contains(self, a: Transformer | Expression | int | float | Decimal) -> Condition:
"""Returns true iff `self` contains `a` literally.
Examples
Expand Down Expand Up @@ -1606,6 +1606,11 @@ class Transformer:
Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used.
"""

def is_type(self, atom_type: AtomType) -> Condition:
"""
Test if the transformed expression is of a certain type.
"""

def contains(self, element: Transformer | Expression | int | float | Decimal) -> Condition:
"""
Create a transformer that checks if the expression contains the given `element`.
Expand Down

0 comments on commit fe31c83

Please sign in to comment.