Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 29, 2024
1 parent 925783c commit 1a45b90
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
6 changes: 2 additions & 4 deletions pysr/regressor_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ def __init__(
):
super().__init__(**kwargs)
self.recursive_history_length = recursive_history_length

def _variable_names(self, y, variable_names=None):
if not variable_names:
if y.shape[1] == 1:
return [
f"xt_{i}" for i in range(self.recursive_history_length, 0, -1)
]
return [f"xt_{i}" for i in range(self.recursive_history_length, 0, -1)]
else:
return [
f"x{i}t_{j}"
Expand Down
26 changes: 21 additions & 5 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,24 +715,40 @@ def test_sequence_2D_data(self):
)
model.fit(X)
self.assertLessEqual(model.get_best()[0]["loss"], 1e-4)

def test_sequence_variable_names(self):
model = PySRSequenceRegressor(
**self.default_test_kwargs,
)
y = np.ones((5, 3))
sequence_variable_names = model._variable_names(y)
print(sequence_variable_names)
self.assertListEqual(sequence_variable_names, ["x0t_3", "x1t_3", "x2t_3", "x0t_2", "x1t_2", "x2t_2", "x0t_1", "x1t_1", "x2t_1"])

self.assertListEqual(
sequence_variable_names,
[
"x0t_3",
"x1t_3",
"x2t_3",
"x0t_2",
"x1t_2",
"x2t_2",
"x0t_1",
"x1t_1",
"x2t_1",
],
)

def test_sequence_custom_variable_names(self):
model = PySRSequenceRegressor(
**self.default_test_kwargs,
)
variable_names = ["a", "b", "c"]
y = np.array([[1] * 5]*3)
y = np.array([[1] * 5] * 3)
sequence_variable_names = model._variable_names(y, variable_names)
self.assertListEqual(sequence_variable_names, ["at_3", "bt_3", "ct_3", "at_2", "bt_2", "ct_2", "at_1", "bt_1", "ct_1"])
self.assertListEqual(
sequence_variable_names,
["at_3", "bt_3", "ct_3", "at_2", "bt_2", "ct_2", "at_1", "bt_1", "ct_1"],
)

def test_unused_variables(self):
X = [1, 1]
Expand Down

0 comments on commit 1a45b90

Please sign in to comment.