Skip to content

Commit

Permalink
Add cyclesymmetric symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Aug 25, 2024
1 parent af2fa52 commit 3323324
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 25 deletions.
34 changes: 28 additions & 6 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,8 @@ macro_rules! req_wc_cmp {
impl PythonExpression {
/// Create a new symbol from a `name`. Symbols carry information about their attributes.
/// The symbol can signal that it is symmetric if it is used as a function
/// using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, and
/// using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`,
/// cyclesymmetric using `is_cyclesymmetric=True` and
/// multilinear using `is_linear=True`. If no attributes
/// are specified, the attributes are inherited from the symbol if it was already defined,
/// otherwise all attributes are set to `false`.
Expand Down Expand Up @@ -1523,15 +1524,24 @@ impl PythonExpression {
name: &str,
is_symmetric: Option<bool>,
is_antisymmetric: Option<bool>,
is_cyclesymmetric: Option<bool>,
is_linear: Option<bool>,
) -> PyResult<Self> {
if is_symmetric.is_none() && is_antisymmetric.is_none() && is_linear.is_none() {
if is_symmetric.is_none()
&& is_antisymmetric.is_none()
&& is_cyclesymmetric.is_none()
&& is_linear.is_none()
{
return Ok(Atom::new_var(State::get_symbol(name)).into());
}

if is_symmetric == Some(true) && is_antisymmetric == Some(true) {
let count = (is_symmetric == Some(true)) as u8
+ (is_antisymmetric == Some(true)) as u8
+ (is_cyclesymmetric == Some(true)) as u8;

if count > 1 {
Err(exceptions::PyValueError::new_err(
"Function cannot be both symmetric and antisymmetric",
"Function cannot be both symmetric, antisymmetric or cyclesymmetric",
))?;
}

Expand All @@ -1545,6 +1555,10 @@ impl PythonExpression {
opts.push(FunctionAttribute::Antisymmetric);
}

if let Some(true) = is_cyclesymmetric {
opts.push(FunctionAttribute::CycleSymmetric);
}

if let Some(true) = is_linear {
opts.push(FunctionAttribute::Linear);
}
Expand All @@ -1563,20 +1577,28 @@ impl PythonExpression {
/// >>> e = f(1,x)
/// >>> print(e)
/// f(1,x)
#[pyo3(signature = (*args,is_symmetric=None,is_antisymmetric=None,is_linear=None))]
#[pyo3(signature = (*args,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None))]
#[classmethod]
pub fn symbols(
cls: &PyType,
args: &PyTuple,
is_symmetric: Option<bool>,
is_antisymmetric: Option<bool>,
is_cyclesymmetric: Option<bool>,
is_linear: Option<bool>,
) -> PyResult<Vec<PythonExpression>> {
let mut result = Vec::with_capacity(args.len());

for a in args {
let name = a.extract::<&str>()?;
let s = Self::symbol(cls, name, is_symmetric, is_antisymmetric, is_linear)?;
let s = Self::symbol(
cls,
name,
is_symmetric,
is_antisymmetric,
is_cyclesymmetric,
is_linear,
)?;
result.push(s);
}

Expand Down
8 changes: 8 additions & 0 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct Symbol {
wildcard_level: u8,
is_symmetric: bool,
is_antisymmetric: bool,
is_cyclesymmetric: bool,
is_linear: bool,
}

Expand All @@ -48,6 +49,7 @@ impl Symbol {
wildcard_level,
is_symmetric: false,
is_antisymmetric: false,
is_cyclesymmetric: false,
is_linear: false,
}
}
Expand All @@ -59,13 +61,15 @@ impl Symbol {
wildcard_level: u8,
is_symmetric: bool,
is_antisymmetric: bool,
is_cyclesymmetric: bool,
is_linear: bool,
) -> Self {
Symbol {
id,
wildcard_level,
is_symmetric,
is_antisymmetric,
is_cyclesymmetric,
is_linear,
}
}
Expand All @@ -86,6 +90,10 @@ impl Symbol {
self.is_antisymmetric
}

pub fn is_cyclesymmetric(&self) -> bool {
self.is_cyclesymmetric
}

pub fn is_linear(&self) -> bool {
self.is_linear
}
Expand Down
36 changes: 28 additions & 8 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const VAR_WILDCARD_LEVEL_3: u8 = 0b00011000;
const FUN_SYMMETRIC_FLAG: u8 = 0b00100000;
const FUN_LINEAR_FLAG: u8 = 0b01000000;
const VAR_ANTISYMMETRIC_FLAG: u8 = 0b10000000;
const VAR_CYCLESYMMETRIC_FLAG: u8 = 0b10100000; // coded as symmetric | antisymmetric
const FUN_ANTISYMMETRIC_FLAG: u64 = 1 << 32; // stored in the function id
const MUL_HAS_COEFF_FLAG: u8 = 0b01000000;

Expand Down Expand Up @@ -64,6 +65,9 @@ impl InlineVar {
if symbol.is_antisymmetric {
flags |= VAR_ANTISYMMETRIC_FLAG;
}
if symbol.is_cyclesymmetric {
flags |= VAR_CYCLESYMMETRIC_FLAG;
}

data[0] = flags;

Expand Down Expand Up @@ -321,6 +325,9 @@ impl Var {
if symbol.is_antisymmetric {
flags |= VAR_ANTISYMMETRIC_FLAG;
}
if symbol.is_cyclesymmetric {
flags |= VAR_CYCLESYMMETRIC_FLAG;
}

self.data.put_u8(flags);

Expand Down Expand Up @@ -391,7 +398,7 @@ impl Fun {
_ => flags |= VAR_WILDCARD_LEVEL_3,
}

if symbol.is_symmetric {
if symbol.is_symmetric || symbol.is_cyclesymmetric {
flags |= FUN_SYMMETRIC_FLAG;
}
if symbol.is_linear {
Expand All @@ -404,7 +411,7 @@ impl Fun {

let buf_pos = self.data.len();

let id = if symbol.is_antisymmetric {
let id = if symbol.is_antisymmetric || symbol.is_cyclesymmetric {
symbol.id as u64 | FUN_ANTISYMMETRIC_FLAG
} else {
symbol.id as u64
Expand Down Expand Up @@ -902,11 +909,14 @@ impl<'a> VarView<'a> {

#[inline(always)]
pub fn get_symbol(&self) -> Symbol {
let is_cyclesymmetric = self.data[0] & VAR_CYCLESYMMETRIC_FLAG != 0;

Symbol::init_fn(
self.data[1..].get_frac_u64().0 as u32,
self.get_wildcard_level(),
self.data[0] & FUN_SYMMETRIC_FLAG != 0,
self.data[0] & VAR_ANTISYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & VAR_ANTISYMMETRIC_FLAG != 0,
is_cyclesymmetric,
self.data[0] & FUN_LINEAR_FLAG != 0,
)
}
Expand Down Expand Up @@ -993,24 +1003,34 @@ impl<'a> FunView<'a> {
pub fn get_symbol(&self) -> Symbol {
let id = self.data[1 + 4..].get_frac_u64().0;

let is_cyclesymmetric =
self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG != 0;

Symbol::init_fn(
id as u32,
self.get_wildcard_level(),
self.is_symmetric(),
id & FUN_ANTISYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0,
!is_cyclesymmetric && id & FUN_ANTISYMMETRIC_FLAG != 0,
is_cyclesymmetric,
self.is_linear(),
)
}

#[inline(always)]
pub fn is_symmetric(&self) -> bool {
self.data[0] & FUN_SYMMETRIC_FLAG != 0
let id = self.data[1 + 4..].get_frac_u64().0;
self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG == 0
}

#[inline(always)]
pub fn is_antisymmetric(&self) -> bool {
let id = self.data[1 + 4..].get_frac_u64().0;
id & FUN_ANTISYMMETRIC_FLAG != 0
!self.is_symmetric() && id & FUN_ANTISYMMETRIC_FLAG != 0
}

#[inline(always)]
pub fn is_cyclesymmetric(&self) -> bool {
self.is_symmetric() && self.is_antisymmetric()
}

#[inline(always)]
Expand Down
33 changes: 33 additions & 0 deletions src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,39 @@ impl<'a> AtomView<'a> {
}

out_f.set_normalized(true);
} else if id.is_cyclesymmetric() {
let mut args: SmallVec<[_; 20]> = SmallVec::new();
for a in out_f.to_fun_view().iter() {
args.push(a);
}

let mut best_shift = 0;
'shift: for shift in 1..args.len() {
for i in 0..args.len() {
match args[(i + best_shift) % args.len()]
.cmp(&args[(i + shift) % args.len()])
{
std::cmp::Ordering::Equal => {}
std::cmp::Ordering::Less => {
continue 'shift;
}
std::cmp::Ordering::Greater => break,
}
}

best_shift = shift;
}

let mut f = workspace.new_atom();
let ff = f.to_fun(id);
for arg in args[best_shift..].iter().chain(&args[..best_shift]) {
ff.add_arg(*arg);
}

drop(args);

ff.set_normalized(true);
std::mem::swap(ff, out_f);
}
}
AtomView::Pow(p) => {
Expand Down
19 changes: 11 additions & 8 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct VariableListIndex(pub(crate) usize);
pub enum FunctionAttribute {
Symmetric,
Antisymmetric,
Cyclesymmetric,
Linear,
}

Expand Down Expand Up @@ -79,14 +80,14 @@ impl Default for State {
}

impl State {
pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false);
pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false);
pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false);
pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false);
pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false);
pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false);
pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false);
pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false);
pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false);
pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false);
pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false);
pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false);
pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false);
pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false);
pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false);
pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false);
pub const E: Symbol = Symbol::init_var(8, 0);
pub const I: Symbol = Symbol::init_var(9, 0);
pub const PI: Symbol = Symbol::init_var(10, 0);
Expand Down Expand Up @@ -265,6 +266,7 @@ impl State {
r.get_wildcard_level(),
attributes.contains(&FunctionAttribute::Symmetric),
attributes.contains(&FunctionAttribute::Antisymmetric),
attributes.contains(&FunctionAttribute::Cyclesymmetric),
attributes.contains(&FunctionAttribute::Linear),
);

