From 8c71964ae3b7b700b1f8b9af947ae08d30ffe364 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 30 Sep 2021 21:05:50 +0900 Subject: [PATCH 01/18] Chore: Add cliff.toml configuration for for git-cliff --- cliff.toml | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 cliff.toml diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 0000000..0d345e2 --- /dev/null +++ b/cliff.toml @@ -0,0 +1,52 @@ +# configuration file for git-cliff (0.1.0) + +[changelog] +# changelog header +header = """ +# Changelog +All notable changes to this project will be documented in this file.\n +""" +# template for the changelog body +# https://tera.netlify.app/docs/#introduction +body = """ +{% if version %}\ + ## [{{ version | replace(from="v", to="") }}] - {{ timestamp | date(format="%Y-%m-%d") }} +{% else %}\ + ## [unreleased] +{% endif %}\ +{% for group, commits in commits | group_by(attribute="group") %} + ### {{ group | upper_first }} + {% for commit in commits %} + - {{ commit.message | upper_first }}\ + {% endfor %} +{% endfor %}\n +""" +# remove the leading and trailing whitespaces from the template +trim = true +# changelog footer +footer = """ + +""" + +[git] +# allow only conventional commits +# https://www.conventionalcommits.org +conventional_commits = true +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^Feat*", group = "Features" }, + { message = "^Fix*", group = "Bug Fixes" }, + { message = "^Doc*", group = "Documentation" }, + { message = "^Perf*", group = "Performance" }, + { message = "^Refactor*", group = "Refactor" }, + { message = "^Style*", group = "Styling" }, + { message = "^Test*", group = "Testing" }, + { message = "^Chore\\(release\\): prepare for*", skip = true }, + { message = "^Chore*", group = "Miscellaneous Tasks" }, +] +# filter out the commits that are not matched by commit parsers +filter_commits = false +# glob pattern for matching git tags +tag_pattern = "v[0-9]*" +# regex for skipping tags +skip_tags = "v0.1.0-beta.1" From 049027022a92ac984a37c55a7d5ce24b21cdc2f6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 22 Nov 2021 14:26:44 +0900 Subject: [PATCH 02/18] Fix: Impl assert_equal --- tests/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index a7958f5..aef47dc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,17 +3,20 @@ import torch from torch import Tensor from torch.nn.utils.rnn import PackedSequence -from torch.testing import assert_close, assert_equal +from torch.testing import assert_close __all__ = [ 'assert_close', - 'assert_equal', 'assert_grad_close', 'assert_packed_close', 'assert_packed_equal', ] +def assert_equal(actual: Tensor, expected: Tensor, **kwargs) -> None: + assert torch.equal(actual, expected) + + def assert_grad_close(actual: Tensor, expected: Tensor, inputs: Tuple[Tensor, ...]) -> None: grad = torch.randn_like(actual) From 9a7b7d054671328d0022f1b76de745a79538e0a8 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 22 Nov 2021 14:41:49 +0900 Subject: [PATCH 03/18] Refactor: Add compute_packed_sequence_scores and compute_packed_sequence_partitions --- torchlatent/crf.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index fe8bfc2..8f359fa 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -13,15 +13,15 @@ from torchlatent.semiring import Semiring, Log, Max __all__ = [ - 'compute_scores', - 'compute_partitions', + 'compute_packed_sequence_scores', + 'compute_packed_sequence_partitions', 'CrfDistribution', 'CrfDecoderABC', 'CrfDecoder', ] -def compute_scores(semiring: Type[Semiring]): - def _compute_scores( +def compute_packed_sequence_scores(semiring: Type[Semiring]): + def _compute_packed_sequence_scores( emissions: PackedSequence, tags: PackedSequence, transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: device = transitions.device @@ -43,45 +43,43 @@ def _compute_scores( transition_scores[:h] = transition_head_scores # [h, c] _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.batch_sizes) - scores = semiring.mul( - semiring.scatter_mul(semiring.mul(emission_scores, transition_scores), index=batch_ptr), - transition_tail_scores, - ) + scores = semiring.mul(emission_scores, transition_scores) + scores = semiring.scatter_mul(scores, index=batch_ptr) + scores = semiring.mul(scores, transition_tail_scores) if emissions.unsorted_indices is not None: scores = scores[emissions.unsorted_indices] return scores - return _compute_scores + return _compute_packed_sequence_scores -def compute_partitions(semiring: Type[Semiring]): - def _compute_partitions( +def compute_packed_sequence_partitions(semiring: Type[Semiring]): + def _compute_packed_sequence_partitions( emissions: PackedSequence, indices: TreeReduceIndices, transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.batch_sizes[0].item() t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - scores[:h] = eye[None, None, :, :] - scores = semiring.reduce(tensor=scores, indices=indices) + emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] + emission_scores[:h] = eye[None, None, :, :] + emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) emission_head_scores = emissions.data[:h, :, None, :] transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] - scores = semiring.bmm( - semiring.bmm(semiring.mul(transition_head_scores, emission_head_scores), scores), - transition_tail_scores, - )[..., 0, 0] + scores = semiring.mul(transition_head_scores, emission_head_scores) + scores = semiring.bmm(scores, emission_scores) + scores = semiring.bmm(scores, transition_tail_scores)[..., 0, 0] if emissions.unsorted_indices is not None: scores = scores[emissions.unsorted_indices] return scores - return _compute_partitions + return _compute_packed_sequence_partitions class CrfDistribution(object): @@ -96,7 +94,7 @@ def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, self.tail_transitions = tail_transitions def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: - return compute_scores(semiring=semiring)( + return compute_packed_sequence_scores(semiring=semiring)( emissions=self.emissions, tags=tags, transitions=self.transitions, head_transitions=self.head_transitions, @@ -104,7 +102,7 @@ def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Ten ) def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_partitions(semiring=semiring)( + return compute_packed_sequence_partitions(semiring=semiring)( emissions=self.emissions, indices=self.indices, transitions=self.transitions, head_transitions=self.head_transitions, From 1b0194f5a9d3b14d67bb4b0274e9128f101f603f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 22 Nov 2021 14:47:07 +0900 Subject: [PATCH 04/18] Refactor: Separate packing.py --- torchlatent/crf/__init__.py | 106 ++++++++++++++++++++++++ torchlatent/{crf.py => crf/packing.py} | 108 ++----------------------- 2 files changed, 112 insertions(+), 102 deletions(-) create mode 100644 torchlatent/crf/__init__.py rename torchlatent/{crf.py => crf/packing.py} (58%) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py new file mode 100644 index 0000000..757b5c7 --- /dev/null +++ b/torchlatent/crf/__init__.py @@ -0,0 +1,106 @@ +from abc import ABCMeta +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch import nn +from torch.nn import init +from torch.nn.utils.rnn import PackedSequence +from torchrua import TreeReduceIndices, tree_reduce_packed_indices + +from torchlatent.crf.packing import PackedCrfDistribution + +__all__ = { + 'CrfDecoderABC', 'CrfDecoder', +} + + +class CrfDecoderABC(nn.Module, metaclass=ABCMeta): + def __init__(self, num_tags: int, num_conjugates: int): + super(CrfDecoderABC, self).__init__() + + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + def reset_parameters(self) -> None: + raise NotImplementedError + + def extra_repr(self) -> str: + return ', '.join([ + f'num_tags={self.num_tags}', + f'num_conjugates={self.num_conjugates}', + ]) + + def compile_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, + indices: Optional[TreeReduceIndices] = None, **kwargs): + assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' + if tags is not None: + assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' + + if indices is None: + batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) + indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) + + return indices + + def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: + return self.transitions, self.head_transitions, self.tail_transitions + + def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, + indices: Optional[TreeReduceIndices] = None, **kwargs): + indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) + transitions, head_transitions, tail_transitions = self.obtain_parameters( + emissions=emissions, tags=tags, indices=indices, + ) + + dist = PackedCrfDistribution( + emissions=emissions, indices=indices, + transitions=transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, + ) + + return dist, tags + + def fit(self, emissions: PackedSequence, tags: PackedSequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) + + return dist.log_prob(tags=tags) + + def decode(self, emissions: PackedSequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> PackedSequence: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) + return dist.argmax + + def marginals(self, emissions: PackedSequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) + return dist.marginals + + +class CrfDecoder(CrfDecoderABC): + def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: + super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) + + self.transitions = nn.Parameter( + torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), + requires_grad=True, + ) + self.head_transitions = nn.Parameter( + torch.empty((1, self.num_conjugates, self.num_tags)), + requires_grad=True, + ) + self.tail_transitions = nn.Parameter( + torch.empty((1, self.num_conjugates, self.num_tags)), + requires_grad=True, + ) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self) -> None: + bound = 0.01 + init.uniform_(self.transitions, -bound, +bound) + init.uniform_(self.head_transitions, -bound, +bound) + init.uniform_(self.tail_transitions, -bound, +bound) diff --git a/torchlatent/crf.py b/torchlatent/crf/packing.py similarity index 58% rename from torchlatent/crf.py rename to torchlatent/crf/packing.py index 8f359fa..08b87bc 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf/packing.py @@ -1,22 +1,17 @@ -from abc import ABCMeta -from typing import Optional, Type, Tuple +from typing import Type import torch -from torch import Tensor -from torch import nn, autograd +from torch import Tensor, autograd from torch.distributions.utils import lazy_property -from torch.nn import init from torch.nn.utils.rnn import PackedSequence -from torchrua import TreeReduceIndices, tree_reduce_packed_indices, batch_sizes_to_ptr -from torchrua import select_head, select_last, roll_packed_sequence +from torchrua import roll_packed_sequence, select_head, select_last, batch_sizes_to_ptr, TreeReduceIndices from torchlatent.semiring import Semiring, Log, Max __all__ = [ 'compute_packed_sequence_scores', 'compute_packed_sequence_partitions', - 'CrfDistribution', - 'CrfDecoderABC', 'CrfDecoder', + 'PackedCrfDistribution', ] @@ -82,10 +77,10 @@ def _compute_packed_sequence_partitions( return _compute_packed_sequence_partitions -class CrfDistribution(object): +class PackedCrfDistribution(object): def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: - super(CrfDistribution, self).__init__() + super(PackedCrfDistribution, self).__init__() self.emissions = emissions self.indices = indices @@ -143,94 +138,3 @@ def argmax(self) -> PackedSequence: sorted_indices=self.emissions.sorted_indices, unsorted_indices=self.emissions.unsorted_indices, ) - - -class CrfDecoderABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int): - super(CrfDecoderABC, self).__init__() - - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - def reset_parameters(self) -> None: - raise NotImplementedError - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) - - def compile_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): - assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' - if tags is not None: - assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' - - if indices is None: - batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) - - return indices - - def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: - return self.transitions, self.head_transitions, self.tail_transitions - - def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): - indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - transitions, head_transitions, tail_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, indices=indices, - ) - - dist = CrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - - return dist, tags - - def fit(self, emissions: PackedSequence, tags: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: - dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) - - return dist.log_prob(tags=tags) - - def decode(self, emissions: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> PackedSequence: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.argmax - - def marginals(self, emissions: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.marginals - - -class CrfDecoder(CrfDecoderABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - - self.transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), - requires_grad=True, - ) - self.head_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - self.tail_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self) -> None: - bound = 0.01 - init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.tail_transitions, -bound, +bound) From 3484a41aeb09f834ceec6097b73f7d21728c31c7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 22 Nov 2021 15:39:23 +0900 Subject: [PATCH 05/18] Feat: Add CattedCrfDistribution --- torchlatent/crf/catting.py | 134 +++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 torchlatent/crf/catting.py diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py new file mode 100644 index 0000000..b1290bf --- /dev/null +++ b/torchlatent/crf/catting.py @@ -0,0 +1,134 @@ +from typing import Type + +import torch +from torch import Tensor, autograd +from torch.distributions.utils import lazy_property +from torchrua import CattedSequence +from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence, batch_sizes_to_ptr, \ + TreeReduceIndices + +from torchlatent.semiring import Semiring, Log, Max + +__all__ = [ + 'compute_catted_sequence_scores', + 'compute_catted_sequence_partitions', + 'CattedCrfDistribution', +] + + +def compute_catted_sequence_scores(semiring: Type[Semiring]): + def _compute_catted_sequence_scores( + emissions: CattedSequence, tags: CattedSequence, + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: + device = transitions.device + + emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] + + h = emissions.token_sizes.size()[0] + t = torch.arange(transitions.size()[0], device=device) # [t] + c = torch.arange(transitions.size()[1], device=device) # [c] + + x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c] + head = head_catted_sequence(tags) # [h, c] + tail = last_catted_sequence(tags) # [h, c] + + transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] + transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] + transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] + + transition_scores[:h] = transition_head_scores # [h, c] + + _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.token_sizes) + scores = semiring.mul(emission_scores, transition_scores) + scores = semiring.scatter_mul(scores, index=batch_ptr) + scores = semiring.mul(scores, transition_tail_scores) + + return scores + + return _compute_catted_sequence_scores + + +def compute_catted_sequence_partitions(semiring: Type[Semiring]): + def _compute_catted_sequence_partitions( + emissions: CattedSequence, indices: TreeReduceIndices, + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: + h = emissions.token_sizes.size()[0] + t = torch.arange(transitions.size()[0], device=transitions.device) # [t] + c = torch.arange(transitions.size()[1], device=transitions.device) # [c] + + emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] + emission_scores[:h] = eye[None, None, :, :] + emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) + + emission_head_scores = emissions.data[:h, :, None, :] + transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] + transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] + + scores = semiring.mul(transition_head_scores, emission_head_scores) + scores = semiring.bmm(scores, emission_scores) + scores = semiring.bmm(scores, transition_tail_scores)[..., 0, 0] + + return scores + + return _compute_catted_sequence_partitions + + +class CattedCrfDistribution(object): + def __init__(self, emissions: CattedSequence, indices: TreeReduceIndices, + transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: + super(CattedCrfDistribution, self).__init__() + self.emissions = emissions + self.indices = indices + + self.transitions = transitions + self.head_transitions = head_transitions + self.tail_transitions = tail_transitions + + def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor: + return compute_catted_sequence_scores(semiring=semiring)( + emissions=self.emissions, tags=tags, + transitions=self.transitions, + head_transitions=self.head_transitions, + tail_transitions=self.tail_transitions, + ) + + def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: + return compute_catted_sequence_partitions(semiring=semiring)( + emissions=self.emissions, indices=self.indices, + transitions=self.transitions, + head_transitions=self.head_transitions, + tail_transitions=self.tail_transitions, + eye=semiring.eye_like(self.transitions), + ) + + def log_prob(self, tags: CattedSequence) -> Tensor: + return self.log_scores(tags=tags) - self.log_partitions + + def log_scores(self, tags: CattedSequence) -> Tensor: + return self.semiring_scores(semiring=Log, tags=tags) + + @lazy_property + def log_partitions(self) -> Tensor: + return self.semiring_partitions(semiring=Log) + + @lazy_property + def marginals(self) -> Tensor: + log_partitions = self.log_partitions + grad, = autograd.grad( + log_partitions, self.emissions.data, torch.ones_like(log_partitions), + create_graph=True, only_inputs=True, allow_unused=False, + ) + return grad + + @lazy_property + def argmax(self) -> CattedSequence: + max_partitions = self.semiring_partitions(semiring=Max) + + grad, = torch.autograd.grad( + max_partitions, self.emissions.data, torch.ones_like(max_partitions), + retain_graph=False, create_graph=False, allow_unused=False, + ) + return CattedSequence( + data=grad.argmax(dim=-1), + token_sizes=self.emissions.token_sizes, + ) From 1cb9ce246bde642ed6b98706de0febef1a89e150 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 22 Nov 2021 15:45:58 +0900 Subject: [PATCH 06/18] Refactor: Update CrfDecoderABC --- torchlatent/crf/__init__.py | 91 ++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 37 deletions(-) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index 757b5c7..841e85c 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -1,22 +1,31 @@ from abc import ABCMeta -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, Type import torch from torch import Tensor from torch import nn from torch.nn import init -from torch.nn.utils.rnn import PackedSequence -from torchrua import TreeReduceIndices, tree_reduce_packed_indices +from torchrua import TreeReduceIndices, PackedSequence, CattedSequence +from torchrua import tree_reduce_packed_indices, tree_reduce_catted_indices +from torchlatent.crf.catting import CattedCrfDistribution from torchlatent.crf.packing import PackedCrfDistribution -__all__ = { +__all__ = [ 'CrfDecoderABC', 'CrfDecoder', -} + 'PackedCrfDistribution', + 'CattedCrfDistribution', + 'Sequence', +] + +Sequence = Union[ + Type[PackedSequence], + Type[CattedSequence], +] class CrfDecoderABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int): + def __init__(self, num_tags: int, num_conjugates: int) -> None: super(CrfDecoderABC, self).__init__() self.num_tags = num_tags @@ -31,49 +40,67 @@ def extra_repr(self) -> str: f'num_conjugates={self.num_conjugates}', ]) - def compile_indices(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, + @staticmethod + def compile_indices(emissions: Sequence, + tags: Optional[Sequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' if tags is not None: assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' if indices is None: - batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - indices = tree_reduce_packed_indices(batch_sizes=batch_sizes) + if isinstance(emissions, PackedSequence): + batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) + return tree_reduce_packed_indices(batch_sizes=batch_sizes) + + if isinstance(emissions, CattedSequence): + token_sizes = emissions.token_sizes.to(device=emissions.data.device) + return tree_reduce_catted_indices(token_sizes=token_sizes) return indices def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: return self.transitions, self.head_transitions, self.tail_transitions - def forward(self, emissions: PackedSequence, tags: Optional[PackedSequence] = None, + def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, ) - dist = PackedCrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - tail_transitions=tail_transitions, - ) - - return dist, tags - - def fit(self, emissions: PackedSequence, tags: PackedSequence, + if isinstance(emissions, PackedSequence): + dist = PackedCrfDistribution( + emissions=emissions, indices=indices, + transitions=transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, + ) + return dist, tags + + if isinstance(emissions, CattedSequence): + dist = CattedCrfDistribution( + emissions=emissions, indices=indices, + transitions=transitions, + head_transitions=head_transitions, + tail_transitions=tail_transitions, + ) + return dist, tags + + raise TypeError(f'{type(emissions)} is not supported.') + + def fit(self, emissions: Sequence, tags: Sequence, indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) return dist.log_prob(tags=tags) - def decode(self, emissions: PackedSequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> PackedSequence: + def decode(self, emissions: Sequence, + indices: Optional[TreeReduceIndices] = None, **kwargs) -> Sequence: dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.argmax - def marginals(self, emissions: PackedSequence, + def marginals(self, emissions: Sequence, indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.marginals @@ -83,24 +110,14 @@ class CrfDecoder(CrfDecoderABC): def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - self.transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)), - requires_grad=True, - ) - self.head_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) - self.tail_transitions = nn.Parameter( - torch.empty((1, self.num_conjugates, self.num_tags)), - requires_grad=True, - ) + self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags))) + self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) + self.tail_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) self.reset_parameters() @torch.no_grad() - def reset_parameters(self) -> None: - bound = 0.01 + def reset_parameters(self, bound: float = 0.01) -> None: init.uniform_(self.transitions, -bound, +bound) init.uniform_(self.head_transitions, -bound, +bound) init.uniform_(self.tail_transitions, -bound, +bound) From 51b03f67fdc405292754edf76d5297c2d91c5fa9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 23 Nov 2021 11:09:01 +0900 Subject: [PATCH 07/18] Fix: Resolve the head selecting bug of CattedSequence --- torchlatent/crf/__init__.py | 7 ++++--- torchlatent/crf/catting.py | 15 +++++++++------ torchlatent/crf/packing.py | 5 ++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index 841e85c..79d70e1 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Optional, Tuple, Union, Type +from typing import Optional, Tuple, Union import torch from torch import Tensor @@ -19,8 +19,8 @@ ] Sequence = Union[ - Type[PackedSequence], - Type[CattedSequence], + PackedSequence, + CattedSequence, ] @@ -65,6 +65,7 @@ def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) + print(f'indices => {indices}') transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, ) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index b1290bf..508bd90 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -4,8 +4,8 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torchrua import CattedSequence -from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence, batch_sizes_to_ptr, \ - TreeReduceIndices +from torchrua import TreeReduceIndices, head_catted_indices +from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence, batch_sizes_to_ptr from torchlatent.semiring import Semiring, Log, Max @@ -36,11 +36,13 @@ def _compute_catted_sequence_scores( transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - transition_scores[:h] = transition_head_scores # [h, c] + head_indices = head_catted_indices(emissions) + transition_scores[head_indices] = transition_head_scores # [h, c] - _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.token_sizes) + batch_ptr, _, _ = batch_sizes_to_ptr(batch_sizes=emissions.token_sizes) scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) + scores = semiring.mul(scores, transition_tail_scores) return scores @@ -55,12 +57,13 @@ def _compute_catted_sequence_partitions( h = emissions.token_sizes.size()[0] t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] + head_indices = head_catted_indices(emissions) emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - emission_scores[:h] = eye[None, None, :, :] + emission_scores[head_indices] = eye[None, None, :, :] emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) - emission_head_scores = emissions.data[:h, :, None, :] + emission_head_scores = emissions.data[head_indices, :, None, :] transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index 08b87bc..7b5325c 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -4,6 +4,7 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence +from torchrua import head_indices from torchrua import roll_packed_sequence, select_head, select_last, batch_sizes_to_ptr, TreeReduceIndices from torchlatent.semiring import Semiring, Log, Max @@ -35,11 +36,13 @@ def _compute_packed_sequence_scores( transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - transition_scores[:h] = transition_head_scores # [h, c] + indices = head_indices(tags, unsort=False) + transition_scores[indices] = transition_head_scores # [h, c] _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.batch_sizes) scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) + scores = semiring.mul(scores, transition_tail_scores) if emissions.unsorted_indices is not None: From c409c500c32ea4f496abf0b6cdb1b5d28da9957d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 23 Nov 2021 11:15:40 +0900 Subject: [PATCH 08/18] Test: Add unit test test_crf_catted_fit --- tests/test_crf.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 226c161..beed167 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -4,7 +4,8 @@ from torch import Tensor from torch import nn from torch.nn.utils.rnn import PackedSequence -from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence +from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence, cat_sequence, cat_packed_sequence, \ + pack_catted_sequence from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes from tests.utils import assert_close, assert_grad_close, assert_packed_equal @@ -81,7 +82,7 @@ def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: num_conjugate=conjugate_sizes(), num_tags=tag_sizes(), ) -def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) for token_size in token_sizes @@ -109,7 +110,7 @@ def test_crf_decoder_fit(device, token_sizes, num_conjugate, num_tags): num_conjugate=conjugate_sizes(), num_tags=tag_sizes(), ) -def test_crf_decoder_decode(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) for token_size in token_sizes @@ -123,3 +124,36 @@ def test_crf_decoder_decode(device, token_sizes, num_conjugate, num_tags): actual = actual_decoder.decode(emissions=emissions) assert_packed_equal(actual=actual, expected=expected) + + +@given( + device=devices(), + token_sizes=token_size_lists(), + num_conjugate=conjugate_sizes(), + num_tags=tag_sizes(), +) +def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): + emissions = [ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ] + tags = [ + torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + for token_size in token_sizes + ] + + packed_emissions = pack_sequence(emissions, device=device) + packed_tags = pack_sequence(tags, device=device) + + catted_emissions = cat_sequence(emissions, device=device) + catted_tags = cat_sequence(tags, device=device) + + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + expected_decoder.reset_parameters_with_(decoder=actual_decoder) + + actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags) + expected = expected_decoder.fit(emissions=packed_emissions, tags=packed_tags) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions)) From b9d837c1571a520e057cf59fa68db1f0c8422d40 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 23 Nov 2021 11:22:05 +0900 Subject: [PATCH 09/18] Test: Add unit test test_crf_catted_decode --- tests/test_crf.py | 33 +++++++++++++++++++++++++++++---- tests/utils.py | 19 ++++++++++++++++--- torchlatent/crf/__init__.py | 1 - 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index beed167..32632af 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -4,8 +4,7 @@ from torch import Tensor from torch import nn from torch.nn.utils.rnn import PackedSequence -from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence, cat_sequence, cat_packed_sequence, \ - pack_catted_sequence +from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence, cat_sequence, pack_catted_sequence from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes from tests.utils import assert_close, assert_grad_close, assert_packed_equal @@ -142,11 +141,11 @@ def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): for token_size in token_sizes ] + catted_emissions = cat_sequence(emissions, device=device) packed_emissions = pack_sequence(emissions, device=device) - packed_tags = pack_sequence(tags, device=device) - catted_emissions = cat_sequence(emissions, device=device) catted_tags = cat_sequence(tags, device=device) + packed_tags = pack_sequence(tags, device=device) actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) @@ -157,3 +156,29 @@ def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): assert_close(actual=actual, expected=expected) assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions)) + + +@given( + device=devices(), + token_sizes=token_size_lists(), + num_conjugate=conjugate_sizes(), + num_tags=tag_sizes(), +) +def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): + emissions = [ + torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence(emissions, device=device) + packed_emissions = pack_sequence(emissions, device=device) + + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + expected_decoder.reset_parameters_with_(decoder=actual_decoder) + + expected = expected_decoder.decode(emissions=packed_emissions) + actual = actual_decoder.decode(emissions=catted_emissions) + actual = pack_catted_sequence(*actual, device=device) + + assert_packed_equal(actual=actual, expected=expected) diff --git a/tests/utils.py b/tests/utils.py index aef47dc..ff59c50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,15 +1,18 @@ -from typing import Tuple +from typing import Tuple, List, Union import torch from torch import Tensor -from torch.nn.utils.rnn import PackedSequence from torch.testing import assert_close +from torchrua import CattedSequence, PackedSequence __all__ = [ 'assert_close', + 'assert_equal', 'assert_grad_close', 'assert_packed_close', 'assert_packed_equal', + 'assert_catted_close', + 'assert_catted_equal', ] @@ -17,7 +20,7 @@ def assert_equal(actual: Tensor, expected: Tensor, **kwargs) -> None: assert torch.equal(actual, expected) -def assert_grad_close(actual: Tensor, expected: Tensor, inputs: Tuple[Tensor, ...]) -> None: +def assert_grad_close(actual: Tensor, expected: Tensor, inputs: Union[List[Tensor], Tuple[Tensor, ...]]) -> None: grad = torch.randn_like(actual) actual_grads = torch.autograd.grad(actual, inputs, grad) @@ -58,3 +61,13 @@ def assert_packed_close(actual: PackedSequence, expected: PackedSequence) -> Non assert expected.unsorted_indices is None else: assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) + + +def assert_catted_equal(actual: CattedSequence, expected: CattedSequence) -> None: + assert_equal(actual.data, expected.data) + assert_equal(actual.token_sizes, expected.token_sizes) + + +def assert_catted_close(actual: CattedSequence, expected: CattedSequence) -> None: + assert_close(actual.data, expected.data) + assert_close(actual.token_sizes, expected.token_sizes) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index 79d70e1..91d5fe8 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -65,7 +65,6 @@ def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - print(f'indices => {indices}') transitions, head_transitions, tail_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, ) From e9c30f31d69efa7ef0abd2ec2a1712269dc63363 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 22:20:01 +0900 Subject: [PATCH 10/18] Refactor: Migrate to new TorchRua --- tests/test_crf.py | 18 +++++++++++++++--- torchlatent/crf/catting.py | 8 ++++---- torchlatent/crf/packing.py | 12 ++++++------ torchlatent/semiring.py | 15 +++++---------- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 32632af..ebf70f8 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -4,13 +4,25 @@ from torch import Tensor from torch import nn from torch.nn.utils.rnn import PackedSequence -from torchrua import pad_packed_sequence, token_sizes_to_mask, pack_sequence, cat_sequence, pack_catted_sequence +from torch.types import Device +from torchrua import pad_packed_sequence, pack_sequence, cat_sequence, pack_catted_sequence, pad_catted_indices from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes from tests.utils import assert_close, assert_grad_close, assert_packed_equal from torchlatent.crf import CrfDecoder +@torch.no_grad() +def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: + if device is None: + device = sizes.device + + size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) + mask = torch.zeros(size, device=device, dtype=torch.bool) + mask[ptr] = True + return mask + + class ThirdPartyCrfDecoder(nn.Module): def __init__(self, num_tags: int, num_conjugates: int) -> None: super(ThirdPartyCrfDecoder, self).__init__() @@ -39,7 +51,7 @@ def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tens emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) tags, _ = pad_packed_sequence(tags, batch_first=False) - mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) log_probs = [] for index in range(num_conjugates): @@ -57,7 +69,7 @@ def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - mask = token_sizes_to_mask(token_sizes=token_sizes, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) predictions = [] for index in range(num_conjugates): diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index 508bd90..629c846 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -5,7 +5,7 @@ from torch.distributions.utils import lazy_property from torchrua import CattedSequence from torchrua import TreeReduceIndices, head_catted_indices -from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence, batch_sizes_to_ptr +from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence from torchlatent.semiring import Semiring, Log, Max @@ -36,10 +36,10 @@ def _compute_catted_sequence_scores( transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - head_indices = head_catted_indices(emissions) + head_indices = head_catted_indices(emissions.token_sizes) transition_scores[head_indices] = transition_head_scores # [h, c] - batch_ptr, _, _ = batch_sizes_to_ptr(batch_sizes=emissions.token_sizes) + batch_ptr = torch.repeat_interleave(emissions.token_sizes) scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) @@ -57,7 +57,7 @@ def _compute_catted_sequence_partitions( h = emissions.token_sizes.size()[0] t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - head_indices = head_catted_indices(emissions) + head_indices = head_catted_indices(emissions.token_sizes) emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] emission_scores[head_indices] = eye[None, None, :, :] diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index 7b5325c..28159ec 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -4,8 +4,8 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence -from torchrua import head_indices -from torchrua import roll_packed_sequence, select_head, select_last, batch_sizes_to_ptr, TreeReduceIndices +from torchrua import head_packed_indices, TreeReduceIndices +from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr from torchlatent.semiring import Semiring, Log, Max @@ -29,17 +29,17 @@ def _compute_packed_sequence_scores( c = torch.arange(transitions.size()[1], device=device) # [c] x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] - head = select_head(tags, unsort=False) # [h, c] - tail = select_last(tags, unsort=False) # [h, c] + head = head_packed_sequence(tags, unsort=False) # [h, c] + tail = last_packed_sequence(tags, unsort=False) # [h, c] transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] - indices = head_indices(tags, unsort=False) + indices = head_packed_indices(tags.batch_sizes) transition_scores[indices] = transition_head_scores # [h, c] - _, batch_ptr, _ = batch_sizes_to_ptr(batch_sizes=emissions.batch_sizes) + batch_ptr, _ = major_sizes_to_ptr(sizes=emissions.batch_sizes) scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 5c0940a..10f20d0 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.types import Device from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp from torchrua.tree_reduction import tree_reduce_sequence, TreeReduceIndices @@ -17,16 +16,12 @@ class Semiring(object): one: float @classmethod - def eye_like(cls, tensor: Tensor, dtype: torch.dtype = None, device: Device = None) -> Tensor: - if dtype is None: - dtype = tensor.dtype - if device is None: - device = tensor.device - + def eye_like(cls, tensor: Tensor) -> Tensor: *_, n = tensor.size() - eye = torch.full((n, n), fill_value=cls.zero, dtype=dtype, device=device) - idx = torch.arange(n, dtype=torch.long, device=device) - eye[idx, idx] = cls.one + + eye = torch.full((n, n), fill_value=cls.zero, dtype=tensor.dtype, device=tensor.device) + index = torch.arange(n, dtype=torch.long, device=tensor.device) + eye[index, index] = cls.one return eye @classmethod From dbaef410b922a71045f5ed7455ad55e0391eee4a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 22:23:10 +0900 Subject: [PATCH 11/18] Refactor: Rename to last_transitions --- tests/test_crf.py | 2 +- torchlatent/crf/__init__.py | 12 ++++++------ torchlatent/crf/catting.py | 22 +++++++++++----------- torchlatent/crf/packing.py | 22 +++++++++++----------- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index ebf70f8..900afa6 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -42,7 +42,7 @@ def reset_parameters_with_(self, decoder: CrfDecoder) -> None: for index in range(self.num_conjugates): self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] - self.decoders[index].end_transitions.data[::] = decoder.tail_transitions[:, index, :] + self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: num_emissions_conjugates = emissions.data.size()[1] diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index 91d5fe8..e67f7f0 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -60,12 +60,12 @@ def compile_indices(emissions: Sequence, return indices def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: - return self.transitions, self.head_transitions, self.tail_transitions + return self.transitions, self.head_transitions, self.last_transitions def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, indices: Optional[TreeReduceIndices] = None, **kwargs): indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - transitions, head_transitions, tail_transitions = self.obtain_parameters( + transitions, head_transitions, last_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, ) @@ -74,7 +74,7 @@ def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, emissions=emissions, indices=indices, transitions=transitions, head_transitions=head_transitions, - tail_transitions=tail_transitions, + last_transitions=last_transitions, ) return dist, tags @@ -83,7 +83,7 @@ def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, emissions=emissions, indices=indices, transitions=transitions, head_transitions=head_transitions, - tail_transitions=tail_transitions, + last_transitions=last_transitions, ) return dist, tags @@ -112,7 +112,7 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags))) self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) - self.tail_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) + self.last_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) self.reset_parameters() @@ -120,4 +120,4 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: def reset_parameters(self, bound: float = 0.01) -> None: init.uniform_(self.transitions, -bound, +bound) init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.tail_transitions, -bound, +bound) + init.uniform_(self.last_transitions, -bound, +bound) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index 629c846..8cff36e 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -19,7 +19,7 @@ def compute_catted_sequence_scores(semiring: Type[Semiring]): def _compute_catted_sequence_scores( emissions: CattedSequence, tags: CattedSequence, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: device = transitions.device emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] @@ -30,11 +30,11 @@ def _compute_catted_sequence_scores( x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c] head = head_catted_sequence(tags) # [h, c] - tail = last_catted_sequence(tags) # [h, c] + last = last_catted_sequence(tags) # [h, c] transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] + transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] head_indices = head_catted_indices(emissions.token_sizes) transition_scores[head_indices] = transition_head_scores # [h, c] @@ -43,7 +43,7 @@ def _compute_catted_sequence_scores( scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) - scores = semiring.mul(scores, transition_tail_scores) + scores = semiring.mul(scores, transition_last_scores) return scores @@ -53,7 +53,7 @@ def _compute_catted_sequence_scores( def compute_catted_sequence_partitions(semiring: Type[Semiring]): def _compute_catted_sequence_partitions( emissions: CattedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.token_sizes.size()[0] t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] @@ -65,11 +65,11 @@ def _compute_catted_sequence_partitions( emission_head_scores = emissions.data[head_indices, :, None, :] transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] + transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] scores = semiring.mul(transition_head_scores, emission_head_scores) scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_tail_scores)[..., 0, 0] + scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] return scores @@ -78,21 +78,21 @@ def _compute_catted_sequence_partitions( class CattedCrfDistribution(object): def __init__(self, emissions: CattedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: super(CattedCrfDistribution, self).__init__() self.emissions = emissions self.indices = indices self.transitions = transitions self.head_transitions = head_transitions - self.tail_transitions = tail_transitions + self.last_transitions = last_transitions def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor: return compute_catted_sequence_scores(semiring=semiring)( emissions=self.emissions, tags=tags, transitions=self.transitions, head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, + last_transitions=self.last_transitions, ) def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: @@ -100,7 +100,7 @@ def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: emissions=self.emissions, indices=self.indices, transitions=self.transitions, head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, + last_transitions=self.last_transitions, eye=semiring.eye_like(self.transitions), ) diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index 28159ec..0ef82ed 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -19,7 +19,7 @@ def compute_packed_sequence_scores(semiring: Type[Semiring]): def _compute_packed_sequence_scores( emissions: PackedSequence, tags: PackedSequence, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> Tensor: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: device = transitions.device emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] @@ -30,11 +30,11 @@ def _compute_packed_sequence_scores( x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] head = head_packed_sequence(tags, unsort=False) # [h, c] - tail = last_packed_sequence(tags, unsort=False) # [h, c] + last = last_packed_sequence(tags, unsort=False) # [h, c] transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], tail] # [h, c] + transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] indices = head_packed_indices(tags.batch_sizes) transition_scores[indices] = transition_head_scores # [h, c] @@ -43,7 +43,7 @@ def _compute_packed_sequence_scores( scores = semiring.mul(emission_scores, transition_scores) scores = semiring.scatter_mul(scores, index=batch_ptr) - scores = semiring.mul(scores, transition_tail_scores) + scores = semiring.mul(scores, transition_last_scores) if emissions.unsorted_indices is not None: scores = scores[emissions.unsorted_indices] @@ -56,7 +56,7 @@ def _compute_packed_sequence_scores( def compute_packed_sequence_partitions(semiring: Type[Semiring]): def _compute_packed_sequence_partitions( emissions: PackedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor, eye: Tensor) -> Tensor: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.batch_sizes[0].item() t = torch.arange(transitions.size()[0], device=transitions.device) # [t] c = torch.arange(transitions.size()[1], device=transitions.device) # [c] @@ -67,11 +67,11 @@ def _compute_packed_sequence_partitions( emission_head_scores = emissions.data[:h, :, None, :] transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_tail_scores = tail_transitions[t[:h, None], c[None, :], :, None] + transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] scores = semiring.mul(transition_head_scores, emission_head_scores) scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_tail_scores)[..., 0, 0] + scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] if emissions.unsorted_indices is not None: scores = scores[emissions.unsorted_indices] @@ -82,21 +82,21 @@ def _compute_packed_sequence_partitions( class PackedCrfDistribution(object): def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, - transitions: Tensor, head_transitions: Tensor, tail_transitions: Tensor) -> None: + transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: super(PackedCrfDistribution, self).__init__() self.emissions = emissions self.indices = indices self.transitions = transitions self.head_transitions = head_transitions - self.tail_transitions = tail_transitions + self.last_transitions = last_transitions def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: return compute_packed_sequence_scores(semiring=semiring)( emissions=self.emissions, tags=tags, transitions=self.transitions, head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, + last_transitions=self.last_transitions, ) def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: @@ -104,7 +104,7 @@ def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: emissions=self.emissions, indices=self.indices, transitions=self.transitions, head_transitions=self.head_transitions, - tail_transitions=self.tail_transitions, + last_transitions=self.last_transitions, eye=semiring.eye_like(self.transitions), ) From f2259990eebadeb9ce35970d3af03cb4a80c99e7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 22:51:12 +0900 Subject: [PATCH 12/18] Refactor: Separate ThirdPartyCrfDecoder --- tests/test_crf.py | 83 ++------------------------------------------ tests/third_party.py | 83 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 81 deletions(-) create mode 100644 tests/third_party.py diff --git a/tests/test_crf.py b/tests/test_crf.py index 900afa6..ee78ebb 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,92 +1,13 @@ import torch -import torchcrf from hypothesis import given -from torch import Tensor -from torch import nn -from torch.nn.utils.rnn import PackedSequence -from torch.types import Device -from torchrua import pad_packed_sequence, pack_sequence, cat_sequence, pack_catted_sequence, pad_catted_indices +from torchrua import pack_sequence, cat_sequence, pack_catted_sequence from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes +from tests.third_party import ThirdPartyCrfDecoder from tests.utils import assert_close, assert_grad_close, assert_packed_equal from torchlatent.crf import CrfDecoder -@torch.no_grad() -def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: - if device is None: - device = sizes.device - - size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) - mask = torch.zeros(size, device=device, dtype=torch.bool) - mask[ptr] = True - return mask - - -class ThirdPartyCrfDecoder(nn.Module): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(ThirdPartyCrfDecoder, self).__init__() - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - self.decoders = nn.ModuleList([ - torchcrf.CRF(num_tags=num_tags, batch_first=False) - for _ in range(num_conjugates) - ]) - - @torch.no_grad() - def reset_parameters_with_(self, decoder: CrfDecoder) -> None: - assert self.num_tags == decoder.num_tags - assert self.num_conjugates == decoder.num_conjugates - - for index in range(self.num_conjugates): - self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] - self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] - self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] - - def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - tags, _ = pad_packed_sequence(tags, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - log_probs = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - tag = tags[:, :, index % num_emissions_conjugates] - - log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) - - return torch.stack(log_probs, dim=-1) - - def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - predictions = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - - prediction = decoder.decode(emissions=emission, mask=mask) - predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) - - return PackedSequence( - torch.stack([prediction.data for prediction in predictions], dim=1), - batch_sizes=predictions[0].batch_sizes, - sorted_indices=predictions[0].sorted_indices, - unsorted_indices=predictions[0].unsorted_indices, - ) - - @given( device=devices(), token_sizes=token_size_lists(), diff --git a/tests/third_party.py b/tests/third_party.py new file mode 100644 index 0000000..6172f91 --- /dev/null +++ b/tests/third_party.py @@ -0,0 +1,83 @@ +import torch +import torchcrf +from torch import Tensor, nn +from torch.nn.utils.rnn import PackedSequence +from torch.types import Device +from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence + +from torchlatent.crf import CrfDecoder + + +@torch.no_grad() +def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: + if device is None: + device = sizes.device + + size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) + mask = torch.zeros(size, device=device, dtype=torch.bool) + mask[ptr] = True + return mask + + +class ThirdPartyCrfDecoder(nn.Module): + def __init__(self, num_tags: int, num_conjugates: int) -> None: + super(ThirdPartyCrfDecoder, self).__init__() + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + self.decoders = nn.ModuleList([ + torchcrf.CRF(num_tags=num_tags, batch_first=False) + for _ in range(num_conjugates) + ]) + + @torch.no_grad() + def reset_parameters_with_(self, decoder: CrfDecoder) -> None: + assert self.num_tags == decoder.num_tags + assert self.num_conjugates == decoder.num_conjugates + + for index in range(self.num_conjugates): + self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] + self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] + self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] + + def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + tags, _ = pad_packed_sequence(tags, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) + + log_probs = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + tag = tags[:, :, index % num_emissions_conjugates] + + log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) + + return torch.stack(log_probs, dim=-1) + + def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: + num_emissions_conjugates = emissions.data.size()[1] + num_decoders_conjugates = self.num_conjugates + num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) + + emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) + mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) + + predictions = [] + for index in range(num_conjugates): + decoder = self.decoders[index % num_decoders_conjugates] + emission = emissions[:, :, index % num_emissions_conjugates] + + prediction = decoder.decode(emissions=emission, mask=mask) + predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) + + return PackedSequence( + torch.stack([prediction.data for prediction in predictions], dim=1), + batch_sizes=predictions[0].batch_sizes, + sorted_indices=predictions[0].sorted_indices, + unsorted_indices=predictions[0].unsorted_indices, + ) From b8e9a2ff7da7413e7318c5d8070859fb4b93e9b4 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 23:12:28 +0900 Subject: [PATCH 13/18] Benchmark: Add benchmark_crf --- benchmark/__init__.py | 0 benchmark/__main__.py | 9 ++++++++ benchmark/crf.py | 50 +++++++++++++++++++++++++++++++++++++++++++ benchmark/meter.py | 19 ++++++++++++++++ 4 files changed, 78 insertions(+) create mode 100644 benchmark/__init__.py create mode 100644 benchmark/__main__.py create mode 100644 benchmark/crf.py create mode 100644 benchmark/meter.py diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/__main__.py b/benchmark/__main__.py new file mode 100644 index 0000000..351d087 --- /dev/null +++ b/benchmark/__main__.py @@ -0,0 +1,9 @@ +from aku import Aku + +from benchmark.crf import benchmark_crf + +aku = Aku() + +aku.option(benchmark_crf) + +aku.run() diff --git a/benchmark/crf.py b/benchmark/crf.py new file mode 100644 index 0000000..0ab929d --- /dev/null +++ b/benchmark/crf.py @@ -0,0 +1,50 @@ +import torch +from torchrua import cat_sequence +from tqdm import tqdm + +from benchmark.meter import TimeMeter +from torchlatent.crf import CrfDecoder + + +def benchmark_crf(num_tags: int = 32, num_conjugates: int = 4, num_runs: int = 100, + batch_size: int = 120, max_token_size: int = 512): + jit_timer, fwd_timer, bwd_timer, dec_timer, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + + if torch.cuda.is_available(): + device = torch.device('cuda:0') + else: + device = torch.device('cpu') + print(f'device => {device}') + + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + print(f'decoder => {decoder}') + + for _ in tqdm(range(num_runs)): + token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() + + emissions = cat_sequence([ + torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ]) + + tags = cat_sequence([ + torch.randint(0, num_tags, (token_size, num_conjugates), device=device) + for token_size in token_sizes + ]) + + with jit_timer: + indices = decoder.compile_indices(emissions=emissions, tags=tags) + + with fwd_timer: + loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean() + + with bwd_timer: + _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + + with dec_timer: + _ = decoder.decode(emissions=emissions, indices=indices) + + print(f'compile => {jit_timer}') + print(f'forward => {fwd_timer}') + print(f'backward => {bwd_timer}') + print(f'decode => {dec_timer}') diff --git a/benchmark/meter.py b/benchmark/meter.py new file mode 100644 index 0000000..f2d0fb0 --- /dev/null +++ b/benchmark/meter.py @@ -0,0 +1,19 @@ +from datetime import datetime + + +class TimeMeter(object): + def __init__(self) -> None: + super(TimeMeter, self).__init__() + + self.seconds = 0 + self.counts = 0 + + def __enter__(self): + self.start_tm = datetime.now() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.seconds += (datetime.now() - self.start_tm).total_seconds() + self.counts += 1 + + def __repr__(self) -> str: + return f'{self.seconds / self.counts:.6f}' From bc666027b35090367934b4561581354f9c9c64cc Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 23:24:30 +0900 Subject: [PATCH 14/18] Benchmark: Add third_decoder --- benchmark/crf.py | 40 ++++++++++++++++++++++++++-------------- benchmark/meter.py | 6 +++++- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/benchmark/crf.py b/benchmark/crf.py index 0ab929d..a74dce0 100644 --- a/benchmark/crf.py +++ b/benchmark/crf.py @@ -1,14 +1,16 @@ import torch -from torchrua import cat_sequence +from torchrua import pack_sequence from tqdm import tqdm from benchmark.meter import TimeMeter +from tests.third_party import ThirdPartyCrfDecoder from torchlatent.crf import CrfDecoder -def benchmark_crf(num_tags: int = 32, num_conjugates: int = 4, num_runs: int = 100, - batch_size: int = 120, max_token_size: int = 512): - jit_timer, fwd_timer, bwd_timer, dec_timer, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() +def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, + batch_size: int = 32, max_token_size: int = 512): + j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() if torch.cuda.is_available(): device = torch.device('cuda:0') @@ -19,32 +21,42 @@ def benchmark_crf(num_tags: int = 32, num_conjugates: int = 4, num_runs: int = 1 decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) print(f'decoder => {decoder}') + third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + print(f'third_decoder => {third_decoder}') + for _ in tqdm(range(num_runs)): token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() - emissions = cat_sequence([ + emissions = pack_sequence([ torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) for token_size in token_sizes ]) - tags = cat_sequence([ + tags = pack_sequence([ torch.randint(0, num_tags, (token_size, num_conjugates), device=device) for token_size in token_sizes ]) - with jit_timer: + with j1: indices = decoder.compile_indices(emissions=emissions, tags=tags) - with fwd_timer: + with f1: loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean() - with bwd_timer: + with b1: _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) - with dec_timer: + with d1: _ = decoder.decode(emissions=emissions, indices=indices) - print(f'compile => {jit_timer}') - print(f'forward => {fwd_timer}') - print(f'backward => {bwd_timer}') - print(f'decode => {dec_timer}') + with f2: + loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean() + + with b2: + _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + + with d2: + _ = third_decoder.decode(emissions=emissions) + + print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}') + print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}') diff --git a/benchmark/meter.py b/benchmark/meter.py index f2d0fb0..64e1c6e 100644 --- a/benchmark/meter.py +++ b/benchmark/meter.py @@ -15,5 +15,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.seconds += (datetime.now() - self.start_tm).total_seconds() self.counts += 1 + @property + def merit(self) -> float: + return self.seconds / max(1, self.counts) + def __repr__(self) -> str: - return f'{self.seconds / self.counts:.6f}' + return f'{self.merit :.6f}' From e09ce154e8a08d88aa7ba9108777fc41e00ac20f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 18 Feb 2022 23:26:29 +0900 Subject: [PATCH 15/18] Doc: Add Performance --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 1200d43..3ac1ea3 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,13 @@ `python3 -m pip torchlatent` +## Performance + +``` +TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497 +Third (0.232487) => 0.103277 0.129209 0.145311 +``` + ## Usage ```python From ddde512952bce144734af1cae3b9376a937ab60e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 4 Mar 2022 00:26:35 +0900 Subject: [PATCH 16/18] Refactor: Migrate to new TorchRua --- torchlatent/crf/__init__.py | 18 +++++++++--------- torchlatent/crf/catting.py | 6 +++--- torchlatent/crf/packing.py | 6 +++--- torchlatent/semiring.py | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index e67f7f0..a126167 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -5,8 +5,8 @@ from torch import Tensor from torch import nn from torch.nn import init -from torchrua import TreeReduceIndices, PackedSequence, CattedSequence -from torchrua import tree_reduce_packed_indices, tree_reduce_catted_indices +from torchrua import ReductionIndices, PackedSequence, CattedSequence +from torchrua import reduce_packed_indices, reduce_catted_indices from torchlatent.crf.catting import CattedCrfDistribution from torchlatent.crf.packing import PackedCrfDistribution @@ -43,7 +43,7 @@ def extra_repr(self) -> str: @staticmethod def compile_indices(emissions: Sequence, tags: Optional[Sequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): + indices: Optional[ReductionIndices] = None, **kwargs): assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' if tags is not None: assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' @@ -51,11 +51,11 @@ def compile_indices(emissions: Sequence, if indices is None: if isinstance(emissions, PackedSequence): batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - return tree_reduce_packed_indices(batch_sizes=batch_sizes) + return reduce_packed_indices(batch_sizes=batch_sizes) if isinstance(emissions, CattedSequence): token_sizes = emissions.token_sizes.to(device=emissions.data.device) - return tree_reduce_catted_indices(token_sizes=token_sizes) + return reduce_catted_indices(token_sizes=token_sizes) return indices @@ -63,7 +63,7 @@ def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: return self.transitions, self.head_transitions, self.last_transitions def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, - indices: Optional[TreeReduceIndices] = None, **kwargs): + indices: Optional[ReductionIndices] = None, **kwargs): indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) transitions, head_transitions, last_transitions = self.obtain_parameters( emissions=emissions, tags=tags, indices=indices, @@ -90,18 +90,18 @@ def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, raise TypeError(f'{type(emissions)} is not supported.') def fit(self, emissions: Sequence, tags: Sequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) return dist.log_prob(tags=tags) def decode(self, emissions: Sequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Sequence: + indices: Optional[ReductionIndices] = None, **kwargs) -> Sequence: dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.argmax def marginals(self, emissions: Sequence, - indices: Optional[TreeReduceIndices] = None, **kwargs) -> Tensor: + indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) return dist.marginals diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index 8cff36e..474f704 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -4,7 +4,7 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torchrua import CattedSequence -from torchrua import TreeReduceIndices, head_catted_indices +from torchrua import ReductionIndices, head_catted_indices from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence from torchlatent.semiring import Semiring, Log, Max @@ -52,7 +52,7 @@ def _compute_catted_sequence_scores( def compute_catted_sequence_partitions(semiring: Type[Semiring]): def _compute_catted_sequence_partitions( - emissions: CattedSequence, indices: TreeReduceIndices, + emissions: CattedSequence, indices: ReductionIndices, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.token_sizes.size()[0] t = torch.arange(transitions.size()[0], device=transitions.device) # [t] @@ -77,7 +77,7 @@ def _compute_catted_sequence_partitions( class CattedCrfDistribution(object): - def __init__(self, emissions: CattedSequence, indices: TreeReduceIndices, + def __init__(self, emissions: CattedSequence, indices: ReductionIndices, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: super(CattedCrfDistribution, self).__init__() self.emissions = emissions diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index 0ef82ed..ec22c38 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -4,7 +4,7 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence -from torchrua import head_packed_indices, TreeReduceIndices +from torchrua import head_packed_indices, ReductionIndices from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr from torchlatent.semiring import Semiring, Log, Max @@ -55,7 +55,7 @@ def _compute_packed_sequence_scores( def compute_packed_sequence_partitions(semiring: Type[Semiring]): def _compute_packed_sequence_partitions( - emissions: PackedSequence, indices: TreeReduceIndices, + emissions: PackedSequence, indices: ReductionIndices, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: h = emissions.batch_sizes[0].item() t = torch.arange(transitions.size()[0], device=transitions.device) # [t] @@ -81,7 +81,7 @@ def _compute_packed_sequence_partitions( class PackedCrfDistribution(object): - def __init__(self, emissions: PackedSequence, indices: TreeReduceIndices, + def __init__(self, emissions: PackedSequence, indices: ReductionIndices, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: super(PackedCrfDistribution, self).__init__() self.emissions = emissions diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 10f20d0..1f6e44a 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp -from torchrua.tree_reduction import tree_reduce_sequence, TreeReduceIndices +from torchrua.reduction import reduce_sequence, ReductionIndices from torchlatent.functional import logsumexp, logaddexp @@ -53,8 +53,8 @@ def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False) @classmethod - def reduce(cls, tensor: Tensor, indices: TreeReduceIndices) -> Tensor: - return tree_reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) + def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor: + return reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) class Std(Semiring): From efec2a5c1f8b453dd0ee84517bddb3d5c0e77747 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 4 Mar 2022 00:30:12 +0900 Subject: [PATCH 17/18] Test: Add sizes --- tests/strategies.py | 57 +++++++---------------------- tests/test_crf.py | 48 ++++++++++++------------ tests/test_functional.py | 6 +-- tests/utils.py | 79 ++++++++++++++++++++++++---------------- 4 files changed, 89 insertions(+), 101 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index 9579c16..785f929 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -2,20 +2,14 @@ from hypothesis import strategies as st -if torch.cuda.is_available(): - MAX_BATCH_SIZE = 120 - MAX_TOKEN_SIZE = 512 - MAX_NUM_TAGS = 100 - MAX_NUM_CONJUGATES = 16 -else: - MAX_BATCH_SIZE = 12 - MAX_TOKEN_SIZE = 24 - MAX_NUM_TAGS = 12 - MAX_NUM_CONJUGATES = 6 - TINY_BATCH_SIZE = 6 TINY_TOKEN_SIZE = 12 +BATCH_SIZE = 24 +TOKEN_SIZE = 50 +NUM_TAGS = 8 +NUM_CONJUGATES = 5 + @st.composite def devices(draw): @@ -28,36 +22,13 @@ def devices(draw): @st.composite -def batch_sizes(draw, max_value: int = MAX_BATCH_SIZE): - return draw(st.integers(min_value=1, max_value=max_value)) - - -@st.composite -def batch_size_lists(draw, max_batch_size: int = MAX_BATCH_SIZE): - return [ - draw(batch_sizes(max_value=max_batch_size)) - for _ in range(draw(batch_sizes(max_value=max_batch_size))) - ] - - -@st.composite -def token_sizes(draw, max_value: int = MAX_TOKEN_SIZE): - return draw(st.integers(min_value=1, max_value=max_value)) - +def sizes(draw, *size: int, min_size: int = 1): + max_size, *size = size -@st.composite -def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: int = MAX_BATCH_SIZE): - return [ - draw(token_sizes(max_value=max_token_size)) - for _ in range(draw(batch_sizes(max_value=max_batch_size))) - ] - - -@st.composite -def tag_sizes(draw, max_value: int = MAX_NUM_TAGS): - return draw(st.integers(min_value=1, max_value=max_value)) - - -@st.composite -def conjugate_sizes(draw, max_value: int = MAX_NUM_CONJUGATES): - return draw(st.integers(min_value=1, max_value=max_value)) + if len(size) == 0: + return draw(st.integers(min_value=min_size, max_value=max_size)) + else: + return [ + draw(sizes(*size, min_size=min_size)) + for _ in range(draw(st.integers(min_value=min_size, max_value=max_size))) + ] diff --git a/tests/test_crf.py b/tests/test_crf.py index ee78ebb..2164a7b 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -2,17 +2,17 @@ from hypothesis import given from torchrua import pack_sequence, cat_sequence, pack_catted_sequence -from tests.strategies import devices, token_size_lists, conjugate_sizes, tag_sizes +from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS from tests.third_party import ThirdPartyCrfDecoder -from tests.utils import assert_close, assert_grad_close, assert_packed_equal +from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal from torchlatent.crf import CrfDecoder @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ @@ -25,8 +25,8 @@ def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): for token_size in token_sizes ], device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) actual = actual_decoder.fit(emissions=emissions, tags=tags) @@ -38,9 +38,9 @@ def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): emissions = pack_sequence([ @@ -48,21 +48,21 @@ def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): for token_size in token_sizes ], device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) expected = expected_decoder.decode(emissions=emissions) actual = actual_decoder.decode(emissions=emissions) - assert_packed_equal(actual=actual, expected=expected) + assert_packed_sequence_equal(actual=actual, expected=expected) @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): emissions = [ @@ -80,8 +80,8 @@ def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): catted_tags = cat_sequence(tags, device=device) packed_tags = pack_sequence(tags, device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags) @@ -93,9 +93,9 @@ def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): @given( device=devices(), - token_sizes=token_size_lists(), - num_conjugate=conjugate_sizes(), - num_tags=tag_sizes(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_conjugate=sizes(NUM_CONJUGATES), + num_tags=sizes(NUM_TAGS), ) def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): emissions = [ @@ -106,12 +106,12 @@ def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): catted_emissions = cat_sequence(emissions, device=device) packed_emissions = pack_sequence(emissions, device=device) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate) + actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) + expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) expected_decoder.reset_parameters_with_(decoder=actual_decoder) expected = expected_decoder.decode(emissions=packed_emissions) actual = actual_decoder.decode(emissions=catted_emissions) actual = pack_catted_sequence(*actual, device=device) - assert_packed_equal(actual=actual, expected=expected) + assert_packed_sequence_equal(actual=actual, expected=expected) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4924882..c12b74a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,14 +1,14 @@ import torch from hypothesis import given, strategies as st -from tests.strategies import devices, token_size_lists, TINY_TOKEN_SIZE, TINY_BATCH_SIZE +from tests.strategies import devices, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE from tests.utils import assert_close, assert_grad_close from torchlatent.functional import logaddexp, logsumexp @given( device=devices(), - token_sizes=token_size_lists(max_token_size=TINY_TOKEN_SIZE, max_batch_size=TINY_BATCH_SIZE) + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) def test_logaddexp(device, token_sizes): x = torch.randn(token_sizes, device=device, requires_grad=True) @@ -24,7 +24,7 @@ def test_logaddexp(device, token_sizes): @given( data=st.data(), device=devices(), - token_sizes=token_size_lists(max_token_size=TINY_TOKEN_SIZE, max_batch_size=TINY_BATCH_SIZE) + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) def test_logsumexp(data, device, token_sizes): tensor = torch.randn(token_sizes, device=device, requires_grad=True) diff --git a/tests/utils.py b/tests/utils.py index ff59c50..3040db4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,40 +1,67 @@ -from typing import Tuple, List, Union +from typing import List, Tuple, Union import torch from torch import Tensor +from torch.nn.utils.rnn import PackedSequence from torch.testing import assert_close -from torchrua import CattedSequence, PackedSequence +from torchrua.catting import CattedSequence __all__ = [ - 'assert_close', - 'assert_equal', - 'assert_grad_close', - 'assert_packed_close', - 'assert_packed_equal', - 'assert_catted_close', - 'assert_catted_equal', + 'assert_equal', 'assert_close', 'assert_grad_close', + 'assert_catted_sequence_equal', 'assert_catted_sequence_close', + 'assert_packed_sequence_equal', 'assert_packed_sequence_close', ] -def assert_equal(actual: Tensor, expected: Tensor, **kwargs) -> None: +def assert_equal(actual: Tensor, expected: Tensor) -> None: assert torch.equal(actual, expected) -def assert_grad_close(actual: Tensor, expected: Tensor, inputs: Union[List[Tensor], Tuple[Tensor, ...]]) -> None: - grad = torch.randn_like(actual) +def assert_grad_close( + actual: Tensor, expected: Tensor, + inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], + allow_unused: bool = False, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - actual_grads = torch.autograd.grad(actual, inputs, grad) - expected_grads = torch.autograd.grad(expected, inputs, grad) + grad = torch.rand_like(actual) + + actual_grads = torch.autograd.grad( + actual, inputs, grad, + create_graph=False, + allow_unused=allow_unused, + ) + + expected_grads = torch.autograd.grad( + expected, inputs, grad, + create_graph=False, + allow_unused=allow_unused, + ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): - if actual_grad is None: - assert expected_grad is None - else: - assert_close(actual=actual_grad, expected=expected_grad) + assert_close(actual=actual_grad, expected=expected_grad, **kwargs) + + +def assert_catted_sequence_close( + actual: CattedSequence, expected: CattedSequence, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + assert_close(actual=actual.data, expected=expected.data, **kwargs) + assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) -def assert_packed_equal(actual: PackedSequence, expected: PackedSequence) -> None: + +def assert_catted_sequence_equal(actual: CattedSequence, expected: CattedSequence) -> None: assert_equal(actual=actual.data, expected=expected.data) + assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) + + +def assert_packed_sequence_close( + actual: PackedSequence, expected: PackedSequence, + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + + assert_close(actual=actual.data, expected=expected.data, **kwargs) assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) if actual.sorted_indices is None: @@ -48,8 +75,8 @@ def assert_packed_equal(actual: PackedSequence, expected: PackedSequence) -> Non assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) -def assert_packed_close(actual: PackedSequence, expected: PackedSequence) -> None: - assert_close(actual=actual.data, expected=expected.data) +def assert_packed_sequence_equal(actual: PackedSequence, expected: PackedSequence) -> None: + assert_equal(actual=actual.data, expected=expected.data) assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) if actual.sorted_indices is None: @@ -61,13 +88,3 @@ def assert_packed_close(actual: PackedSequence, expected: PackedSequence) -> Non assert expected.unsorted_indices is None else: assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) - - -def assert_catted_equal(actual: CattedSequence, expected: CattedSequence) -> None: - assert_equal(actual.data, expected.data) - assert_equal(actual.token_sizes, expected.token_sizes) - - -def assert_catted_close(actual: CattedSequence, expected: CattedSequence) -> None: - assert_close(actual.data, expected.data) - assert_close(actual.token_sizes, expected.token_sizes) From 1508825feeba6cd1d2e0c1cbe7b78ef289f85615 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 4 Mar 2022 00:39:16 +0900 Subject: [PATCH 18/18] Chore: Update version number --- .github/workflows/python-publish.yml | 2 +- .github/workflows/unit-tests.yml | 4 ++-- README.md | 4 ++-- setup.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 2944977..9993084 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 63384e4..db1e5bc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: '3.8' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 3ac1ea3..e6d403b 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ ## Requirements -- Python 3.7 -- PyTorch 1.6.0 +- Python 3.8 +- PyTorch 1.10.2 ## Installation diff --git a/setup.py b/setup.py index b7687ba..6f8a71b 100644 --- a/setup.py +++ b/setup.py @@ -4,17 +4,17 @@ setup( name=name, - version='0.4.1', + version='0.4.2', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT', author='speedcell4', author_email='speedcell4@gmail.com', description='High Performance Structured Prediction in PyTorch', - python_requires='>=3.7', + python_requires='>=3.8', install_requires=[ 'numpy', - 'torchrua>=0.3.0', + 'torchrua>=0.4.0', ], extras_require={ 'dev': [