Replies: 2 comments 1 reply
-
I believe the OT library doesn't have a built-in feature that directly computes the gradient of the GW distance concerning one of the datasets. Particularly, within DL frameworks like PyTorch, there isn't a straightforward way to achieve this using the library's existing functions. The |
Beta Was this translation helpful? Give feedback.
-
gwloss in used internally in the ot.gromov_wassersytein(2) solver but yes it uses only pytorch operation when being computed on pytorch tensors so it should provide proper gradients (you should check numerically) wrt the data since it uses only pytorch operation in this case. Still keep in mind that gwloss returns the loss for a given OT plan not the optimal one for that you need to ue ot.gromov_wassretsein2 function that defined all gadients properly (also wrt th marginal weights and data.) |
Beta Was this translation helpful? Give feedback.
-
I am currently utilizing the ot.gromov.gwloss and ot.gromov.init_matrix functions from the OT library to compute the gw distance between two datasets and, eventually, its gradient with regard to one of the datasets. While integrating the said function, I stumbled upon a critical query.
Given that the ot.gromov.gwloss function seems to accommodate an abstract backend approach (with nx.backend), I have passed PyTorch tensors as inputs, hoping to maintain the computational graph and utilize PyTorch's autograd mechanism. While the function does return a torch tensor as output, I am uncertain about whether the operations inside ot.gromov.gwloss are PyTorch native. While I have observed methods like nx.dot, which appear to adapt based on the said backend, I wanted to ensure that there aren't any underlying numpy operations that would break the computational graph.
My questions are:
Can you confirm that when PyTorch tensors are input to ot.gromov.gwloss, all internal operations (and any auxiliary functions it calls) remain fully PyTorch-native and differentiable? In a way that guarantees the correctness of gradients obtained via .backward() when using this function with PyTorch tensors?
If no, does the ot library accommodate taking the gradient of the GW distance with regard to one of the datasets in any manner? (I have already looked at the ot.gromov.gwggrad function, but apparently it does not compute the gradient that I want; i.e. , with regard to one of the datasets only.)
Beta Was this translation helpful? Give feedback.
All reactions