Skip to content

Commit

Permalink
Now support user-defined torch_numeric.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Sep 24, 2024
1 parent ec612c7 commit b8073a7
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion cvxtorch/torch_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def apply_torch_numeric(self, expr: Expression, values: list[torch.Tensor]) -> t
This function returns self.torch_numeric(values) if it exists,
and self.numeric(values) otherwise.
"""
torch_numeric = EXPR2TORCH.get(type(expr))
# torch_numeric = EXPR2TORCH.get(type(expr))
torch_numeric = get_torch_numeric(expr)
if torch_numeric:
return torch_numeric.torch_numeric(expr, values)
elif not self.implemented_only:
Expand All @@ -340,3 +341,16 @@ def apply_torch_numeric(self, expr: Expression, values: list[torch.Tensor]) -> t
raise NotImplementedError(f"torch_numeric function of {type(expr)} is not implemented."
f"If you want to use CVXPY's numeric instead, pass"
f"implemented_only=False.")

def get_torch_numeric(expr: Expression) -> callable:
"""
This function returns the torch_numeric function of this atom.
It supports creating custom torch_numeric functions for atoms by the user.
If the user provides a torch_numeric function for the atom, this function will return it.
Otherwise, get the default one provided by this pacjage (from EXPR2TORCH).
"""

torch_numeric = getattr(expr, "torch_numeric", None)
if torch_numeric:
return torch_numeric
return EXPR2TORCH.get(type(expr))

0 comments on commit b8073a7

Please sign in to comment.