Skip to content

Commit

Permalink
implement scalar comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jun 27, 2023
1 parent 748ef9e commit 9629bf4
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions cpp_ext/TorchValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,38 @@ void PyTorch_IntValue::bindDerived(ClassTy &c) {
&DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__lt__",
[](const PyTorch_IntValue &self,
const PyTorch_IntValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return lt(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__le__",
[](const PyTorch_IntValue &self,
const PyTorch_IntValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return le(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__gt__",
[](const PyTorch_IntValue &self,
const PyTorch_IntValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return gt(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__ge__",
[](const PyTorch_IntValue &self,
const PyTorch_IntValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return ge(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
py::implicitly_convertible<int, PyTorch_IntValue>();
py::implicitly_convertible<DType, PyTorch_IntValue>();
}
Expand Down Expand Up @@ -648,6 +680,30 @@ void PyTorch_FloatValue::bindDerived(ClassTy &c) {
return div(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__lt__",
[](const PyTorch_FloatValue &self,
const PyTorch_FloatValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return lt(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__gt__",
[](const PyTorch_FloatValue &self,
const PyTorch_FloatValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return gt(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
c.def(
"__ge__",
[](const PyTorch_FloatValue &self,
const PyTorch_FloatValue &other) -> PyTorch_BoolValue {
auto loc = getValueLocation(self);
return ge(self, other, &loc, &DefaultingPyInsertionPoint::resolve());
},
"other"_a);
py::implicitly_convertible<float, PyTorch_FloatValue>();
}

Expand Down

0 comments on commit 9629bf4

Please sign in to comment.