Skip to content

Commit 665d4a1

Browse files
committedMar 1, 2025··
fix: comparison operator parsing
1 parent 6a2dbad commit 665d4a1

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed
 

‎docs/operators.md

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ it can exponentially increase the search space.
1616
|--------------|------------|----------|
1717
| `+` | `max` | `logical_or`[^2] |
1818
| `-` | `min` | `logical_and`[^3]|
19-
| `*` | `greater`[^4] | |
20-
| `/` | `cond`[^5] | |
21-
| `^` | `mod` | |
19+
| `*` | `>`[^4] | |
20+
| `/` | `>=` | |
21+
| `^` | `<` | |
22+
| | `<=` | |
23+
| | `cond`[^5] | |
24+
| | `mod` | |
2225

2326
**Unary Operators**
2427

@@ -74,5 +77,5 @@ any invalid values over the training dataset.
7477
[^1]: However, you will need to define a sympy equivalent in `extra_sympy_mapping` if you want to use a function not in the above list.
7578
[^2]: `logical_or` is equivalent to `(x, y) -> (x > 0 || y > 0) ? 1 : 0`
7679
[^3]: `logical_and` is equivalent to `(x, y) -> (x > 0 && y > 0) ? 1 : 0`
77-
[^4]: `greater` is equivalent to `(x, y) -> x > y ? 1 : 0`
80+
[^4]: `>` is equivalent to `(x, y) -> x > y ? 1 : 0`
7881
[^5]: `cond` is equivalent to `(x, y) -> x > 0 ? y : 0`

‎pysr/julia_import.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,8 @@ def _import_juliacall():
6767
# Expose `D` operator:
6868
jl.seval("using SymbolicRegression: D")
6969

70+
# Expose other operators:
71+
jl.seval("using SymbolicRegression: less, greater_equal, less_equal")
72+
7073
jl.seval("using Pkg: Pkg")
7174
Pkg = jl.Pkg

‎pysr/test/test_main.py

+19
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,25 @@ def test_tensorboard_logger(self):
758758
# Verify model still works as expected
759759
self.assertLessEqual(model.get_best()["loss"], 1e-4)
760760

761+
def test_comparison_operator(self):
762+
X = self.rstate.randn(100, 2)
763+
y = ((X[:, 0] + X[:, 1]) < (X[:, 0] * X[:, 1])).astype(float)
764+
765+
model = PySRRegressor(
766+
binary_operators=["<", "+", "*"],
767+
**self.default_test_kwargs,
768+
early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity <= 7",
769+
)
770+
771+
model.fit(X, y)
772+
773+
best_equation = model.get_best()["equation"]
774+
self.assertIn("less", best_equation)
775+
self.assertLessEqual(model.get_best()["loss"], 1e-4)
776+
777+
y_pred = model.predict(X)
778+
np.testing.assert_array_almost_equal(y, y_pred, decimal=3)
779+
761780

762781
def manually_create_model(equations, feature_names=None):
763782
if feature_names is None:

0 commit comments

Comments
 (0)
Please sign in to comment.