Skip to content

Commit

Permalink
ttfb command (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Sep 8, 2022
1 parent 262f391 commit b5b1f04
Showing 1 changed file with 157 additions and 0 deletions.
157 changes: 157 additions & 0 deletions tapqir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,163 @@ def subset():
logger.info("Created a new data file at `subset/data.tpqr`")


@app.command()
def ttfb(
model: avail_models = typer.Option(
"cosmos", help="Tapqir model", prompt="Tapqir model"
),
):
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import pyro
import torch
from pyro import distributions as dist
from pyro.ops.stats import hpdi

from tapqir.models import models
from tapqir.utils.imscroll import time_to_first_binding
from tapqir.utils.mle_analysis import train, ttfb_guide, ttfb_model

logger = logging.getLogger("tapqir")

mpl.rc("text", usetex=True)
mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams.update({"font.size": 8})

global DEFAULTS
cd = DEFAULTS["cd"]

model = models[model](device="cpu", dtype="float")
try:
model.load(cd, data_only=False)
except TapqirFileNotFoundError as err:
logger.exception(f"Failed to load {err.name} file")
return 1

for c in range(model.data.C):
# sorted on-target
ttfb = time_to_first_binding(model.params["z_map"][: model.data.N, :, c])
# sort ttfb
sdx = torch.argsort(ttfb, descending=True)

fig, ax = plt.subplots()
norm = mpl.colors.Normalize(vmin=0, vmax=1)
ax.imshow(
model.params["z_probs"][: model.data.N, :, c][sdx],
norm=norm,
aspect="equal",
interpolation="none",
)
ax.set_xlabel("Time (frame)")
ax.set_ylabel("AOI")
ax.set_title(f"Channel {c}")
plt.savefig(f"ttfb_rastergram{c}.png", dpi=600)

fig, ax = plt.subplots()
# prepare data
Tmax = model.data.F
torch.manual_seed(0)
z = dist.Bernoulli(model.params["z_probs"][: model.data.N, :, c]).sample(
(2000,)
)
data = time_to_first_binding(z)

# use cuda
torch.set_default_tensor_type(torch.cuda.FloatTensor)

# Tapqir fit
train(
ttfb_model,
ttfb_guide,
lr=5e-3,
n_steps=15000,
data=data.cuda(),
control=None,
Tmax=Tmax,
jit=False,
)

results = pd.DataFrame(columns=["Mean", "95% LL", "95% UL"])

results.loc["ka", "Mean"] = pyro.param("ka").mean().item()
ll, ul = hpdi(pyro.param("ka").data.squeeze(), 0.95, dim=0)
results.loc["ka", "95% LL"], results.loc["ka", "95% UL"] = ll.item(), ul.item()

results.loc["kns", "Mean"] = pyro.param("kns").mean().item()
ll, ul = hpdi(pyro.param("kns").data.squeeze(), 0.95, dim=0)
results.loc["kns", "95% LL"], results.loc["kns", "95% UL"] = (
ll.item(),
ul.item(),
)

results.loc["Af", "Mean"] = pyro.param("Af").mean().item()
ll, ul = hpdi(pyro.param("Af").data.squeeze(), 0.95, dim=0)
results.loc["Af", "95% LL"], results.loc["Af", "95% UL"] = ll.item(), ul.item()
results.to_csv(f"ttfb{c}.csv")

# use cuda
torch.set_default_tensor_type(torch.FloatTensor)

nz = (data == 0).sum(1, keepdim=True)
N = data.shape[1]

fraction_bound = (data.unsqueeze(-1) < torch.arange(Tmax)).float().mean(1)
fb_ll, fb_ul = hpdi(fraction_bound, 0.95, dim=0)

ax.fill_between(torch.arange(Tmax), fb_ll, fb_ul, alpha=0.3, color="C2")
ax.plot(torch.arange(Tmax), fraction_bound.mean(0), color="C2")

ax.plot(
torch.arange(Tmax),
(
nz / N
+ (1 - nz / N)
* (
results.loc["Af", "Mean"]
* (
1
- torch.exp(
-(results.loc["ka", "Mean"] + results.loc["kns", "Mean"])
* torch.arange(Tmax)
)
)
+ (1 - results.loc["Af", "Mean"])
* (1 - torch.exp(-results.loc["kns", "Mean"] * torch.arange(Tmax)))
)
).mean(0),
color="k",
)

plt.minorticks_on()
ax.tick_params(
direction="in",
which="minor",
length=1,
bottom=True,
top=True,
left=True,
right=True,
)
ax.tick_params(
direction="in",
which="major",
length=2,
bottom=True,
top=True,
left=True,
right=True,
)
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
ax.set_yticklabels([r"$0$", r"$0.2$", r"$0.4$", r"$0.6$", r"$0.8$", r"$1$"])
ax.set_xlabel("Time (frame)")
ax.set_ylabel("Fraction bound")
ax.set_title(f"Channel {c}")
ax.set_ylim(-0.05, 1.05)

plt.savefig(f"ttfb_fit{c}.png", dpi=600)


@app.callback()
def main(
cd: Path = typer.Option(
Expand Down

0 comments on commit b5b1f04

Please sign in to comment.