diff --git a/tapqir/models/cosmos.py b/tapqir/models/cosmos.py index 1e5717c..8917fea 100644 --- a/tapqir/models/cosmos.py +++ b/tapqir/models/cosmos.py @@ -190,8 +190,6 @@ def model(self): dim=-1, ) - # spots - spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", @@ -258,7 +256,7 @@ def model(self): onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] - for kdx in spots: + for kdx in range(self.K): specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( @@ -369,8 +367,6 @@ def guide(self): ), ) - # spots - spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", @@ -418,7 +414,7 @@ def guide(self): ), ) - for kdx in spots: + for kdx in range(self.K): # sample spot presence m m = pyro.sample( f"m_k{kdx}", diff --git a/tapqir/utils/dataset.py b/tapqir/utils/dataset.py index 70dbb78..931fbfd 100644 --- a/tapqir/utils/dataset.py +++ b/tapqir/utils/dataset.py @@ -138,10 +138,16 @@ def median(self) -> torch.Tensor: ) def fetch(self, ndx, fdx, cdx): + if isinstance(ndx, torch.Tensor): + ndx = ndx.cpu() + if isinstance(fdx, torch.Tensor): + fdx = fdx.cpu() + if isinstance(cdx, torch.Tensor): + cdx = cdx.cpu() return ( - Vindex(self.images.to(self.device))[ndx, fdx, cdx], - Vindex(self.xy.to(self.device))[ndx, fdx, cdx], - Vindex(self.is_ontarget.to(self.device))[ndx], + Vindex(self.images)[ndx, fdx, cdx].to(self.device), + Vindex(self.xy)[ndx, fdx, cdx].to(self.device), + Vindex(self.is_ontarget)[ndx].to(self.device), ) @lazy_property