Skip to content

Commit

Permalink
fix mask error (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Nov 19, 2022
1 parent 5809962 commit 034bf5e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tapqir/models/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def guide(self):

with channels as cdx, aois as ndx:
ndx = ndx[:, None, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
pyro.sample(
"background_mean",
Expand Down
2 changes: 1 addition & 1 deletion tapqir/models/crosstalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def model(self):

with aois as ndx:
ndx = ndx[:, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
# background mean and std
background_mean = pyro.sample(
Expand Down
2 changes: 1 addition & 1 deletion tapqir/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def model(self):

with channels as cdx, aois as ndx:
ndx = ndx[:, None, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
# background mean and std
background_mean = pyro.sample(
Expand Down

0 comments on commit 034bf5e

Please sign in to comment.