From 44e5a6a24454876b91e8fd0c4db173ecdc46698c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 15 Mar 2024 00:08:23 +0900 Subject: [PATCH] Feat: Support ExceptionSemiring --- tests/test_cky.py | 115 ++++++++++++++++++++++++++------------------- tests/test_crf.py | 46 +++++++++--------- torchlatent/abc.py | 10 ++-- torchlatent/cky.py | 115 ++++++++++++++++++++++++++++++++++++++------- torchlatent/crf.py | 84 ++++++++++++++++----------------- 5 files changed, 235 insertions(+), 135 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index e627e8e..2513779 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,58 +1,76 @@ import torch -from hypothesis import given, settings, strategies as st +from hypothesis import given, settings from torch_struct import TreeCRF from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes from torchrua import C -from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores -from torchlatent.semiring import Log +from torchlatent.cky import CkyDecoder, masked_select -def get_argmax(cky): - argmax = cky.argmax - mask = argmax > 0 +@settings(deadline=None) +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), +) +def test_cky_scores(token_sizes, num_targets): + logits = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) + token_sizes = torch.tensor(token_sizes, device=device) - _, t, _, n = mask.size() - index = torch.arange(t, device=mask.device) - x = torch.masked_select(index[None, :, None, None], mask=mask) - y = torch.masked_select(index[None, None, :, None], mask=mask) + expected = TreeCRF(logits, lengths=token_sizes) + actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) - index = torch.arange(n, device=mask.device) - z = torch.masked_select(index[None, None, None, :], mask=mask) + expected = expected.log_prob(expected.argmax) + actual = actual.log_probs(actual.argmax) - return argmax, x, y, z + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(logits,)) @settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), - rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), ) -def test_cky_scores(token_sizes, num_targets, rua_targets): - emissions = torch.randn( +def test_cky_partitions(token_sizes, num_targets): + logits = torch.randn( (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), device=device, requires_grad=True, ) token_sizes = torch.tensor(token_sizes, device=device) - expected_cky = TreeCRF(emissions, lengths=token_sizes) - argmax, x, y, z = get_argmax(expected_cky) + expected = TreeCRF(logits, lengths=token_sizes) + actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) + + expected = expected.partition + actual = actual.log_partitions - emissions = torch.randn_like(emissions, requires_grad=True) + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(logits,)) - expected_cky = TreeCRF(emissions, lengths=token_sizes) - expected = expected_cky.log_prob(argmax) + expected_cky.partition - targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) - actual = cky_scores( - emissions=C(emissions, token_sizes), - targets=rua_targets(targets), - semiring=Log, +@settings(deadline=None) +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), +) +def test_cky_argmax(token_sizes, num_targets): + logits = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, ) + token_sizes = torch.tensor(token_sizes, device=device) - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) + expected = TreeCRF(logits, lengths=token_sizes) + actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) + + expected = C(data=torch.stack(masked_select(expected.argmax.bool()), dim=-1), token_sizes=token_sizes * 2 - 1) + actual = actual.argmax + + for actual, expected in zip(actual.tolist(), expected.tolist()): + assert set(map(tuple, actual)) == set(map(tuple, expected)) @settings(deadline=None) @@ -60,23 +78,21 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), ) -def test_cky_partitions(token_sizes, num_targets): - emissions = torch.randn( +def test_cky_entropy(token_sizes, num_targets): + logits = torch.randn( (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), device=device, requires_grad=True, ) token_sizes = torch.tensor(token_sizes, device=device) - expected = TreeCRF(emissions, lengths=token_sizes).partition + expected = TreeCRF(logits, lengths=token_sizes) + actual = CkyDecoder(num_targets=num_targets)(logits=C(logits, token_sizes)) - actual_emissions = C( - data=emissions.logsumexp(dim=-1), - token_sizes=token_sizes, - ) - actual = cky_partitions(actual_emissions, Log) + expected = expected.entropy + actual = actual.entropy assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) + assert_grad_close(actual=actual, expected=expected, inputs=(logits,), rtol=1e-4, atol=1e-4) @settings(deadline=None) @@ -84,21 +100,24 @@ def test_cky_partitions(token_sizes, num_targets): token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), ) -def test_cky_argmax(token_sizes, num_targets): - emissions = torch.randn( +def test_cky_kl(token_sizes, num_targets): + logits1 = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) + logits2 = torch.randn( (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), device=device, requires_grad=True, ) token_sizes = torch.tensor(token_sizes, device=device) - expected_cky = TreeCRF(emissions, lengths=token_sizes) - - _, x, y, z = get_argmax(expected_cky) + expected1 = TreeCRF(logits1, lengths=token_sizes) + expected2 = TreeCRF(logits2, lengths=token_sizes) + actual1 = CkyDecoder(num_targets=num_targets)(logits=C(logits1, token_sizes)) + actual2 = CkyDecoder(num_targets=num_targets)(logits=C(logits2, token_sizes)) - expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + expected = expected1.kl(expected2) + actual = actual1.kl(actual2) - actual_cky = CkyDecoder(num_targets=num_targets) - actual = actual_cky(emissions=C(emissions, token_sizes)).argmax - - for actual, expected in zip(actual.tolist(), expected.tolist()): - assert set(map(tuple, actual)) == set(map(tuple, expected)) + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(logits1, logits2), rtol=1e-4, atol=1e-4) diff --git a/tests/test_crf.py b/tests/test_crf.py index f148245..22e7cf0 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -12,10 +12,10 @@ @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([C.new, D.new, P.new]), + rua_logits=st.sampled_from([C.new, D.new, P.new]), rua_targets=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): +def test_crf_scores(token_sizes, num_targets, rua_logits, rua_targets): inputs = [ torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes @@ -28,19 +28,19 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = D.new(inputs) + expected_logits = D.new(inputs) expected_tags = D.new(targets) expected = expected_crf._compute_score( - expected_emissions.data.transpose(0, 1), + expected_logits.data.transpose(0, 1), expected_tags.data.transpose(0, 1), - expected_emissions.mask().transpose(0, 1), + expected_logits.mask().transpose(0, 1), ) actual = crf_scores( - emissions=rua_emissions(inputs), + logits=rua_logits(inputs), targets=rua_targets(targets), - transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), + bias=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), semiring=Log, ) @@ -52,9 +52,9 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([C.new, D.new, P.new]), + rua_logits=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_partitions(token_sizes, num_targets, rua_emissions): +def test_crf_partitions(token_sizes, num_targets, rua_logits): inputs = [ torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes @@ -62,16 +62,16 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = D.new(inputs) + expected_logits = D.new(inputs) expected = expected_crf._compute_normalizer( - expected_emissions.data.transpose(0, 1), - expected_emissions.mask().t(), + expected_logits.data.transpose(0, 1), + expected_logits.mask().t(), ) actual = crf_partitions( - emissions=rua_emissions(inputs), - transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), + logits=rua_logits(inputs), + bias=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), semiring=Log, ) @@ -83,9 +83,9 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([C.new, D.new, P.new]), + rua_logits=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_argmax(token_sizes, num_targets, rua_emissions): +def test_crf_argmax(token_sizes, num_targets, rua_logits): inputs = [ torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes @@ -93,19 +93,19 @@ def test_crf_argmax(token_sizes, num_targets, rua_emissions): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = D.new(inputs) + expected_logits = D.new(inputs) expected = expected_crf.decode( - expected_emissions.data.transpose(0, 1), - expected_emissions.mask().t(), + expected_logits.data.transpose(0, 1), + expected_logits.mask().t(), ) expected = C.new([torch.tensor(tensor, device=device) for tensor in expected]) actual_crf = CrfDecoder(num_targets=num_targets) - actual_crf.transitions = expected_crf.transitions - actual_crf.head_transitions = expected_crf.start_transitions - actual_crf.last_transitions = expected_crf.end_transitions + actual_crf.bias = expected_crf.transitions + actual_crf.head_bias = expected_crf.start_transitions + actual_crf.last_bias = expected_crf.end_transitions - actual = actual_crf(rua_emissions(inputs)).argmax.cat() + actual = actual_crf(rua_logits(inputs)).argmax.cat() assert_sequence_close(actual=actual, expected=expected) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index d5e088d..7bb5164 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -9,9 +9,9 @@ class StructuredDistribution(object, metaclass=ABCMeta): - def __init__(self, emissions: Union[C, D, P]) -> None: + def __init__(self, logits: Union[C, D, P]) -> None: super(StructuredDistribution, self).__init__() - self.emissions = emissions + self.logits = logits def log_scores(self, targets: Union[C, D, P]) -> Tensor: raise NotImplementedError @@ -26,7 +26,7 @@ def log_partitions(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions), + self.log_partitions, self.logits.data, torch.ones_like(self.log_partitions), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, ) @@ -39,7 +39,7 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: grad, = torch.autograd.grad( - self.max, self.emissions.data, torch.ones_like(self.max), + self.max, self.logits.data, torch.ones_like(self.max), create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, ) return grad @@ -56,5 +56,5 @@ def reset_parameters(self) -> None: def extra_repr(self) -> str: return f'num_targets={self.num_targets}' - def forward(self, emissions: Union[C, D, P]) -> StructuredDistribution: + def forward(self, logits: Union[C, D, P]) -> StructuredDistribution: raise NotImplementedError diff --git a/torchlatent/cky.py b/torchlatent/cky.py index b8c863d..5f12d1a 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -6,15 +6,15 @@ from torchrua import C, D, P from torchlatent.abc import StructuredDecoder, StructuredDistribution -from torchlatent.semiring import Log, Max, Semiring +from torchlatent.semiring import Div, ExceptionSemiring, Log, Max, Semiring, Xen -def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: +def cky_scores(logits: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: xyz, token_sizes = targets = targets.cat() batch_ptr, _ = targets.ptr() - emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] - return semiring.segment_prod(emissions, token_sizes) + logits = logits.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] + return semiring.segment_prod(logits, token_sizes) def diag(tensor: Tensor, offset: int) -> Tensor: @@ -43,17 +43,53 @@ def right(chart: Tensor, offset: int) -> Tensor: ) -def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) +def cky_partitions(logits: C, semiring: Type[Semiring]) -> Tensor: + chart = torch.full_like(logits.data, fill_value=semiring.zero, requires_grad=False) - diag_scatter(chart, diag(emissions.data, offset=0), offset=0) + diag_scatter(chart, diag(logits.data, offset=0), offset=0) for w in range(1, chart.size()[1]): score = semiring.sum(semiring.mul(left(chart, offset=w), right(chart, offset=w)), dim=2) - diag_scatter(chart, semiring.mul(score, diag(emissions.data, offset=w)), offset=w) + diag_scatter(chart, semiring.mul(score, diag(logits.data, offset=w)), offset=w) index = torch.arange(chart.size()[0], dtype=torch.long, device=chart.device) - return chart[index, 0, emissions.token_sizes - 1] + return chart[index, 0, logits.token_sizes - 1] + + +def cky_exceptions(logits1: C, logits2: C, log_prob: C, + semiring: Type[Semiring], exception: Type[ExceptionSemiring]) -> Tensor: + chart1 = torch.full_like(logits1.data, fill_value=semiring.zero, requires_grad=False) + chart2 = torch.full_like(logits2.data, fill_value=semiring.zero, requires_grad=False) + chart3 = torch.full_like(log_prob.data, fill_value=exception.zero, requires_grad=False) + + diag_scatter(chart1, diag(logits1.data, offset=0), offset=0) + diag_scatter(chart2, diag(logits2.data, offset=0), offset=0) + diag_scatter(chart3, diag(log_prob.data, offset=0), offset=0) + + for w in range(1, chart3.size()[1]): + score1 = semiring.mul(left(chart1, offset=w), right(chart1, offset=w)) + score2 = semiring.mul(left(chart2, offset=w), right(chart2, offset=w)) + tensor = exception.mul(left(chart3, offset=w), right(chart3, offset=w)) + + diag_scatter(chart1, semiring.mul( + semiring.sum(score1, dim=2), + diag(logits1.data, offset=w), + ), offset=w) + + diag_scatter(chart2, semiring.mul( + semiring.sum(score2, dim=2), + diag(logits2.data, offset=w), + ), offset=w) + + log_p = score1 - semiring.sum(score1, dim=-1, keepdim=True) + log_q = score2 - semiring.sum(score2, dim=-1, keepdim=True) + diag_scatter(chart3, exception.mul( + exception.sum(tensor, log_p=log_p, log_q=log_q, dim=2), + diag(log_prob.data, offset=w), + ), offset=w) + + index = torch.arange(chart3.size()[0], dtype=torch.long, device=chart3.device) + return chart3[index, 0, log_prob.token_sizes - 1] def masked_select(mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: @@ -70,29 +106,58 @@ def masked_select(mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: class CkyDistribution(StructuredDistribution): - def __init__(self, emissions: C) -> None: - super(CkyDistribution, self).__init__(emissions=emissions) + def __init__(self, logits: C) -> None: + super(CkyDistribution, self).__init__(logits=logits) def log_scores(self, targets: Union[C, D, P]) -> Tensor: return cky_scores( - emissions=self.emissions, targets=targets, + logits=self.logits, targets=targets, semiring=Log, ) @lazy_property def log_partitions(self) -> Tensor: return cky_partitions( - emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)), + logits=self.logits._replace(data=Log.sum(self.logits.data, dim=-1)), semiring=Log, ) @lazy_property def max(self) -> Tensor: return cky_partitions( - emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)), + logits=self.logits._replace(data=Max.sum(self.logits.data, dim=-1)), semiring=Max, ) + @lazy_property + def entropy(self) -> Tensor: + return cky_exceptions( + logits1=self.logits._replace(data=Log.sum(self.logits.data, dim=-1)), + logits2=self.logits._replace(data=Log.sum(self.logits.data, dim=-1)), + log_prob=self.logits._replace(data=Xen.sum( + torch.full_like(self.logits.data, fill_value=Xen.zero), + torch.log_softmax(self.logits.data, dim=-1), + torch.log_softmax(self.logits.data, dim=-1), + dim=-1, + )), + semiring=Log, + exception=Xen, + ) + + def kl(self, other: 'CkyDistribution') -> Tensor: + return cky_exceptions( + logits1=self.logits._replace(data=Log.sum(self.logits.data, dim=-1)), + logits2=other.logits._replace(data=Log.sum(other.logits.data, dim=-1)), + log_prob=self.logits._replace(data=Div.sum( + torch.full_like(self.logits.data, fill_value=Div.zero), + torch.log_softmax(self.logits.data, dim=-1), + torch.log_softmax(other.logits.data, dim=-1), + dim=-1, + )), + semiring=Log, + exception=Div, + ) + @lazy_property def argmax(self) -> C: argmax = super(CkyDistribution, self).argmax @@ -100,7 +165,7 @@ def argmax(self) -> C: return C( data=torch.stack([x, y, z], dim=-1), - token_sizes=self.emissions.token_sizes * 2 - 1, + token_sizes=self.logits.token_sizes * 2 - 1, ) @@ -108,5 +173,21 @@ class CkyDecoder(StructuredDecoder): def __init__(self, *, num_targets: int) -> None: super(CkyDecoder, self).__init__(num_targets=num_targets) - def forward(self, emissions: C) -> CkyDistribution: - return CkyDistribution(emissions=emissions) + def forward(self, logits: C) -> CkyDistribution: + return CkyDistribution(logits=logits) + + +if __name__ == '__main__': + from torch_struct import TreeCRF + + num_targets = 17 + logits1 = C(data=torch.randn((3, 5, 5, num_targets), requires_grad=True), token_sizes=torch.tensor([5, 2, 3])) + logits2 = C(data=torch.randn((3, 5, 5, num_targets), requires_grad=True), token_sizes=torch.tensor([5, 2, 3])) + + excepted1 = TreeCRF(logits1.data, logits1.token_sizes) + excepted2 = TreeCRF(logits2.data, logits2.token_sizes) + print(excepted1.kl(excepted2)) + + actual1 = CkyDecoder(num_targets=num_targets)(logits1) + actual2 = CkyDecoder(num_targets=num_targets)(logits2) + print(actual1.kl(actual2)) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 83d1b09..61e8e03 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -12,102 +12,102 @@ T = Tuple[Tensor, Tensor, Tensor] -def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: - transitions, head_transitions, last_transitions = transitions +def crf_scores(logits: Union[C, D, P], targets: Union[C, D, P], bias: T, semiring: Type[Semiring]) -> Tensor: + bias, head_bias, last_bias = bias targets = _, token_sizes = targets.cat() - head_transitions = targets.head().rua(head_transitions) - last_transitions = targets.last().rua(last_transitions) - transitions = targets.data.roll(1).rua(transitions, targets) + head_bias = targets.head().rua(head_bias) + last_bias = targets.last().rua(last_bias) + bias = targets.data.roll(1).rua(bias, targets) - emissions, _ = emissions.idx().cat().rua(emissions, targets) - emissions = semiring.segment_prod(emissions, sizes=token_sizes) + logits, _ = logits.idx().cat().rua(logits, targets) + logits = semiring.segment_prod(logits, sizes=token_sizes) token_sizes = torch.stack([torch.ones_like(token_sizes), token_sizes - 1], dim=-1) - transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] + bias = semiring.segment_prod(bias, sizes=token_sizes.view(-1))[1::2] return semiring.mul( - semiring.mul(head_transitions, last_transitions), - semiring.mul(emissions, transitions), + semiring.mul(head_bias, last_bias), + semiring.mul(logits, bias), ) -def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: - transitions, head_transitions, last_transitions = transitions +def crf_partitions(logits: Union[C, D, P], bias: T, semiring: Type[Semiring]) -> Tensor: + bias, head_bias, last_bias = bias - emissions = emissions.pack() - last_indices = emissions.idx().last() - emissions, batch_sizes, _, _ = emissions + logits = logits.pack() + last_indices = logits.idx().last() + logits, batch_sizes, _, _ = logits _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() - emission, *emissions = torch.split(emissions, sections, dim=0) + emission, *logits = torch.split(logits, sections, dim=0) - charts = [semiring.mul(head_transitions, emission)] - for emission, batch_size in zip(emissions, batch_sizes): + charts = [semiring.mul(head_bias, emission)] + for emission, batch_size in zip(logits, batch_sizes): charts.append(semiring.mul( - semiring.bmm(charts[-1][:batch_size], transitions), + semiring.bmm(charts[-1][:batch_size], bias), emission, )) emission = torch.cat(charts, dim=0)[last_indices] - return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) + return semiring.sum(semiring.mul(emission, last_bias), dim=-1) class CrfDistribution(StructuredDistribution): - def __init__(self, emissions: Union[C, D, P], transitions: T) -> None: - super(CrfDistribution, self).__init__(emissions=emissions) - self.transitions = transitions + def __init__(self, logits: Union[C, D, P], bias: T) -> None: + super(CrfDistribution, self).__init__(logits=logits) + self.bias = bias def log_scores(self, targets: Union[C, D, P]) -> Tensor: return crf_scores( - emissions=self.emissions, targets=targets, - transitions=self.transitions, + logits=self.logits, targets=targets, + bias=self.bias, semiring=Log, ) @lazy_property def log_partitions(self) -> Tensor: return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, + logits=self.logits, + bias=self.bias, semiring=Log, ) @lazy_property def max(self) -> Tensor: return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, + logits=self.logits, + bias=self.bias, semiring=Max, ) @lazy_property def argmax(self) -> Union[C, D, P]: argmax = super(CrfDistribution, self).argmax.argmax(dim=-1) - return self.emissions._replace(data=argmax) + return self.logits._replace(data=argmax) class CrfDecoder(StructuredDecoder): def __init__(self, *, num_targets: int) -> None: super(CrfDecoder, self).__init__(num_targets=num_targets) - self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) - self.head_transitions = nn.Parameter(torch.empty((num_targets,))) - self.last_transitions = nn.Parameter(torch.empty((num_targets,))) + self.bias = nn.Parameter(torch.empty((num_targets, num_targets))) + self.head_bias = nn.Parameter(torch.empty((num_targets,))) + self.last_bias = nn.Parameter(torch.empty((num_targets,))) self.reset_parameters() def reset_parameters(self) -> None: - init.zeros_(self.transitions) - init.zeros_(self.head_transitions) - init.zeros_(self.last_transitions) + init.zeros_(self.bias) + init.zeros_(self.head_bias) + init.zeros_(self.last_bias) - def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: + def forward(self, logits: Union[C, D, P]) -> CrfDistribution: return CrfDistribution( - emissions=emissions, - transitions=( - self.transitions, - self.head_transitions, - self.last_transitions, + logits=logits, + bias=( + self.bias, + self.head_bias, + self.last_bias, ), )