sinkhorn2
and its functorch.vmap
compatibility
#482
Labels
sinkhorn2
and its functorch.vmap
compatibility
#482
🚀 Feature
Making the
ot.sinkhorn2
function compatible withfunctorch.vmap
.Motivation
I'm using the
Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates thesinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:but this is way too slow for my application. I was reading through
functorch
, and apparently I should have been able to use thevmap
functionality.But after wrapping my function in
vmap
, I get this weird error:Pitch
Apparently, the data-dependent
if-statement
needs to be replaced with other alternatives. Any help is appreciated.The text was updated successfully, but these errors were encountered: