diff --git a/msdsl/eqn/lds.py b/msdsl/eqn/lds.py index 5dde501..52a174f 100644 --- a/msdsl/eqn/lds.py +++ b/msdsl/eqn/lds.py @@ -3,6 +3,10 @@ import numpy as np import scipy.linalg +import sympy as sp +from msdsl.expr.extras import msdsl_ast_to_sympy +from msdsl.assignment import Assignment + class LDS: def __init__(self, A=None, B=None, C=None, D=None): # save settings @@ -55,6 +59,31 @@ def __str__(self): # return result return retval + + def convert_to_sympy(self, states, inputs, outputs): + state_strings = list(map(lambda x: str(x), states)) + inputs_strings = list(map(lambda x: str(x), inputs)) + outputs_strings = list(map(lambda x: str(x), outputs)) + # Convert state, input, and output strings to sympy symbols + states = sp.symbols(state_strings) + inputs = sp.symbols(inputs_strings) + outputs = sp.symbols(outputs_strings) + + # Convert numpy arrays to sympy matrices + A_sym = sp.Matrix(self.A) + B_sym = sp.Matrix(self.B) + C_sym = sp.Matrix(self.C) + D_sym = sp.Matrix(self.D) + + # Define state-space equations + state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs) + output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs) + + # Explicitly compute the derivatives (dot{x}) + state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq + + return state_ode, output_eq + class LdsCollection: def __init__(self): @@ -75,3 +104,68 @@ def append(self, lds: LDS): self.B = np.concatenate((self.B, B), axis=2) if self.B is not None else B self.C = np.concatenate((self.C, C), axis=2) if self.C is not None else C self.D = np.concatenate((self.D, D), axis=2) if self.D is not None else D + + + def convert_to_sympy_piecewise(self, states, inputs, outputs, sel_bits, sel_eqns): + # Convert states, inputs, and outputs to sympy symbols + state_strings = list(map(str, states)) + inputs_strings = list(map(str, inputs)) + outputs_strings = list(map(str, outputs)) + + states = sp.symbols(state_strings) + inputs = sp.symbols(inputs_strings) + outputs = sp.symbols(outputs_strings) + + # Convert sel_bits to SymPy symbols if they aren't already + sel_bits_sympy = [sp.Symbol(str(sel_bit)) if not isinstance(sel_bit, sp.Basic) else sel_bit for sel_bit in sel_bits] + + # Initialize lists for state and output equations with default expressions + state_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(states))] + output_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(outputs))] + + # Iterate over all possible configurations of sel_bits + + for k in range(self.A.shape[2]): # Number of scenarios + A_sym = sp.Matrix(self.A[:, :, k]) + B_sym = sp.Matrix(self.B[:, :, k]) + C_sym = sp.Matrix(self.C[:, :, k]) + D_sym = sp.Matrix(self.D[:, :, k]) + + # Define state-space equations + state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs) + output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs) + + # Compute derivatives (dot{x}) for the state equations + state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq + + # Create the condition for this scenario + condition = True + for i, sel_bit_sym in enumerate(sel_bits_sympy): + bit_value = (k >> i) & 1 + # Use logical AND to build up the condition + condition = sp.And(condition, sp.Eq(sel_bit_sym, bit_value)) + + + # Process sel_eqns to adjust the condition if needed + for sel_eqn in sel_eqns: + + + if isinstance(sel_eqn, Assignment): + signal = sp.Symbol(sel_eqn.signal.name) + expr = msdsl_ast_to_sympy(sel_eqn.expr) + + condition = condition.subs(signal, expr) + + # Assign the corresponding equations to the Piecewise objects + for i in range(len(states)): + state_eq_piecewise[i] = sp.Eq( sp.Derivative(states[i],sp.Symbol('t')), sp.Piecewise((state_ode[i], condition), (state_eq_piecewise[i], True))) + + for i in range(len(outputs)): + output_eq_piecewise[i] = sp.Eq(outputs[i], sp.Piecewise((output_eq[i], condition), (output_eq_piecewise[i], True))) + + # For each equation we need to substitute in + return state_eq_piecewise + output_eq_piecewise + + + + diff --git a/msdsl/expr/expr.py b/msdsl/expr/expr.py index cf5a708..0bd3287 100644 --- a/msdsl/expr/expr.py +++ b/msdsl/expr/expr.py @@ -6,6 +6,7 @@ from msdsl.expr.format import RealFormat, SIntFormat, UIntFormat, Format, IntFormat +import sympy as sp # constant wrapping def wrap_constant(operand): @@ -1039,6 +1040,9 @@ def mt19937(clk=None, rst=None, cke=None, seed=None): def lcg_op(clk=None, rst=None, cke=None, seed=None): return LCG(clk=clk, rst=rst, cke=cke, seed=seed) + + + # testing def main(): diff --git a/msdsl/expr/extras.py b/msdsl/expr/extras.py index 794397f..9969790 100644 --- a/msdsl/expr/extras.py +++ b/msdsl/expr/extras.py @@ -1,6 +1,9 @@ from typing import Union, List from numbers import Number, Integral -from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array +from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array, LessThan, GreaterThan, Product, Sum, EqualTo, Array, Constant +from msdsl.expr.signals import AnalogSignal, DigitalSignal + +import sympy as sp def all_between(x: List[ModelExpr], lo: Union[Number, ModelExpr], hi: Union[Number, ModelExpr]) -> ModelExpr: """ @@ -37,4 +40,35 @@ def if_(condition, then, else_): :param else_: Action to be executed for False case :return: Boolean """ - return array([else_, then], condition) \ No newline at end of file + return array([else_, then], condition) + +def msdsl_ast_to_sympy(ast): + """ + Convert an AST from msdsl to a sympy expression. + """ + if isinstance(ast, LessThan): + return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, GreaterThan): + return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Product): + accum = 1 # Corrected from 0 to 1 to properly accumulate products + for operand in ast.operands: + accum *= msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, Sum): + accum = 0 + for operand in ast.operands: + accum += msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, EqualTo): + return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Array): + elements = ast.operands[:-1] + address = msdsl_ast_to_sympy(ast.operands[-1]) + return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) + elif isinstance(ast, Constant): + return ast.value + elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): + return sp.Symbol(str(ast)) + else: + raise Exception(f"Unsupported AST node: {type(ast)}") \ No newline at end of file diff --git a/msdsl/model.py b/msdsl/model.py index b4df97e..8ab3068 100644 --- a/msdsl/model.py +++ b/msdsl/model.py @@ -1,4 +1,5 @@ -from collections import OrderedDict, Iterable +from collections import OrderedDict +from collections.abc import Iterable from itertools import chain from numbers import Integral, Number from typing import List, Set, Union @@ -30,8 +31,12 @@ from msdsl.function import GeneralFunction, Function, PlaceholderFunction, MultiFunction from msdsl.lfsr import LFSR +from msdsl.expr.extras import msdsl_ast_to_sympy + from scipy.signal import cont2discrete +import sympy as sp + class Bus: def __init__(self, signal: Signal, n: Integral): self.signal = signal @@ -778,7 +783,7 @@ def get_equation_io(self, eqn_sys: EqnSys): # determine sel_bits sel_bit_names = set(signal_names(eqn_sys.get_sel_bits())) sel_bits = self.get_signals(sel_bit_names) - + # return result return inputs, states, outputs, sel_bits @@ -794,8 +799,12 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N :param extra_outputs: List of internal variables in the system of equations that should be bound to analog signals. :param clk: Name of clock signal to use (None will default to `CLK_MSDSL) :param rst: Name of the reset signal to use (None will default to `RST_MSDSL) + + Returns an LDSCollection Object """ + + # set defaults extra_outputs = extra_outputs if extra_outputs is not None else [] @@ -804,7 +813,7 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N # analyze equation to find out knowns and unknowns inputs, states, outputs, sel_bits = self.get_equation_io(eqn_sys) - + # add the extra outputs as needed for extra_output in extra_outputs: if not isinstance(extra_output, Signal): @@ -832,6 +841,8 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N # add to collection of LDS systems collection.append(lds) + + # construct address for selection if len(sel_bits) > 0: @@ -843,6 +854,10 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N self.add_discrete_time_lds(collection=collection, inputs=inputs, states=states, outputs=outputs, sel=sel, clk=clk, rst=rst) + + return collection, inputs, states, outputs, sel_bits + + def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=None, sel=None, clk=None, rst=None): @@ -870,6 +885,45 @@ def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=No else: self.bind_name(outputs[row].name, expr) + def get_sympy_lds(self): + """ + Must be run after add_eqn_sys. + Returns an array of piecewise LDSs in the form of sympy equations. + Appends all assignments. + """ + + sympy_lds = [] + + #self.assignments + + for circuit in self.circuits: + filtered_states = list(filter(lambda x: not isinstance(x, DigitalInput), circuit.sel_bits)) + state_str = list(map(lambda x: str(x), filtered_states)) + + sel_eqns = self.get_assignments(state_str) + + + circuit_lds = circuit.collection.convert_to_sympy_piecewise(circuit.states, circuit.inputs, circuit.outputs, circuit.sel_bits, sel_eqns) + circuit.sympy_eqs = circuit_lds + sympy_lds += circuit_lds + + for symbol_name, assignment_obj in self.unmodified_assignments.items(): + sympy_lds.append(sp.Eq(sp.Symbol(symbol_name), msdsl_ast_to_sympy(assignment_obj.expr))) + + return sympy_lds + + def write_sympy_lds_to_file(self, filename): + """ + Must be run after compile(). + Writes the sympy equations to json format. + """ + + sympy_lds = self.diffeqs + + with open(filename, 'w') as f: + for eq in sympy_lds: + f.write(str(eq) + '\n') + def set_tf(self, input_: Signal, output: Signal, tf, clk=None, rst=None): """ Method to assign an output signal as a function of the input signal by applying a given transfer function. @@ -1012,10 +1066,24 @@ def make_circuit(self, clk=None, rst=None): def compile(self, gen: CodeGenerator): # compile circuits + self.unmodified_assignments = deepcopy(self.assignments) + for circuit in self.circuits: eqns = circuit.compile_to_eqn_list() - self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) + ldscollection, inputs, states, outputs, sel_bits = self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) + + circuit.collection = ldscollection #Assign the circuit.collection object so we may use it later. + circuit.inputs = inputs + circuit.states = states + circuit.outputs = outputs + circuit.sel_bits = sel_bits + + + + self.diffeqs = self.get_sympy_lds() + + # determine the I/Os and internal variables ios = [] internals = [] diff --git a/msdsl/util.py b/msdsl/util.py index 3cab005..9112d61 100644 --- a/msdsl/util.py +++ b/msdsl/util.py @@ -33,9 +33,42 @@ def warn(s): def list2dict(l): return {elem: k for k, elem in enumerate(l)} + +def msdsl_ast_to_sympy(ast): + """ + Convert an AST from msdsl to a sympy expression. + """ + if isinstance(ast, LessThan): + return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, GreaterThan): + return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Product): + accum = 1 # Corrected from 0 to 1 to properly accumulate products + for operand in ast.operands: + accum *= msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, Sum): + accum = 0 + for operand in ast.operands: + accum += msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, EqualTo): + return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Array): + elements = ast.operands[:-1] + address = msdsl_ast_to_sympy(ast.operands[-1]) + return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) + elif isinstance(ast, Constant): + return ast.value + elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): + return sp.Symbol(str(ast)) + else: + raise Exception(f"Unsupported AST node: {type(ast)}") + def main(): # list2dict tests print(list2dict(['a', 'b', 'c'])) if __name__ == '__main__': - main() \ No newline at end of file + main() +