-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bcc1d13
commit 44e5a6a
Showing
5 changed files
with
235 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,123 @@ | ||
import torch | ||
from hypothesis import given, settings, strategies as st | ||
from hypothesis import given, settings | ||
from torch_struct import TreeCRF | ||
from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes | ||
from torchrua import C | ||
|
||
from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores | ||
from torchlatent.semiring import Log | ||
from torchlatent.cky import CkyDecoder, masked_select | ||
|
||
|
||
def get_argmax(cky): | ||
argmax = cky.argmax | ||
mask = argmax > 0 | ||
@settings(deadline=None) | ||
@given( | ||
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), | ||
num_targets=sizes(TINY_TOKEN_SIZE), | ||
) | ||
def test_cky_scores(token_sizes, num_targets): | ||
logits = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
token_sizes = torch.tensor(token_sizes, device=device) | ||
|
||
_, t, _, n = mask.size() | ||
index = torch.arange(t, device=mask.device) | ||
x = torch.masked_select(index[None, :, None, None], mask=mask) | ||
y = torch.masked_select(index[None, None, :, None], mask=mask) | ||
expected = TreeCRF(logits, lengths=token_sizes) | ||
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) | ||
|
||
index = torch.arange(n, device=mask.device) | ||
z = torch.masked_select(index[None, None, None, :], mask=mask) | ||
expected = expected.log_prob(expected.argmax) | ||
actual = actual.log_probs(actual.argmax) | ||
|
||
return argmax, x, y, z | ||
assert_close(actual=actual, expected=expected) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(logits,)) | ||
|
||
|
||
@settings(deadline=None) | ||
@given( | ||
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), | ||
num_targets=sizes(TINY_TOKEN_SIZE), | ||
rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), | ||
) | ||
def test_cky_scores(token_sizes, num_targets, rua_targets): | ||
emissions = torch.randn( | ||
def test_cky_partitions(token_sizes, num_targets): | ||
logits = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
token_sizes = torch.tensor(token_sizes, device=device) | ||
expected_cky = TreeCRF(emissions, lengths=token_sizes) | ||
|
||
argmax, x, y, z = get_argmax(expected_cky) | ||
expected = TreeCRF(logits, lengths=token_sizes) | ||
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) | ||
|
||
expected = expected.partition | ||
actual = actual.log_partitions | ||
|
||
emissions = torch.randn_like(emissions, requires_grad=True) | ||
assert_close(actual=actual, expected=expected) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(logits,)) | ||
|
||
expected_cky = TreeCRF(emissions, lengths=token_sizes) | ||
expected = expected_cky.log_prob(argmax) + expected_cky.partition | ||
|
||
targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) | ||
actual = cky_scores( | ||
emissions=C(emissions, token_sizes), | ||
targets=rua_targets(targets), | ||
semiring=Log, | ||
@settings(deadline=None) | ||
@given( | ||
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), | ||
num_targets=sizes(TINY_TOKEN_SIZE), | ||
) | ||
def test_cky_argmax(token_sizes, num_targets): | ||
logits = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
token_sizes = torch.tensor(token_sizes, device=device) | ||
|
||
assert_close(actual=actual, expected=expected) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) | ||
expected = TreeCRF(logits, lengths=token_sizes) | ||
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) | ||
|
||
expected = C(data=torch.stack(masked_select(expected.argmax.bool()), dim=-1), token_sizes=token_sizes * 2 - 1) | ||
actual = actual.argmax | ||
|
||
for actual, expected in zip(actual.tolist(), expected.tolist()): | ||
assert set(map(tuple, actual)) == set(map(tuple, expected)) | ||
|
||
|
||
@settings(deadline=None) | ||
@given( | ||
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), | ||
num_targets=sizes(TINY_TOKEN_SIZE), | ||
) | ||
def test_cky_partitions(token_sizes, num_targets): | ||
emissions = torch.randn( | ||
def test_cky_entropy(token_sizes, num_targets): | ||
logits = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
token_sizes = torch.tensor(token_sizes, device=device) | ||
|
||
expected = TreeCRF(emissions, lengths=token_sizes).partition | ||
expected = TreeCRF(logits, lengths=token_sizes) | ||
actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) | ||
|
||
actual_emissions = C( | ||
data=emissions.logsumexp(dim=-1), | ||
token_sizes=token_sizes, | ||
) | ||
actual = cky_partitions(actual_emissions, Log) | ||
expected = expected.entropy | ||
actual = actual.entropy | ||
|
||
assert_close(actual=actual, expected=expected) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(logits,), rtol=1e-4, atol=1e-4) | ||
|
||
|
||
@settings(deadline=None) | ||
@given( | ||
token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), | ||
num_targets=sizes(TINY_TOKEN_SIZE), | ||
) | ||
def test_cky_argmax(token_sizes, num_targets): | ||
emissions = torch.randn( | ||
def test_cky_kl(token_sizes, num_targets): | ||
logits1 = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
logits2 = torch.randn( | ||
(len(token_sizes), max(token_sizes), max(token_sizes), num_targets), | ||
device=device, requires_grad=True, | ||
) | ||
token_sizes = torch.tensor(token_sizes, device=device) | ||
|
||
expected_cky = TreeCRF(emissions, lengths=token_sizes) | ||
|
||
_, x, y, z = get_argmax(expected_cky) | ||
expected1 = TreeCRF(logits1, lengths=token_sizes) | ||
expected2 = TreeCRF(logits2, lengths=token_sizes) | ||
actual1 = CkyDecoder(num_targets=num_targets)(logits=C(logits1, token_sizes)) | ||
actual2 = CkyDecoder(num_targets=num_targets)(logits=C(logits2, token_sizes)) | ||
|
||
expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) | ||
expected = expected1.kl(expected2) | ||
actual = actual1.kl(actual2) | ||
|
||
actual_cky = CkyDecoder(num_targets=num_targets) | ||
actual = actual_cky(emissions=C(emissions, token_sizes)).argmax | ||
|
||
for actual, expected in zip(actual.tolist(), expected.tolist()): | ||
assert set(map(tuple, actual)) == set(map(tuple, expected)) | ||
assert_close(actual=actual, expected=expected) | ||
assert_grad_close(actual=actual, expected=expected, inputs=(logits1, logits2), rtol=1e-4, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.