Skip to content

Commit

Permalink
fix fetch device (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Nov 20, 2022
1 parent 034bf5e commit 900973e
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tapqir/models/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def model(self):

with channels as cdx, aois as ndx:
ndx = ndx[:, None, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
# background mean and std
background_mean = pyro.sample(
Expand Down
2 changes: 1 addition & 1 deletion tapqir/models/crosstalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def guide(self):

with aois as ndx:
ndx = ndx[:, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
pyro.sample(
"background_mean",
Expand Down
2 changes: 1 addition & 1 deletion tapqir/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def guide(self):

with channels as cdx, aois as ndx:
ndx = ndx[:, None, None]
mask = Vindex(self.data.mask)[ndx].to(self.device)
mask = Vindex(self.data.mask.to(self.device))[ndx]
with handlers.mask(mask=mask):
pyro.sample(
"background_mean",
Expand Down
8 changes: 0 additions & 8 deletions tapqir/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,6 @@ def to(self, device: str, dtype: str = "double") -> None:
torch.set_default_tensor_type(torch.DoubleTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
# change loaded data device
if hasattr(self, "data"):
self.data.ontarget = self.data.ontarget._replace(device=self.device)
self.data.offtarget = self.data.offtarget._replace(device=self.device)
self.data.offset = self.data.offset._replace(
samples=self.data.offset.samples.to(self.device),
weights=self.data.offset.weights.to(self.device),
)

@property
def Q(self):
Expand Down
6 changes: 3 additions & 3 deletions tapqir/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def median(self) -> torch.Tensor:

def fetch(self, ndx, fdx, cdx):
return (
Vindex(self.images)[ndx, fdx, cdx].to(self.device),
Vindex(self.xy)[ndx, fdx, cdx].to(self.device),
Vindex(self.is_ontarget)[ndx].to(self.device),
Vindex(self.images.to(self.device))[ndx, fdx, cdx],
Vindex(self.xy.to(self.device))[ndx, fdx, cdx],
Vindex(self.is_ontarget.to(self.device))[ndx],
)

@lazy_property
Expand Down

0 comments on commit 900973e

Please sign in to comment.