Skip to content

Commit

Permalink
Feat: Support ExceptionSemiring
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 14, 2024
1 parent bcc1d13 commit 44e5a6a
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 135 deletions.
115 changes: 67 additions & 48 deletions tests/test_cky.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,123 @@
import torch
from hypothesis import given, settings, strategies as st
from hypothesis import given, settings
from torch_struct import TreeCRF
from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes
from torchrua import C

from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores
from torchlatent.semiring import Log
from torchlatent.cky import CkyDecoder, masked_select


def get_argmax(cky):
argmax = cky.argmax
mask = argmax > 0
@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
num_targets=sizes(TINY_TOKEN_SIZE),
)
def test_cky_scores(token_sizes, num_targets):
logits = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
token_sizes = torch.tensor(token_sizes, device=device)

_, t, _, n = mask.size()
index = torch.arange(t, device=mask.device)
x = torch.masked_select(index[None, :, None, None], mask=mask)
y = torch.masked_select(index[None, None, :, None], mask=mask)
expected = TreeCRF(logits, lengths=token_sizes)
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes))

index = torch.arange(n, device=mask.device)
z = torch.masked_select(index[None, None, None, :], mask=mask)
expected = expected.log_prob(expected.argmax)
actual = actual.log_probs(actual.argmax)

return argmax, x, y, z
assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=(logits,))


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
num_targets=sizes(TINY_TOKEN_SIZE),
rua_targets=st.sampled_from([C.cat, C.pad, C.pack]),
)
def test_cky_scores(token_sizes, num_targets, rua_targets):
emissions = torch.randn(
def test_cky_partitions(token_sizes, num_targets):
logits = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
token_sizes = torch.tensor(token_sizes, device=device)
expected_cky = TreeCRF(emissions, lengths=token_sizes)

argmax, x, y, z = get_argmax(expected_cky)
expected = TreeCRF(logits, lengths=token_sizes)
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes))

expected = expected.partition
actual = actual.log_partitions

emissions = torch.randn_like(emissions, requires_grad=True)
assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=(logits,))

expected_cky = TreeCRF(emissions, lengths=token_sizes)
expected = expected_cky.log_prob(argmax) + expected_cky.partition

targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1)
actual = cky_scores(
emissions=C(emissions, token_sizes),
targets=rua_targets(targets),
semiring=Log,
@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
num_targets=sizes(TINY_TOKEN_SIZE),
)
def test_cky_argmax(token_sizes, num_targets):
logits = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
token_sizes = torch.tensor(token_sizes, device=device)

assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=(emissions,))
expected = TreeCRF(logits, lengths=token_sizes)
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes))

expected = C(data=torch.stack(masked_select(expected.argmax.bool()), dim=-1), token_sizes=token_sizes * 2 - 1)
actual = actual.argmax

for actual, expected in zip(actual.tolist(), expected.tolist()):
assert set(map(tuple, actual)) == set(map(tuple, expected))


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
num_targets=sizes(TINY_TOKEN_SIZE),
)
def test_cky_partitions(token_sizes, num_targets):
emissions = torch.randn(
def test_cky_entropy(token_sizes, num_targets):
logits = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
token_sizes = torch.tensor(token_sizes, device=device)

expected = TreeCRF(emissions, lengths=token_sizes).partition
expected = TreeCRF(logits, lengths=token_sizes)
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes))

actual_emissions = C(
data=emissions.logsumexp(dim=-1),
token_sizes=token_sizes,
)
actual = cky_partitions(actual_emissions, Log)
expected = expected.entropy
actual = actual.entropy

assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=(emissions,))
assert_grad_close(actual=actual, expected=expected, inputs=(logits,), rtol=1e-4, atol=1e-4)


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
num_targets=sizes(TINY_TOKEN_SIZE),
)
def test_cky_argmax(token_sizes, num_targets):
emissions = torch.randn(
def test_cky_kl(token_sizes, num_targets):
logits1 = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
logits2 = torch.randn(
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
device=device, requires_grad=True,
)
token_sizes = torch.tensor(token_sizes, device=device)

expected_cky = TreeCRF(emissions, lengths=token_sizes)

_, x, y, z = get_argmax(expected_cky)
expected1 = TreeCRF(logits1, lengths=token_sizes)
expected2 = TreeCRF(logits2, lengths=token_sizes)
actual1 = CkyDecoder(num_targets=num_targets)(logits=C(logits1, token_sizes))
actual2 = CkyDecoder(num_targets=num_targets)(logits=C(logits2, token_sizes))

expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1)
expected = expected1.kl(expected2)
actual = actual1.kl(actual2)

actual_cky = CkyDecoder(num_targets=num_targets)
actual = actual_cky(emissions=C(emissions, token_sizes)).argmax

