diff --git a/cpp_ext/TorchValues.cpp b/cpp_ext/TorchValues.cpp index a2c31ae..a7ef850 100644 --- a/cpp_ext/TorchValues.cpp +++ b/cpp_ext/TorchValues.cpp @@ -616,6 +616,38 @@ void PyTorch_IntValue::bindDerived(ClassTy &c) { &DefaultingPyInsertionPoint::resolve()); }, "other"_a); + c.def( + "__lt__", + [](const PyTorch_IntValue &self, + const PyTorch_IntValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return lt(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); + c.def( + "__le__", + [](const PyTorch_IntValue &self, + const PyTorch_IntValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return le(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); + c.def( + "__gt__", + [](const PyTorch_IntValue &self, + const PyTorch_IntValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return gt(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); + c.def( + "__ge__", + [](const PyTorch_IntValue &self, + const PyTorch_IntValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return ge(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); py::implicitly_convertible(); py::implicitly_convertible(); } @@ -648,6 +680,30 @@ void PyTorch_FloatValue::bindDerived(ClassTy &c) { return div(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); }, "other"_a); + c.def( + "__lt__", + [](const PyTorch_FloatValue &self, + const PyTorch_FloatValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return lt(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); + c.def( + "__gt__", + [](const PyTorch_FloatValue &self, + const PyTorch_FloatValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return gt(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); + c.def( + "__ge__", + [](const PyTorch_FloatValue &self, + const PyTorch_FloatValue &other) -> PyTorch_BoolValue { + auto loc = getValueLocation(self); + return ge(self, other, &loc, &DefaultingPyInsertionPoint::resolve()); + }, + "other"_a); py::implicitly_convertible(); }