Skip to content

Commit

Permalink
Make dwelltime analysis faster (#413)
Browse files Browse the repository at this point in the history
* fixes

* lint
  • Loading branch information
ordabayevy authored Jan 26, 2023
1 parent 05bb6e8 commit 623e272
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tapqir/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ def postUI(out):
dt_layout.add_child(
"num_samples",
widgets.IntText(
value=2000,
value=500,
description="Number of posterior samples",
style={"description_width": "initial"},
),
Expand Down
14 changes: 9 additions & 5 deletions tapqir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def dwelltime(
show_default=False,
),
num_samples: int = typer.Option(
2000,
500,
"--num-samples",
"-n",
help="Number of posterior samples",
Expand Down Expand Up @@ -1215,9 +1215,9 @@ def dwelltime(
for c in range(model.data.C):
logger.info(f"Channel #{c} ({model.data.channels[c]})")
intervals = count_intervals(z_samples_masked[..., c])
intervals.to_csv(cd / f"{model.name}_dwelltime-intervals-channel{c}.csv")
intervals.to_pickle(cd / f"{model.name}_dwelltime-intervals-channel{c}.pkl")
logger.info(
f"Saved time intervals in {model.name}_dwelltime-intervals-channel{c}.csv file"
f"Saved time intervals in {model.name}_dwelltime-intervals-channel{c}.pkl file"
)

logger.info("Off-rate calculation ...")
Expand Down Expand Up @@ -1273,7 +1273,9 @@ def dwelltime(
ax.hist(
bound_dwell_times(
count_intervals(
model.params["z_map"][None, model.data.mask[: model.data.N], :, c]
model.params["z_map"][: model.data.N][
None, model.data.mask[: model.data.N], :, c
]
)
)[0],
bins=100,
Expand Down Expand Up @@ -1349,7 +1351,9 @@ def dwelltime(
ax.hist(
unbound_dwell_times(
count_intervals(
model.params["z_map"][None, model.data.mask[: model.data.N], :, c]
model.params["z_map"][: model.data.N][
None, model.data.mask[: model.data.N], :, c
]
)
)[0],
bins=100,
Expand Down
4 changes: 2 additions & 2 deletions tapqir/utils/imscroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _(labels):
stop_type[..., -1] += 2
stop_type = stop_type[stop_sample, stop_aoi, stop_frame]

assert all(start_aoi == stop_aoi)
assert np.array_equal(start_aoi, stop_aoi)

low_or_high = np.where(abs(start_type) > abs(stop_type), start_type, stop_type)
z_type = z[start_sample, start_aoi, start_frame]
Expand Down Expand Up @@ -91,7 +91,7 @@ def _(labels):
stop_type[..., -1] += 2
stop_type = stop_type[stop_sample, stop_aoi, stop_frame]

assert all(start_aoi == stop_aoi)
assert torch.equal(start_aoi, stop_aoi)

low_or_high = torch.where(abs(start_type) > abs(stop_type), start_type, stop_type)
z_type = z[start_sample, start_aoi, start_frame]
Expand Down

0 comments on commit 623e272

Please sign in to comment.