Skip to content

Commit

Permalink
Refactor: Rename emission to logit
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Sep 10, 2024
1 parent 3348f52 commit 92e62bb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
6 changes: 6 additions & 0 deletions torchlatent/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from torchlatent.cky import CkyDecoder, CkyDistribution
from torchlatent.crf import CrfDecoder, CrfDistribution

__all__ = [
'CkyDecoder', 'CkyDistribution',
'CrfDecoder', 'CrfDistribution',
]
14 changes: 7 additions & 7 deletions torchlatent/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def crf_scores(logits: Z, targets: Z, bias: T, semiring: Type[Semiring]) -> Tens
bias = bias[targets.data.roll(1), targets.data]

batch_ptr, token_ptr = targets.ptr()
logits = logits.left().data[batch_ptr, token_ptr, targets.data]
logits ,_= logits[batch_ptr, token_ptr, targets.data]
logits = semiring.segment_prod(logits, sizes=token_sizes)

token_sizes = torch.stack([torch.ones_like(token_sizes), token_sizes - 1], dim=-1)
Expand All @@ -41,17 +41,17 @@ def crf_partitions(logits: Z, bias: T, semiring: Type[Semiring]) -> Tensor:
logits, batch_sizes, _, _ = logits

_, *batch_sizes = sections = batch_sizes.detach().cpu().tolist()
emission, *logits = torch.split(logits, sections, dim=0)
logit, *logits = torch.split(logits, sections, dim=0)

charts = [semiring.mul(head_bias, emission)]
for emission, batch_size in zip(logits, batch_sizes):
charts = [semiring.mul(head_bias, logit)]
for logit, batch_size in zip(logits, batch_sizes):
charts.append(semiring.mul(
semiring.bmm(charts[-1][:batch_size], bias),
emission,
logit,
))

emission = torch.cat(charts, dim=0)[last_indices]
return semiring.sum(semiring.mul(emission, last_bias), dim=-1)
logit = torch.cat(charts, dim=0)[last_indices]
return semiring.sum(semiring.mul(logit, last_bias), dim=-1)


class CrfDistribution(StructuredDistribution):
Expand Down

0 comments on commit 92e62bb

Please sign in to comment.