Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing PyTorch/JAX export for logical_or, logical_and, and relu #433

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
98cca40
Refactor utility functions
MilesCranmer Sep 17, 2023
4713607
Move denoising functionality to separate file
MilesCranmer Sep 17, 2023
3ae241a
Move feature selection functionality to separate file
MilesCranmer Sep 17, 2023
ff2ef42
Mypy compatibility
MilesCranmer Sep 17, 2023
135a464
Move all deprecated functions to deprecated.py
MilesCranmer Sep 17, 2023
6c92e1c
Store `sr_options_` and rename state to `sr_state_`
MilesCranmer Sep 17, 2023
ff2f93a
Add missing sympy operators for boolean logic
MilesCranmer Sep 19, 2023
d5787b2
Add missing sympy operators for relu
MilesCranmer Sep 19, 2023
47823ba
Add functionality for piecewise export to torch
MilesCranmer Sep 22, 2023
73d0f8a
Clean up error message in exports
MilesCranmer Sep 22, 2023
f92a935
Implement relu, logical_or, logical_and
MilesCranmer Sep 22, 2023
2a20447
Remove unnecessary as_bool
MilesCranmer Sep 22, 2023
22b047a
Merge tag 'v0.16.4' into sympy-or
MilesCranmer Dec 14, 2023
11dea32
Replace Heaviside with piecewise
MilesCranmer Dec 14, 2023
208307d
Merge tag 'v0.16.4' into store-options
MilesCranmer Dec 14, 2023
f21e3d6
Merge branch 'master' into store-options
MilesCranmer Dec 14, 2023
50c1407
Merge branch 'store-options' into sympy-or
MilesCranmer Dec 14, 2023
cff611a
Apply suggestions from code review
MilesCranmer Jun 3, 2024
3f1524b
Update pysr/export_torch.py
MilesCranmer Jun 3, 2024
01e1a15
Update pysr/export_torch.py
MilesCranmer Jun 3, 2024
5c0a49a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
0f47a59
Merge tag 'v0.18.4' into sympy-or
MilesCranmer Jun 3, 2024
c008678
Merge branch 'master' into sympy-or
MilesCranmer Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pysr/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in JAX function mappings."
f"Function {expr.func} was not found in JAX function mappings. "
"Please add it to extra_jax_mappings in the format, e.g., "
"{sympy.sqrt: 'jnp.sqrt'}."
)
Expand Down
73 changes: 71 additions & 2 deletions pysr/export_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,64 @@ def _initialize_torch():

torch = _torch

# Allows PyTorch to map Piecewise functions:
def expr_cond_pair(expr, cond):
if isinstance(cond, torch.Tensor) and not isinstance(expr, torch.Tensor):
expr = torch.tensor(expr, dtype=cond.dtype, device=cond.device)
elif isinstance(expr, torch.Tensor) and not isinstance(cond, torch.Tensor):
cond = torch.tensor(cond, dtype=expr.dtype, device=expr.device)
else:
return expr, cond

# First, make sure expr and cond are same size:
if expr.shape != cond.shape:
if len(expr.shape) == 0:
expr = expr.expand(cond.shape)
elif len(cond.shape) == 0:
cond = cond.expand(expr.shape)
else:
raise ValueError(
"expr and cond must have same shape, or one must be a scalar."
)
return expr, cond

MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
def piecewise(*expr_conds):
output = None
already_used = None
for expr, cond in expr_conds:
if not isinstance(cond, torch.Tensor) and not isinstance(
expr, torch.Tensor
):
# When we just have scalars, have to do this a bit more complicated
# due to the fact that we need to evaluate on the correct device.
if output is None:
already_used = cond
output = expr if cond else 0.0
else:
if not isinstance(output, torch.Tensor):
output += expr if cond and not already_used else 0.0
already_used = already_used or cond
else:
expr = torch.tensor(
expr, dtype=output.dtype, device=output.device
).expand(output.shape)
output += torch.where(
cond & ~already_used, expr, torch.zeros_like(expr)
)
already_used = already_used | cond
else:
if output is None:
already_used = cond
output = torch.where(cond, expr, torch.zeros_like(expr))
else:
output += torch.where(
cond & ~already_used, expr, torch.zeros_like(expr)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
)
already_used = already_used | cond
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
return output

# TODO: Add test that makes sure tensors are on the same device

_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
Expand Down Expand Up @@ -81,6 +139,11 @@ def _initialize_torch():
sympy.Heaviside: torch.heaviside,
sympy.core.numbers.Half: (lambda: 0.5),
sympy.core.numbers.One: (lambda: 1.0),
sympy.logic.boolalg.Boolean: lambda x: x,
sympy.logic.boolalg.BooleanTrue: (lambda: True),
sympy.logic.boolalg.BooleanFalse: (lambda: False),
sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair,
sympy.Piecewise: piecewise,
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
}

class _Node(torch.nn.Module):
Expand Down Expand Up @@ -125,7 +188,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
self._torch_func = _func_lookup[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in Torch function mappings."
f"Function {expr.func} was not found in Torch function mappings. "
"Please add it to extra_torch_mappings in the format, e.g., "
"{sympy.sqrt: torch.sqrt}."
)
Expand Down Expand Up @@ -153,7 +216,13 @@ def forward(self, memodict):
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)
try:
return self._torch_func(*args)
except Exception as err:
# Add information about the current node to the error:
raise type(err)(
f"Error occurred in node {self._sympy_func} with args {args}"
)