for actual, expected in zip(actual.tolist(), expected.tolist()):
assert set(map(tuple, actual)) == set(map(tuple, expected))
assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=(logits1, logits2), rtol=1e-4, atol=1e-4)
46 changes: 23 additions & 23 deletions tests/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
num_targets=sizes(TOKEN_SIZE),
rua_emissions=st.sampled_from([C.new, D.new, P.new]),
rua_logits=st.sampled_from([C.new, D.new, P.new]),
rua_targets=st.sampled_from([C.new, D.new, P.new]),
)
def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets):
def test_crf_scores(token_sizes, num_targets, rua_logits, rua_targets):
inputs = [
torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
Expand All @@ -28,19 +28,19 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets):

expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)

expected_emissions = D.new(inputs)
expected_logits = D.new(inputs)
expected_tags = D.new(targets)

expected = expected_crf._compute_score(
expected_emissions.data.transpose(0, 1),
expected_logits.data.transpose(0, 1),
expected_tags.data.transpose(0, 1),
expected_emissions.mask().transpose(0, 1),
expected_logits.mask().transpose(0, 1),
)

actual = crf_scores(
emissions=rua_emissions(inputs),
logits=rua_logits(inputs),
targets=rua_targets(targets),
transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
bias=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
semiring=Log,
)

Expand All @@ -52,26 +52,26 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets):
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
num_targets=sizes(TOKEN_SIZE),
rua_emissions=st.sampled_from([C.new, D.new, P.new]),
rua_logits=st.sampled_from([C.new, D.new, P.new]),
)
def test_crf_partitions(token_sizes, num_targets, rua_emissions):
def test_crf_partitions(token_sizes, num_targets, rua_logits):
inputs = [
torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
]

expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)

expected_emissions = D.new(inputs)
expected_logits = D.new(inputs)

expected = expected_crf._compute_normalizer(
expected_emissions.data.transpose(0, 1),
expected_emissions.mask().t(),
expected_logits.data.transpose(0, 1),
expected_logits.mask().t(),
)

actual = crf_partitions(
emissions=rua_emissions(inputs),
transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
logits=rua_logits(inputs),
bias=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
semiring=Log,
)

Expand All @@ -83,29 +83,29 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions):
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
num_targets=sizes(TOKEN_SIZE),
rua_emissions=st.sampled_from([C.new, D.new, P.new]),
rua_logits=st.sampled_from([C.new, D.new, P.new]),
)
def test_crf_argmax(token_sizes, num_targets, rua_emissions):
def test_crf_argmax(token_sizes, num_targets, rua_logits):
inputs = [
torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
]

expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)

expected_emissions = D.new(inputs)
expected_logits = D.new(inputs)

expected = expected_crf.decode(
expected_emissions.data.transpose(0, 1),
expected_emissions.mask().t(),
expected_logits.data.transpose(0, 1),
expected_logits.mask().t(),
)
expected = C.new([torch.tensor(tensor, device=device) for tensor in expected])

actual_crf = CrfDecoder(num_targets=num_targets)
actual_crf.transitions = expected_crf.transitions
actual_crf.head_transitions = expected_crf.start_transitions
actual_crf.last_transitions = expected_crf.end_transitions
actual_crf.bias = expected_crf.transitions
actual_crf.head_bias = expected_crf.start_transitions
actual_crf.last_bias = expected_crf.end_transitions

actual = actual_crf(rua_emissions(inputs)).argmax.cat()
actual = actual_crf(rua_logits(inputs)).argmax.cat()

assert_sequence_close(actual=actual, expected=expected)
10 changes: 5 additions & 5 deletions torchlatent/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


class StructuredDistribution(object, metaclass=ABCMeta):
def __init__(self, emissions: Union[C, D, P]) -> None:
def __init__(self, logits: Union[C, D, P]) -> None:
super(StructuredDistribution, self).__init__()
self.emissions = emissions
self.logits = logits

def log_scores(self, targets: Union[C, D, P]) -> Tensor:
raise NotImplementedError
Expand All @@ -26,7 +26,7 @@ def log_partitions(self) -> Tensor:
@lazy_property
def marginals(self) -> Tensor:
grad, = torch.autograd.grad(
self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions),
self.log_partitions, self.logits.data, torch.ones_like(self.log_partitions),
create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True,

)
Expand All @@ -39,7 +39,7 @@ def max(self) -> Tensor:
@lazy_property
def argmax(self) -> Tensor:
grad, = torch.autograd.grad(
self.max, self.emissions.data, torch.ones_like(self.max),
self.max, self.logits.data, torch.ones_like(self.max),
create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True,
)
return grad
Expand All @@ -56,5 +56,5 @@ def reset_parameters(self) -> None:
def extra_repr(self) -> str:
return f'num_targets={self.num_targets}'

def forward(self, emissions: Union[C, D, P]) -> StructuredDistribution:
def forward(self, logits: Union[C, D, P]) -> StructuredDistribution:
raise NotImplementedError
Loading

0 comments on commit 44e5a6a

Please sign in to comment.