diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 2944977..9993084 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 63384e4..db1e5bc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 7516561..1186c7b 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,20 @@ ## Requirements -- Python 3.7 -- PyTorch 1.6.0 +- Python 3.8 +- PyTorch 1.10.2 ## Installation `python3 -m pip torchlatent` +## Performance + +``` +TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497 +Third (0.232487) => 0.103277 0.129209 0.145311 +``` + ## Usage ```python diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/__main__.py b/benchmark/__main__.py new file mode 100644 index 0000000..351d087 --- /dev/null +++ b/benchmark/__main__.py @@ -0,0 +1,9 @@ +from aku import Aku + +from benchmark.crf import benchmark_crf + +aku = Aku() + +aku.option(benchmark_crf) + +aku.run() diff --git a/benchmark/crf.py b/benchmark/crf.py new file mode 100644 index 0000000..a74dce0 --- /dev/null +++ b/benchmark/crf.py @@ -0,0 +1,62 @@ +import torch +from torchrua import pack_sequence +from tqdm import tqdm + +from benchmark.meter import TimeMeter +from tests.third_party import ThirdPartyCrfDecoder +from torchlatent.crf import CrfDecoder + + +def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, + batch_size: int = 32, max_token_size: int = 512): + j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + device = torch.device('cpu') + print(f'device => {device}') + + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + print(f'decoder => {decoder}') + + third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + print(f'third_decoder => {third_decoder}') + + for _ in tqdm(range(num_runs)): + token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() + + emissions = pack_sequence([ + torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ]) + + tags = pack_sequence([ + torch.randint(0, num_tags, (token_size, num_conjugates), device=device) + for token_size in token_sizes + ]) + + with j1: + indices = decoder.compile_indices(emissions=emissions, tags=tags) + + with f1: + loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean() + + with b1: + _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + + with d1: + _ = decoder.decode(emissions=emissions, indices=indices) + + with f2: + loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean() + + with b2: + _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + + with d2: + _ = third_decoder.decode(emissions=emissions) + + print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}') + print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}') diff --git a/benchmark/meter.py b/benchmark/meter.py new file mode 100644 index 0000000..64e1c6e --- /dev/null +++ b/benchmark/meter.py @@ -0,0 +1,23 @@ +from datetime import datetime + + +class TimeMeter(object): + def __init__(self) -> None: + super(TimeMeter, self).__init__() + + self.seconds = 0 + self.counts = 0 + + def __enter__(self): + self.start_tm = datetime.now() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.seconds += (datetime.now() - self.start_tm).total_seconds() + self.counts += 1 + + @property + def merit(self) -> float: + return self.seconds / max(1, self.counts) + + def __repr__(self) -> str: + return f'{self.merit :.6f}' diff --git a/setup.py b/setup.py index b7687ba..6f8a71b 100644 --- a/setup.py +++ b/setup.py @@ -4,17 +4,17 @@ setup( name=name, - version='0.4.1', + version='0.4.2', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT', author='speedcell4', author_email='speedcell4@gmail.com', description='High Performance Structured Prediction in PyTorch', - python_requires='>=3.7', + python_requires='>=3.8', install_requires=[ 'numpy', - 'torchrua>=0.3.0', + 'torchrua>=0.4.0', ], extras_require={ 'dev': [ diff --git a/tests/strategies.py b/tests/strategies.py index 9579c16..785f929 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -2,20 +2,14 @@ from hypothesis import strategies as st -if torch.cuda.is_available(): - MAX_BATCH_SIZE = 120 - MAX_TOKEN_SIZE = 512 - MAX_NUM_TAGS = 100 - MAX_NUM_CONJUGATES = 16 -else: - MAX_BATCH_SIZE = 12 - MAX_TOKEN_SIZE = 24 - MAX_NUM_TAGS = 12 - MAX_NUM_CONJUGATES = 6 - TINY_BATCH_SIZE = 6 TINY_TOKEN_SIZE = 12 +BATCH_SIZE = 24 +TOKEN_SIZE = 50 +NUM_TAGS = 8 +NUM_CONJUGATES = 5 + @st.composite def devices(draw): @@ -28,36 +22,13 @@ def devices(draw): @st.composite -def batch_sizes(draw, max_value: int = MAX_BATCH_SIZE): - return draw(st.integers(min_value=1, max_value=max_value)) - - -@st.composite -def batch_size_lists(draw, max_batch_size: int = MAX_BATCH_SIZE): - return [ - draw(batch_sizes(max_value=max_batch_size)) - for _ in range(draw(batch_sizes(max_value=max_batch_size))) - ] - - -@st.composite -def token_sizes(draw, max_value: int = MAX_TOKEN_SIZE): - return draw(st.integers(min_value=1, max_value=max_value)) - +def sizes(draw, *size: int, min_size: int = 1): + max_size, *size = size -@st.composite -def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: int = MAX_BATCH_SIZE): - return [ - draw(token_sizes(max_value=max_token_size)) - for _ in range(draw(batch_sizes(max_value=max_batch_size))) - ] - - -@st.composite -def tag_sizes(draw, max_value: int = MAX_NUM_TAGS): - return draw(st.integers(min_value=1, max_value=max_value)) - - -@st.composite -def conjugate_sizes(draw, max_value: int = MAX_NUM_CONJUGATES): - return draw(st.integers(min_value=1, max_value=max_value)) + if len(size) == 0: + return draw(st.integers(min_value=min_size, max_value=max_size)) + else: + return [ + draw(sizes(*size, min_size=min_size)) + for _ in range(draw(st.integers(min_value=min_size, max_value=max_size))) + ] diff --git a/tests/test_crf.py b/tests/test_crf.py index 226c161..2164a7b 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,87 +1,20 @@ import torch -import torchcrf from hypothesis import given -from torch import Tensor -from torch import nn -from torch.nn.utils.rnn import PackedSequence -from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence +from torchrua import pack_sequence, cat_sequence, pack_catted_sequence -from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes -from tests.utils import assert_close, assert_grad_close, assert_packed_equal +from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS +from tests.third_party import ThirdPartyCrfDecoder +from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal from torchlatent.crf import CrfDecoder -class ThirdPartyCrfDecoder(nn.Module): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(ThirdPartyCrfDecoder, self).__init__() - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - self.decoders = nn.ModuleList([ - torchcrf.CRF(num_tags=num_tags, batch_first=False) - for _ in range(num_conjugates) - ]) - - @torch.no_grad() - def reset_parameters_with_(self, decoder: CrfDecoder) -> None: - assert self.num_tags == decoder.num_tags - assert self.num_conjugates == decoder.num_conjugates - - for index in range(self.num_conjugates): - self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] - self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] - self.decoders[index].end_transitions.data[::] = decoder.tail_transitions[:, index, :] - - def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - tags, _ = pad_packed_sequence(tags, batch_first=False) - mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) - - log_probs = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - tag = tags[:, :, index % num_emissions_conjugates] - - log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) - - return torch.stack(log_probs, dim=-1) - - def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) - - predictions = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - - prediction = decoder.decode(emissions=emission, mask=mask) - predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) - - return PackedSequence( - torch.stack([prediction.data for prediction in predictions], dim=1), - batch_sizes=predictions[0].batch_sizes, - sorted_indices=predictions[0].sorted_indices, - unsorted_indices=predictions[0].unsorted_indices, - ) - - @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) -def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) for token_size in token_sizes @@ -92,8 +25,8 @@ def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): for token_size in token_sizes ], device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) actual = actual_decoder.fit(emissions=emissions, tags=tags) @@ -105,21 +38,80 @@ def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) -def test_crf_decoder_decode(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) for token_size in token_sizes ], device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) expected = expected_decoder.decode(emissions=emissions) actual = actual_decoder.decode(emissions=emissions) - assert_packed_equal(actual=actual, expected=expected) + assert_packed_sequence_equal(actual=actual, expected=expected) + + +@given( + device=devices(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), +) +def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): + emissions = [ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ] + tags = [ + torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence(emissions, device=device) + packed_emissions = pack_sequence(emissions, device=device) + + catted_tags = cat_sequence(tags, device=device) + packed_tags = pack_sequence(tags, device=device) + + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder.reset_parameters_with_(decoder=actual_decoder) + + actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags) + expected = expected_decoder.fit(emissions=packed_emissions, tags=packed_tags) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions)) + + +@given( + device=devices(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), +) +def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): + emissions = [ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence(emissions, device=device) + packed_emissions = pack_sequence(emissions, device=device) + + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder.reset_parameters_with_(decoder=actual_decoder) + + expected = expected_decoder.decode(emissions=packed_emissions) + actual = actual_decoder.decode(emissions=catted_emissions) + actual = pack_catted_sequence(*actual, device=device) + + assert_packed_sequence_equal(actual=actual, expected=expected) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4924882..c12b74a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,14 +1,14 @@ import torch from hypothesis import given, strategies as st -from tests.strategies import devices, token_size_lists, TINY_TOKEN_SIZE, TINY_BATCH_SIZE +from tests.strategies import devices, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE from tests.utils import assert_close, assert_grad_close from torchlatent.functional import logaddexp, logsumexp @given( device=devices(), - token_sizes=token_size_lists(max_token_size=TINY_TOKEN_SIZE, max_batch_size=TINY_BATCH_SIZE) + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) def test_logaddexp(device, token_sizes): x = torch.randn(token_sizes, device=device, requires_grad=True) @@ -24,7 +24,7 @@ def test_logaddexp(device, token_sizes): @given( data=st.data(), device=devices(), - token_sizes=token_size_lists(max_token_size=TINY_TOKEN_SIZE, max_batch_size=TINY_BATCH_SIZE) + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) def test_logsumexp(data, device, token_sizes): tensor = torch.randn(token_sizes, device=device, requires_grad=True) diff --git a/tests/third_party.py b/tests/third_party.py new file mode 100644 index 0000000..6172f91 --- /dev/null +++ b/tests/third_party.py @@ -0,0 +1,83 @@ +import torch +import torchcrf +from torch import Tensor, nn +from torch.nn.utils.rnn import PackedSequence +from torch.types import Device +from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence + +from torchlatent.crf import CrfDecoder + + +@torch.no_grad() +def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: + if device is None: + device = sizes.device + + size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) + mask = torch.zeros(size, device=device, dtype=torch.bool) + mask[ptr] = True + return mask + + +class ThirdPartyCrfDecoder(nn.Module): + def __init__(self, num_tags: int, num_conjugates: int) -> None: + super(ThirdPartyCrfDecoder, self).__init__() + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + self.decoders = nn.ModuleList([ + torchcrf.CRF(num_tags=num_tags, batch_first=False) + for _ in range(num_conjugates) + ]) + + @torch.no_grad() + def reset_parameters_with_(self, decoder: CrfDecoder) -> None: + assert self.num_tags == decoder.num_tags + assert self.num_conjugates == decoder.num_conjugates + + for index in range(self.num_conjugates): + self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] + self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] + self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] + + def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + tags, _ = pad_packed_sequence(tags, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) + + log_probs = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + tag = tags[:, :, index % num_emissions_conjugates] + + log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) + + return torch.stack(log_probs, dim=-1) + + def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) + + predictions = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + + prediction = decoder.decode(emissions=emission, mask=mask) + predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) + + return PackedSequence( + torch.stack([prediction.data for prediction in predictions], dim=1), + batch_sizes=predictions[0].batch_sizes, + sorted_indices=predictions[0].sorted_indices, + unsorted_indices=predictions[0].unsorted_indices, + ) diff --git a/tests/utils.py b/tests/utils.py index a7958f5..3040db4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,34 +1,67 @@ -from typing import Tuple +from typing import List, Tuple, Union import torch from torch import Tensor from torch.nn.utils.rnn import PackedSequence -from torch.testing import assert_close, assert_equal +from torch.testing import assert_close +from torchrua.catting import CattedSequence __all__ = [ - 'assert_close', - 'assert_equal', - 'assert_grad_close', - 'assert_packed_close', - 'assert_packed_equal', + 'assert_equal', 'assert_close', 'assert_grad_close', + 'assert_catted_sequence_equal', 'assert_catted_sequence_close', + 'assert_packed_sequence_equal', 'assert_packed_sequence_close', ] -def assert_grad_close(actual: Tensor, expected: Tensor, inputs: Tuple[Tensor, ...]) -> None: - grad = torch.randn_like(actual) +def assert_equal(actual: Tensor, expected: Tensor) -> None: + assert torch.equal(actual, expected) - actual_grads = torch.autograd.grad(actual, inputs, grad) - expected_grads = torch.autograd.grad(expected, inputs, grad) + +def assert_grad_close( + actual: Tensor, expected: Tensor, + inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], + allow_unused: bool = False, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + + grad = torch.rand_like(actual) + + actual_grads = torch.autograd.grad( + actual, inputs, grad, + create_graph=False, + allow_unused=allow_unused, + ) + + expected_grads = torch.autograd.grad( + expected, inputs, grad, + create_graph=False, + allow_unused=allow_unused, + ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): - if actual_grad is None: - assert expected_grad is None - else: - assert_close(actual=actual_grad, expected=expected_grad) + assert_close(actual=actual_grad, expected=expected_grad, **kwargs) + + +def assert_catted_sequence_close( + actual: CattedSequence, expected: CattedSequence, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + assert_close(actual=actual.data, expected=expected.data, **kwargs) + assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) -def assert_packed_equal(actual: PackedSequence, expected: PackedSequence) -> None: + +def assert_catted_sequence_equal(actual: CattedSequence, expected: CattedSequence) -> None: assert_equal(actual=actual.data, expected=expected.data) + assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) + + +def assert_packed_sequence_close( + actual: PackedSequence, expected: PackedSequence, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + + assert_close(actual=actual.data, expected=expected.data, **kwargs) assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) if actual.sorted_indices is None: @@ -42,8 +75,8 @@ def assert_packed_equal(actual: PackedSequence, expected: PackedSequence) -> Non assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) -def assert_packed_close(actual: PackedSequence, expected: PackedSequence) -> None: - assert_close(actual=actual.data, expected=expected.data) +def assert_packed_sequence_equal(actual: PackedSequence, expected: PackedSequence) -> None: + assert_equal(actual=actual.data, expected=expected.data) assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) if actual.sorted_indices is None: diff --git a/torchlatent/crf.py b/torchlatent/crf.py deleted file mode 100644 index fe8bfc2..0000000 --- a/torchlatent/crf.py +++ /dev/null @@ -1,238 +0,0 @@ -from abc import ABCMeta -from typing import Optional, Type, Tuple - -import torch -from torch import Tensor -from torch import nn, autograd -from torch.distributions.utils import lazy_property -from torch.nn import init -from torch.nn.utils.rnn import PackedSequence -from torchrua import TreeReduceIndices, tree_reduce_packed_indices, batch_sizes_to_ptr -from torchrua import select_head, select_last, roll_packed_sequence - -from torchlatent.semiring import Semiring, Log, Max - -__all__ = [ - 'compute_scores', - 'compute_partitions', - 'CrfDistribution', - 'CrfDecoderABC', 'CrfDecoder', -] - - -def compute_scores(semiring: Type[Semiring]): - def _compute_scores( - emissions: PackedSequence, tags: PackedSequence, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: - device = transitions.device - - emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=device) # [t] - c = torch.arange(transitions.size()[1], device=device) # [c] - - x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] - head = select_head(tags, unsort=False) # [h, c] - tail = select_last(tags, unsort=False) # [h, c] - - transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] - transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - - transition_scores[:h] = transition_head_scores # [h, c] - - _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.batch_sizes) - scores = semiring.mul( - semiring.scatter_mul(semiring.mul(emission_scores, transition_scores), index=batch_ptr), - transition_tail_scores, - ) - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - - return scores - - return _compute_scores - - -def compute_partitions(semiring: Type[Semiring]): - def _compute_partitions( - emissions: PackedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=transitions.device) # [t] - c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - - scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - scores[:h] = eye[None, None, :, :] - scores = semiring.reduce(tensor=scores, indices=indices) - - emission_head_scores = emissions.data[:h, :, None, :] - transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] - - scores = semiring.bmm( - semiring.bmm(semiring.mul(transition_head_scores, emission_head_scores), scores), - transition_tail_scores, - )[..., 0, 0] - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - return scores - - return _compute_partitions - - -class CrfDistribution(object): - def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: - super(CrfDistribution, self).__init__() - self.emissions = emissions - self.indices = indices - - self.transitions = transitions - self.head_transitions = head_transitions - self.tail_transitions = tail_transitions - - def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: - return compute_scores(semiring=semiring)( - emissions=self.emissions, tags=tags, - transitions=self.transitions, - head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, - ) - - def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_partitions(semiring=semiring)( - emissions=self.emissions, indices=self.indices, - transitions=self.transitions, - head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, - eye=semiring.eye_like(self.transitions), - ) - - def log_prob(self, tags: PackedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: PackedSequence) -> Tensor: - return self.semiring_scores(semiring=Log, tags=tags) - - @lazy_property - def log_partitions(self) -> Tensor: - return self.semiring_partitions(semiring=Log) - - @lazy_property - def marginals(self) -> Tensor: - log_partitions = self.log_partitions - grad, = autograd.grad( - log_partitions, self.emissions.data, torch.ones_like(log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, - ) - return grad - - @lazy_property - def argmax(self) -> PackedSequence: - max_partitions = self.semiring_partitions(semiring=Max) - - grad, = torch.autograd.grad( - max_partitions, self.emissions.data, torch.ones_like(max_partitions), - retain_graph=False, create_graph=False, allow_unused=False, - ) - return PackedSequence( - data=grad.argmax(dim=-1), - batch_sizes=self.emissions.batch_sizes, - sorted_indices=self.emissions.sorted_indices, - unsorted_indices=self.emissions.unsorted_indices, - ) - - -class CrfDecoderABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int): - super(CrfDecoderABC, self).__init__() - - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - def reset_parameters(self) -> None: - raise NotImplementedError - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) - - def compile_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): - assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' - if tags is not None: - assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' - - if indices is None: - batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) - - return indices - - def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: - return self.transitions, self.head_transitions, self.tail_transitions - - def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): - indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - transitions, head_transitions, tail_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, indices=indices, - ) - - dist = CrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - - return dist, tags - - def fit(self, emissions: PackedSequence, tags: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: - dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) - - return dist.log_prob(tags=tags) - - def decode(self, emissions: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> PackedSequence: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.argmax - - def marginals(self, emissions: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.marginals - - -class CrfDecoder(CrfDecoderABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - - self.transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), - requires_grad=True, - ) - self.head_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - self.tail_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self) -> None: - bound = 0.01 - init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.tail_transitions, -bound, +bound) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py new file mode 100644 index 0000000..a126167 --- /dev/null +++ b/torchlatent/crf/__init__.py @@ -0,0 +1,123 @@ +from abc import ABCMeta +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor +from torch import nn +from torch.nn import init +from torchrua import ReductionIndices, PackedSequence, CattedSequence +from torchrua import reduce_packed_indices, reduce_catted_indices + +from torchlatent.crf.catting import CattedCrfDistribution +from torchlatent.crf.packing import PackedCrfDistribution + +__all__ = [ + 'CrfDecoderABC', 'CrfDecoder', + 'PackedCrfDistribution', + 'CattedCrfDistribution', + 'Sequence', +] + +Sequence = Union[ + PackedSequence, + CattedSequence, +] + + +class CrfDecoderABC(nn.Module, metaclass=ABCMeta): + def __init__(self, num_tags: int, num_conjugates: int) -> None: + super(CrfDecoderABC, self).__init__() + + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + def reset_parameters(self) -> None: + raise NotImplementedError + + def extra_repr(self) -> str: + return ', '.join([ + f'num_tags={self.num_tags}', + f'num_conjugates={self.num_conjugates}', + ]) + + @staticmethod + def compile_indices(emissions: Sequence, + tags: Optional[Sequence] = None, + indices: Optional[ReductionIndices] = None, **kwargs): + assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' + if tags is not None: + assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' + + if indices is None: + if isinstance(emissions, PackedSequence): + batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) + return reduce_packed_indices(batch_sizes=batch_sizes) + + if isinstance(emissions, CattedSequence): + token_sizes = emissions.token_sizes.to(device=emissions.data.device) + return reduce_catted_indices(token_sizes=token_sizes) + + return indices + + def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: + return self.transitions, self.head_transitions, self.last_transitions + + def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, + indices: Optional[ReductionIndices] = None, **kwargs): + indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) + transitions, head_transitions, last_transitions = self.obtain_parameters( + emissions=emissions, tags=tags, indices=indices, + ) + + if isinstance(emissions, PackedSequence): + dist = PackedCrfDistribution( + emissions=emissions, indices=indices, + transitions=transitions, + head_transitions=head_transitions, + last_transitions=last_transitions, + ) + return dist, tags + + if isinstance(emissions, CattedSequence): + dist = CattedCrfDistribution( + emissions=emissions, indices=indices, + transitions=transitions, + head_transitions=head_transitions, + last_transitions=last_transitions, + ) + return dist, tags + + raise TypeError(f'{type(emissions)} is not supported.') + + def fit(self, emissions: Sequence, tags: Sequence, + indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: + dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) + + return dist.log_prob(tags=tags) + + def decode(self, emissions: Sequence, + indices: Optional[ReductionIndices] = None, **kwargs) -> Sequence: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) + return dist.argmax + + def marginals(self, emissions: Sequence, + indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) + return dist.marginals + + +class CrfDecoder(CrfDecoderABC): + def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: + super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) + + self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags))) + self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) + self.last_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self, bound: float = 0.01) -> None: + init.uniform_(self.transitions, -bound, +bound) + init.uniform_(self.head_transitions, -bound, +bound) + init.uniform_(self.last_transitions, -bound, +bound) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py new file mode 100644 index 0000000..474f704 --- /dev/null +++ b/torchlatent/crf/catting.py @@ -0,0 +1,137 @@ +from typing import Type + +import torch +from torch import Tensor, autograd +from torch.distributions.utils import lazy_property +from torchrua import CattedSequence +from torchrua import ReductionIndices, head_catted_indices +from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence + +from torchlatent.semiring import Semiring, Log, Max + +__all__ = [ + 'compute_catted_sequence_scores', + 'compute_catted_sequence_partitions', + 'CattedCrfDistribution', +] + + +def compute_catted_sequence_scores(semiring: Type[Semiring]): + def _compute_catted_sequence_scores( + emissions: CattedSequence, tags: CattedSequence, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: + device = transitions.device + + emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] + + h = emissions.token_sizes.size()[0] + t = torch.arange(transitions.size()[0], device=device) # [t] + c = torch.arange(transitions.size()[1], device=device) # [c] + + x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c] + head = head_catted_sequence(tags) # [h, c] + last = last_catted_sequence(tags) # [h, c] + + transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] + transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] + transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] + + head_indices = head_catted_indices(emissions.token_sizes) + transition_scores[head_indices] = transition_head_scores # [h, c] + + batch_ptr = torch.repeat_interleave(emissions.token_sizes) + scores = semiring.mul(emission_scores, transition_scores) + scores = semiring.scatter_mul(scores, index=batch_ptr) + + scores = semiring.mul(scores, transition_last_scores) + + return scores + + return _compute_catted_sequence_scores + + +def compute_catted_sequence_partitions(semiring: Type[Semiring]): + def _compute_catted_sequence_partitions( + emissions: CattedSequence, indices: ReductionIndices, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: + h = emissions.token_sizes.size()[0] + t = torch.arange(transitions.size()[0], device=transitions.device) # [t] + c = torch.arange(transitions.size()[1], device=transitions.device) # [c] + head_indices = head_catted_indices(emissions.token_sizes) + + emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] + emission_scores[head_indices] = eye[None, None, :, :] + emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) + + emission_head_scores = emissions.data[head_indices, :, None, :] + transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] + transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] + + scores = semiring.mul(transition_head_scores, emission_head_scores) + scores = semiring.bmm(scores, emission_scores) + scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] + + return scores + + return _compute_catted_sequence_partitions + + +class CattedCrfDistribution(object): + def __init__(self, emissions: CattedSequence, indices: ReductionIndices, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: + super(CattedCrfDistribution, self).__init__() + self.emissions = emissions + self.indices = indices + + self.transitions = transitions + self.head_transitions = head_transitions + self.last_transitions = last_transitions + + def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor: + return compute_catted_sequence_scores(semiring=semiring)( + emissions=self.emissions, tags=tags, + transitions=self.transitions, + head_transitions=self.head_transitions, + last_transitions=self.last_transitions, + ) + + def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: + return compute_catted_sequence_partitions(semiring=semiring)( + emissions=self.emissions, indices=self.indices, + transitions=self.transitions, + head_transitions=self.head_transitions, + last_transitions=self.last_transitions, + eye=semiring.eye_like(self.transitions), + ) + + def log_prob(self, tags: CattedSequence) -> Tensor: + return self.log_scores(tags=tags) - self.log_partitions + + def log_scores(self, tags: CattedSequence) -> Tensor: + return self.semiring_scores(semiring=Log, tags=tags) + + @lazy_property + def log_partitions(self) -> Tensor: + return self.semiring_partitions(semiring=Log) + + @lazy_property + def marginals(self) -> Tensor: + log_partitions = self.log_partitions + grad, = autograd.grad( + log_partitions, self.emissions.data, torch.ones_like(log_partitions), + create_graph=True, only_inputs=True, allow_unused=False, + ) + return grad + + @lazy_property + def argmax(self) -> CattedSequence: + max_partitions = self.semiring_partitions(semiring=Max) + + grad, = torch.autograd.grad( + max_partitions, self.emissions.data, torch.ones_like(max_partitions), + retain_graph=False, create_graph=False, allow_unused=False, + ) + return CattedSequence( + data=grad.argmax(dim=-1), + token_sizes=self.emissions.token_sizes, + ) diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py new file mode 100644 index 0000000..ec22c38 --- /dev/null +++ b/torchlatent/crf/packing.py @@ -0,0 +1,143 @@ +from typing import Type + +import torch +from torch import Tensor, autograd +from torch.distributions.utils import lazy_property +from torch.nn.utils.rnn import PackedSequence +from torchrua import head_packed_indices, ReductionIndices +from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr + +from torchlatent.semiring import Semiring, Log, Max + +__all__ = [ + 'compute_packed_sequence_scores', + 'compute_packed_sequence_partitions', + 'PackedCrfDistribution', +] + + +def compute_packed_sequence_scores(semiring: Type[Semiring]): + def _compute_packed_sequence_scores( + emissions: PackedSequence, tags: PackedSequence, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: + device = transitions.device + + emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] + + h = emissions.batch_sizes[0].item() + t = torch.arange(transitions.size()[0], device=device) # [t] + c = torch.arange(transitions.size()[1], device=device) # [c] + + x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] + head = head_packed_sequence(tags, unsort=False) # [h, c] + last = last_packed_sequence(tags, unsort=False) # [h, c] + + transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] + transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] + transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] + + indices = head_packed_indices(tags.batch_sizes) + transition_scores[indices] = transition_head_scores # [h, c] + + batch_ptr, _ = major_sizes_to_ptr(sizes=emissions.batch_sizes) + scores = semiring.mul(emission_scores, transition_scores) + scores = semiring.scatter_mul(scores, index=batch_ptr) + + scores = semiring.mul(scores, transition_last_scores) + + if emissions.unsorted_indices is not None: + scores = scores[emissions.unsorted_indices] + + return scores + + return _compute_packed_sequence_scores + + +def compute_packed_sequence_partitions(semiring: Type[Semiring]): + def _compute_packed_sequence_partitions( + emissions: PackedSequence, indices: ReductionIndices, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: + h = emissions.batch_sizes[0].item() + t = torch.arange(transitions.size()[0], device=transitions.device) # [t] + c = torch.arange(transitions.size()[1], device=transitions.device) # [c] + + emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] + emission_scores[:h] = eye[None, None, :, :] + emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) + + emission_head_scores = emissions.data[:h, :, None, :] + transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] + transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] + + scores = semiring.mul(transition_head_scores, emission_head_scores) + scores = semiring.bmm(scores, emission_scores) + scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] + + if emissions.unsorted_indices is not None: + scores = scores[emissions.unsorted_indices] + return scores + + return _compute_packed_sequence_partitions + + +class PackedCrfDistribution(object): + def __init__(self, emissions: PackedSequence, indices: ReductionIndices, + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: + super(PackedCrfDistribution, self).__init__() + self.emissions = emissions + self.indices = indices + + self.transitions = transitions + self.head_transitions = head_transitions + self.last_transitions = last_transitions + + def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: + return compute_packed_sequence_scores(semiring=semiring)( + emissions=self.emissions, tags=tags, + transitions=self.transitions, + head_transitions=self.head_transitions, + last_transitions=self.last_transitions, + ) + + def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: + return compute_packed_sequence_partitions(semiring=semiring)( + emissions=self.emissions, indices=self.indices, + transitions=self.transitions, + head_transitions=self.head_transitions, + last_transitions=self.last_transitions, + eye=semiring.eye_like(self.transitions), + ) + + def log_prob(self, tags: PackedSequence) -> Tensor: + return self.log_scores(tags=tags) - self.log_partitions + + def log_scores(self, tags: PackedSequence) -> Tensor: + return self.semiring_scores(semiring=Log, tags=tags) + + @lazy_property + def log_partitions(self) -> Tensor: + return self.semiring_partitions(semiring=Log) + + @lazy_property + def marginals(self) -> Tensor: + log_partitions = self.log_partitions + grad, = autograd.grad( + log_partitions, self.emissions.data, torch.ones_like(log_partitions), + create_graph=True, only_inputs=True, allow_unused=False, + ) + return grad + + @lazy_property + def argmax(self) -> PackedSequence: + max_partitions = self.semiring_partitions(semiring=Max) + + grad, = torch.autograd.grad( + max_partitions, self.emissions.data, torch.ones_like(max_partitions), + retain_graph=False, create_graph=False, allow_unused=False, + ) + return PackedSequence( + data=grad.argmax(dim=-1), + batch_sizes=self.emissions.batch_sizes, + sorted_indices=self.emissions.sorted_indices, + unsorted_indices=self.emissions.unsorted_indices, + ) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 5c0940a..1f6e44a 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,8 +1,7 @@ import torch from torch import Tensor -from torch.types import Device from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp -from torchrua.tree_reduction import tree_reduce_sequence, TreeReduceIndices +from torchrua.reduction import reduce_sequence, ReductionIndices from torchlatent.functional import logsumexp, logaddexp @@ -17,16 +16,12 @@ class Semiring(object): one: float @classmethod - def eye_like(cls, tensor: Tensor, dtype: torch.dtype = None, device: Device = None) -> Tensor: - if dtype is None: - dtype = tensor.dtype - if device is None: - device = tensor.device - + def eye_like(cls, tensor: Tensor) -> Tensor: *_, n = tensor.size() - eye = torch.full((n, n), fill_value=cls.zero, dtype=dtype, device=device) - idx = torch.arange(n, dtype=torch.long, device=device) - eye[idx, idx] = cls.one + + eye = torch.full((n, n), fill_value=cls.zero, dtype=tensor.dtype, device=tensor.device) + index = torch.arange(n, dtype=torch.long, device=tensor.device) + eye[index, index] = cls.one return eye @classmethod @@ -58,8 +53,8 @@ def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False) @classmethod - def reduce(cls, tensor: Tensor, indices: TreeReduceIndices) -> Tensor: - return tree_reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) + def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor: + return reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) class Std(Semiring):