Skip to content

Commit

Permalink
optimize data indexing (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Feb 27, 2023
1 parent 9c27683 commit 318e8fc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
8 changes: 2 additions & 6 deletions tapqir/models/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def model(self):
dim=-1,
)

# spots
spots = pyro.plate("spots", self.K)
# aoi sites
aois = pyro.plate(
"aois",
Expand Down Expand Up @@ -258,7 +256,7 @@ def model(self):
onehot_theta = one_hot(theta, num_classes=1 + self.K)

ms, heights, widths, xs, ys = [], [], [], [], []
for kdx in spots:
for kdx in range(self.K):
specific = onehot_theta[..., 1 + kdx]
# spot presence
m = pyro.sample(
Expand Down Expand Up @@ -369,8 +367,6 @@ def guide(self):
),
)

# spots
spots = pyro.plate("spots", self.K)
# aoi sites
aois = pyro.plate(
"aois",
Expand Down Expand Up @@ -418,7 +414,7 @@ def guide(self):
),
)

for kdx in spots:
for kdx in range(self.K):
# sample spot presence m
m = pyro.sample(
f"m_k{kdx}",
Expand Down
12 changes: 9 additions & 3 deletions tapqir/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def median(self) -> torch.Tensor:
)

def fetch(self, ndx, fdx, cdx):
if isinstance(ndx, torch.Tensor):
ndx = ndx.cpu()
if isinstance(fdx, torch.Tensor):
fdx = fdx.cpu()
if isinstance(cdx, torch.Tensor):
cdx = cdx.cpu()
return (
Vindex(self.images.to(self.device))[ndx, fdx, cdx],
Vindex(self.xy.to(self.device))[ndx, fdx, cdx],
Vindex(self.is_ontarget.to(self.device))[ndx],
Vindex(self.images)[ndx, fdx, cdx].to(self.device),
Vindex(self.xy)[ndx, fdx, cdx].to(self.device),
Vindex(self.is_ontarget)[ndx].to(self.device),
)

@lazy_property
Expand Down

0 comments on commit 318e8fc

Please sign in to comment.