Expand Down Expand Up @@ -297,6 +299,7 @@ impl State {
wildcard_level,
attributes.contains(&FunctionAttribute::Symmetric),
attributes.contains(&FunctionAttribute::Antisymmetric),
attributes.contains(&FunctionAttribute::Cyclesymmetric),
attributes.contains(&FunctionAttribute::Linear),
);

Expand Down
12 changes: 9 additions & 3 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,17 @@ class Expression:
"""The built-in logarithm function."""

@classmethod
def symbol(_cls, name: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Expression:
def symbol(_cls,
name: str,
is_symmetric: Optional[bool] = None,
is_antisymmetric: Optional[bool] = None,
is_cyclesymmetric: Optional[bool] = None,
is_linear: Optional[bool] = None) -> Expression:
"""
Create a new symbol from a `name`. Symbols carry information about their attributes.
The symbol can signal that it is symmetric if it is used as a function
using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, and
using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`,
cyclesymmetric using `is_cyclesymmetric=True`, and
multilinear using `is_linear=True`. If no attributes
are specified, the attributes are inherited from the symbol if it was already defined,
otherwise all attributes are set to `false`.
Expand Down Expand Up @@ -154,7 +160,7 @@ class Expression:
"""

@classmethod
def symbols(_cls, *names: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Sequence[Expression]:
def symbols(_cls, *names: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_cyclesymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Sequence[Expression]:
"""
Create a Symbolica symbol for every name in `*names`. See `Expression.symbol` for more information.
Expand Down

0 comments on commit 3323324

Please sign in to comment.