class _SingleSymPyModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
Expand Down
40 changes: 28 additions & 12 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
Path to the temporary equations directory.
equation_file_ : str
Output equation file name produced by the julia backend.
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
sr_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
The state for the julia SymbolicRegression.jl backend post fitting.
sr_options_ : PyCall.jlwrap
The options used by `SymbolicRegression.jl`, created during
a call to `.fit`. You may use this to manually call functions
in `SymbolicRegression` which take an `::Options` argument.
equation_file_contents_ : list[pandas.DataFrame]
Contents of the equation file output by the Julia backend.
show_pickle_warnings_ : bool
Expand Down Expand Up @@ -1031,7 +1035,7 @@ def __getstate__(self):
serialization.

Thus, for `PySRRegressor` to support pickle serialization, the
`raw_julia_state_` attribute must be hidden from pickle. This will
`sr_state_` attribute must be hidden from pickle. This will
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
but does allow all other attributes of a fitted `PySRRegressor` estimator
to be serialized. Note: Jax and Torch format equations are also removed
Expand All @@ -1041,9 +1045,9 @@ def __getstate__(self):
show_pickle_warning = not (
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
)
if "raw_julia_state_" in state and show_pickle_warning:
if ("sr_state_" in state or "sr_options_" in state) and show_pickle_warning:
warnings.warn(
"raw_julia_state_ cannot be pickled and will be removed from the "
"sr_state_ and sr_options_ cannot be pickled and will be removed from the "
"serialized instance. This will prevent a `warm_start` fit of any "
"model that is deserialized via `pickle.load()`."
)
Expand All @@ -1055,7 +1059,10 @@ def __getstate__(self):
"serialized instance. When loading the model, please redefine "
f"`{state_key}` at runtime."
)
state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
state_keys_to_clear = [
"sr_state_",
"sr_options_",
] + state_keys_containing_lambdas
pickled_state = {
key: (None if key in state_keys_to_clear else value)
for key, value in state.items()
Expand Down Expand Up @@ -1105,6 +1112,14 @@ def equations(self): # pragma: no cover
)
return self.equations_

@property
def raw_julia_state_(self): # pragma: no cover
warnings.warn(
"PySRRegressor.raw_julia_state_ is now deprecated. "
"Please use PySRRegressor.sr_state_ instead.",
)
return self.sr_state_

def get_best(self, index=None):
"""
Get best equation using `model_selection`.
Expand Down Expand Up @@ -1605,7 +1620,7 @@ def _run(self, X, y, mutated_params, weights, seed):

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
options = SymbolicRegression.Options(
self.sr_options_ = SymbolicRegression.Options(
binary_operators=Main.eval(str(binary_operators).replace("'", "")),
unary_operators=Main.eval(str(unary_operators).replace("'", "")),
bin_constraints=bin_constraints,
Expand Down Expand Up @@ -1704,7 +1719,7 @@ def _run(self, X, y, mutated_params, weights, seed):

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
self.raw_julia_state_ = SymbolicRegression.equation_search(
self.sr_state_ = SymbolicRegression.equation_search(
Main.X,
Main.y,
weights=Main.weights,
Expand All @@ -1714,10 +1729,10 @@ def _run(self, X, y, mutated_params, weights, seed):
y_variable_names=y_variable_names,
X_units=self.X_units_,
y_units=self.y_units_,
options=options,
options=self.sr_options_,
numprocs=cprocs,
parallelism=parallelism,
saved_state=self.raw_julia_state_,
saved_state=self.sr_state_,
return_state=True,
addprocs_function=cluster_manager,
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
Expand Down Expand Up @@ -1786,10 +1801,10 @@ def fit(
Fitted estimator.
"""
# Init attributes that are not specified in BaseEstimator
if self.warm_start and hasattr(self, "raw_julia_state_"):
if self.warm_start and hasattr(self, "sr_state_"):
pass
else:
if hasattr(self, "raw_julia_state_"):
if hasattr(self, "sr_state_"):
warnings.warn(
"The discovered expressions are being reset. "
"Please set `warm_start=True` if you wish to continue "
Expand All @@ -1799,7 +1814,8 @@ def fit(
self.equations_ = None
self.nout_ = 1
self.selection_mask_ = None
self.raw_julia_state_ = None
self.sr_state_ = None
self.sr_options_ = None
self.X_units_ = None
self.y_units_ = None

Expand Down
4 changes: 2 additions & 2 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_high_precision_search_custom_loss(self):
from pysr.sr import Main

# We should have that the model state is now a Float64 hof:
Main.test_state = model.raw_julia_state_
Main.test_state = model.sr_state_
self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float64"))

def test_multioutput_custom_operator_quiet_custom_complexity(self):
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_empty_operators_single_input_warm_start(self):
from pysr.sr import Main

# We should have that the model state is now a Float32 hof:
Main.test_state = regressor.raw_julia_state_
Main.test_state = regressor.sr_state_
self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float32"))
# This should exit almost immediately, and use the old equations
regressor.fit(X, y)
Expand Down
Loading