diff --git a/tapqir/models/cosmos.py b/tapqir/models/cosmos.py index 6c8e2bf..d2941f9 100644 --- a/tapqir/models/cosmos.py +++ b/tapqir/models/cosmos.py @@ -395,7 +395,7 @@ def guide(self): with channels as cdx, aois as ndx: ndx = ndx[:, None, None] - mask = Vindex(self.data.mask)[ndx].to(self.device) + mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): pyro.sample( "background_mean", diff --git a/tapqir/models/crosstalk.py b/tapqir/models/crosstalk.py index 5027fa7..ec623b0 100644 --- a/tapqir/models/crosstalk.py +++ b/tapqir/models/crosstalk.py @@ -126,7 +126,7 @@ def model(self): with aois as ndx: ndx = ndx[:, None] - mask = Vindex(self.data.mask)[ndx].to(self.device) + mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( diff --git a/tapqir/models/hmm.py b/tapqir/models/hmm.py index 4beeeaf..5b51d06 100644 --- a/tapqir/models/hmm.py +++ b/tapqir/models/hmm.py @@ -138,7 +138,7 @@ def model(self): with channels as cdx, aois as ndx: ndx = ndx[:, None, None] - mask = Vindex(self.data.mask)[ndx].to(self.device) + mask = Vindex(self.data.mask.to(self.device))[ndx] with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample(