diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml
new file mode 100644
index 0000000..e7f3d7f
--- /dev/null
+++ b/.github/workflows/publish-package.yml
@@ -0,0 +1,26 @@
+name: publish package
+
+on:
+ release:
+ types: [ created ]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.8'
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip setuptools wheel
+ python -m pip install twine
+ - name: Build and publish
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ python setup.py sdist bdist_wheel
+ twine upload dist/*
diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
deleted file mode 100644
index 9993084..0000000
--- a/.github/workflows/python-publish.yml
+++ /dev/null
@@ -1,31 +0,0 @@
-# This workflows will upload a Python Package using Twine when a release is created
-# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
-
-name: Upload Python Package
-
-on:
- release:
- types: [created]
-
-jobs:
- deploy:
-
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python
- uses: actions/setup-python@v2
- with:
- python-version: '3.8'
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- python -m pip install setuptools wheel twine
- - name: Build and publish
- env:
- TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
- TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
- run: |
- python setup.py sdist bdist_wheel
- twine upload dist/*
diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index db1e5bc..0717269 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -1,23 +1,27 @@
-name: Unit Tests
+name: unit tests
-on: [push]
+on:
+ workflow_dispatch:
+ push:
+ schedule:
+ - cron: "0 21 * * 6"
jobs:
build:
-
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
- - name: Set up Python
- uses: actions/setup-python@v2
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
with:
python-version: '3.8'
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
- python -m pip install torch
- python -m pip install -e '.[dev]'
+ python -m pip install pip --upgrade
+ python -m pip install -r requirements.txt
+ python -m pip install pytest hypothesis torchnyan
+ python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps
+ python -m pip install pytorch-crf torch-struct
- name: Test with pytest
run: |
python -m pytest tests
\ No newline at end of file
diff --git a/README.md b/README.md
index 1186c7b..3c63b53 100644
--- a/README.md
+++ b/README.md
@@ -1,111 +1,24 @@
-# TorchLatent
+
-![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg)
-[![PyPI version](https://badge.fury.io/py/torchlatent.svg)](https://badge.fury.io/py/torchlatent)
-[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua)
+# TorchLatent
-## Requirements
+![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/speedcell4/torchlatent/unit-tests.yml?cacheSeconds=0)
+![PyPI - Version](https://img.shields.io/pypi/v/torchlatent?label=pypi%20version&cacheSeconds=0)
+![PyPI - Downloads](https://img.shields.io/pypi/dm/torchlatent?cacheSeconds=0)
-- Python 3.8
-- PyTorch 1.10.2
+
## Installation
-`python3 -m pip torchlatent`
-
-## Performance
-
-```
-TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497
-Third (0.232487) => 0.103277 0.129209 0.145311
-```
-
-## Usage
-
-```python
-import torch
-from torchrua import pack_sequence
-
-from torchlatent.crf import CrfDecoder
-
-num_tags = 3
-num_conjugates = 1
-
-decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates)
-
-emissions = pack_sequence([
- torch.randn((5, num_conjugates, num_tags), requires_grad=True),
- torch.randn((2, num_conjugates, num_tags), requires_grad=True),
- torch.randn((3, num_conjugates, num_tags), requires_grad=True),
-])
-
-tags = pack_sequence([
- torch.randint(0, num_tags, (5, num_conjugates)),
- torch.randint(0, num_tags, (2, num_conjugates)),
- torch.randint(0, num_tags, (3, num_conjugates)),
-])
-
-print(decoder.fit(emissions=emissions, tags=tags))
-# tensor([[-6.7424],
-# [-5.1288],
-# [-2.7283]], grad_fn=)
-
-print(decoder.decode(emissions=emissions))
-# PackedSequence(data=tensor([[2],
-# [0],
-# [1],
-# [0],
-# [2],
-# [0],
-# [2],
-# [0],
-# [1],
-# [2]]),
-# batch_sizes=tensor([3, 3, 2, 1, 1]),
-# sorted_indices=tensor([0, 2, 1]),
-# unsorted_indices=tensor([0, 2, 1]))
-
-print(decoder.marginals(emissions=emissions))
-# tensor([[[0.1040, 0.1001, 0.7958]],
-#
-# [[0.5736, 0.0784, 0.3479]],
-#
-# [[0.0932, 0.8797, 0.0271]],
-#
-# [[0.6558, 0.0472, 0.2971]],
-#
-# [[0.2740, 0.1109, 0.6152]],
-#
-# [[0.4811, 0.2163, 0.3026]],
-#
-# [[0.2321, 0.3478, 0.4201]],
-#
-# [[0.4987, 0.1986, 0.3027]],
-#
-# [[0.2029, 0.5888, 0.2083]],
-#
-# [[0.2802, 0.2358, 0.4840]]], grad_fn=)
-```
+`python -m pip torchlatent`
## Latent Structures
-- [ ] Conditional Random Fields (CRF)
- - [x] Conjugated
- - [ ] Dynamic Transition Matrix
- - [ ] Second-order
- - [ ] Variant-order
-- [ ] Tree CRF
+- [x] Conditional Random Fields (CRF)
+- [x] Cocke–Kasami-Younger Algorithm (CKY)
+- [ ] Probabilistic Context-free Grammars (CFG)
+- [ ] Connectionist Temporal Classification (CTC)
+- [ ] Recurrent Neural Network Grammars (RNNG)
- [ ] Non-Projective Dependency Tree (Matrix-tree Theorem)
-- [ ] Probabilistic Context-free Grammars (PCFG)
- [ ] Dependency Model with Valence (DMV)
-
-## Citation
-
-```
-@misc{wang2020torchlatent,
- title={TorchLatent: High Performance Structured Prediction in PyTorch},
- author={Yiran Wang},
- year={2020},
- howpublished = "\url{https://github.com/speedcell4/torchlatent}"
-}
-```
+- [ ] Autoregressive Decoding (Beam Search)
diff --git a/benchmark/__init__.py b/benchmark/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/benchmark/__main__.py b/benchmark/__main__.py
deleted file mode 100644
index 351d087..0000000
--- a/benchmark/__main__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from aku import Aku
-
-from benchmark.crf import benchmark_crf
-
-aku = Aku()
-
-aku.option(benchmark_crf)
-
-aku.run()
diff --git a/benchmark/crf.py b/benchmark/crf.py
deleted file mode 100644
index a74dce0..0000000
--- a/benchmark/crf.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import torch
-from torchrua import pack_sequence
-from tqdm import tqdm
-
-from benchmark.meter import TimeMeter
-from tests.third_party import ThirdPartyCrfDecoder
-from torchlatent.crf import CrfDecoder
-
-
-def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100,
- batch_size: int = 32, max_token_size: int = 512):
- j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter()
- j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter()
-
- if torch.cuda.is_available():
- device = torch.device('cuda:0')
- else:
- device = torch.device('cpu')
- print(f'device => {device}')
-
- decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device)
- print(f'decoder => {decoder}')
-
- third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device)
- print(f'third_decoder => {third_decoder}')
-
- for _ in tqdm(range(num_runs)):
- token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist()
-
- emissions = pack_sequence([
- torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True)
- for token_size in token_sizes
- ])
-
- tags = pack_sequence([
- torch.randint(0, num_tags, (token_size, num_conjugates), device=device)
- for token_size in token_sizes
- ])
-
- with j1:
- indices = decoder.compile_indices(emissions=emissions, tags=tags)
-
- with f1:
- loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean()
-
- with b1:
- _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss))
-
- with d1:
- _ = decoder.decode(emissions=emissions, indices=indices)
-
- with f2:
- loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean()
-
- with b2:
- _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss))
-
- with d2:
- _ = third_decoder.decode(emissions=emissions)
-
- print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}')
- print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}')
diff --git a/benchmark/meter.py b/benchmark/meter.py
deleted file mode 100644
index 64e1c6e..0000000
--- a/benchmark/meter.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from datetime import datetime
-
-
-class TimeMeter(object):
- def __init__(self) -> None:
- super(TimeMeter, self).__init__()
-
- self.seconds = 0
- self.counts = 0
-
- def __enter__(self):
- self.start_tm = datetime.now()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.seconds += (datetime.now() - self.start_tm).total_seconds()
- self.counts += 1
-
- @property
- def merit(self) -> float:
- return self.seconds / max(1, self.counts)
-
- def __repr__(self) -> str:
- return f'{self.merit :.6f}'
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..9443809
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,2 @@
+torch
+torchrua
diff --git a/setup.py b/setup.py
index 6f8a71b..ba04b02 100644
--- a/setup.py
+++ b/setup.py
@@ -1,10 +1,16 @@
-from setuptools import setup, find_packages
+from pathlib import Path
+
+from setuptools import find_packages, setup
name = 'torchlatent'
+root_dir = Path(__file__).parent.resolve()
+with (root_dir / 'requirements.txt').open(mode='r', encoding='utf-8') as fp:
+ install_requires = [install_require.strip() for install_require in fp]
+
setup(
name=name,
- version='0.4.2',
+ version='0.4.3',
packages=[package for package in find_packages() if package.startswith(name)],
url='https://github.com/speedcell4/torchlatent',
license='MIT',
@@ -12,16 +18,5 @@
author_email='speedcell4@gmail.com',
description='High Performance Structured Prediction in PyTorch',
python_requires='>=3.8',
- install_requires=[
- 'numpy',
- 'torchrua>=0.4.0',
- ],
- extras_require={
- 'dev': [
- 'einops',
- 'pytest',
- 'hypothesis',
- 'pytorch-crf',
- ],
- }
+ install_requires=install_requires,
)
diff --git a/tests/strategies.py b/tests/strategies.py
deleted file mode 100644
index 785f929..0000000
--- a/tests/strategies.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-
-from hypothesis import strategies as st
-
-TINY_BATCH_SIZE = 6
-TINY_TOKEN_SIZE = 12
-
-BATCH_SIZE = 24
-TOKEN_SIZE = 50
-NUM_TAGS = 8
-NUM_CONJUGATES = 5
-
-
-@st.composite
-def devices(draw):
- if not torch.cuda.is_available():
- device = torch.device('cpu')
- else:
- device = torch.device('cuda:0')
- _ = torch.empty((1,), device=device)
- return device
-
-
-@st.composite
-def sizes(draw, *size: int, min_size: int = 1):
- max_size, *size = size
-
- if len(size) == 0:
- return draw(st.integers(min_value=min_size, max_value=max_size))
- else:
- return [
- draw(sizes(*size, min_size=min_size))
- for _ in range(draw(st.integers(min_value=min_size, max_value=max_size)))
- ]
diff --git a/tests/test_cky.py b/tests/test_cky.py
new file mode 100644
index 0000000..e627e8e
--- /dev/null
+++ b/tests/test_cky.py
@@ -0,0 +1,104 @@
+import torch
+from hypothesis import given, settings, strategies as st
+from torch_struct import TreeCRF
+from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes
+from torchrua import C
+
+from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores
+from torchlatent.semiring import Log
+
+
+def get_argmax(cky):
+ argmax = cky.argmax
+ mask = argmax > 0
+
+ _, t, _, n = mask.size()
+ index = torch.arange(t, device=mask.device)
+ x = torch.masked_select(index[None, :, None, None], mask=mask)
+ y = torch.masked_select(index[None, None, :, None], mask=mask)
+
+ index = torch.arange(n, device=mask.device)
+ z = torch.masked_select(index[None, None, None, :], mask=mask)
+
+ return argmax, x, y, z
+
+
+@settings(deadline=None)
+@given(
+ token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
+ num_targets=sizes(TINY_TOKEN_SIZE),
+ rua_targets=st.sampled_from([C.cat, C.pad, C.pack]),
+)
+def test_cky_scores(token_sizes, num_targets, rua_targets):
+ emissions = torch.randn(
+ (len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
+ device=device, requires_grad=True,
+ )
+ token_sizes = torch.tensor(token_sizes, device=device)
+ expected_cky = TreeCRF(emissions, lengths=token_sizes)
+
+ argmax, x, y, z = get_argmax(expected_cky)
+
+ emissions = torch.randn_like(emissions, requires_grad=True)
+
+ expected_cky = TreeCRF(emissions, lengths=token_sizes)
+ expected = expected_cky.log_prob(argmax) + expected_cky.partition
+
+ targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1)
+ actual = cky_scores(
+ emissions=C(emissions, token_sizes),
+ targets=rua_targets(targets),
+ semiring=Log,
+ )
+
+ assert_close(actual=actual, expected=expected)
+ assert_grad_close(actual=actual, expected=expected, inputs=(emissions,))
+
+
+@settings(deadline=None)
+@given(
+ token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
+ num_targets=sizes(TINY_TOKEN_SIZE),
+)
+def test_cky_partitions(token_sizes, num_targets):
+ emissions = torch.randn(
+ (len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
+ device=device, requires_grad=True,
+ )
+ token_sizes = torch.tensor(token_sizes, device=device)
+
+ expected = TreeCRF(emissions, lengths=token_sizes).partition
+
+ actual_emissions = C(
+ data=emissions.logsumexp(dim=-1),
+ token_sizes=token_sizes,
+ )
+ actual = cky_partitions(actual_emissions, Log)
+
+ assert_close(actual=actual, expected=expected)
+ assert_grad_close(actual=actual, expected=expected, inputs=(emissions,))
+
+
+@settings(deadline=None)
+@given(
+ token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE),
+ num_targets=sizes(TINY_TOKEN_SIZE),
+)
+def test_cky_argmax(token_sizes, num_targets):
+ emissions = torch.randn(
+ (len(token_sizes), max(token_sizes), max(token_sizes), num_targets),
+ device=device, requires_grad=True,
+ )
+ token_sizes = torch.tensor(token_sizes, device=device)
+
+ expected_cky = TreeCRF(emissions, lengths=token_sizes)
+
+ _, x, y, z = get_argmax(expected_cky)
+
+ expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1)
+
+ actual_cky = CkyDecoder(num_targets=num_targets)
+ actual = actual_cky(emissions=C(emissions, token_sizes)).argmax
+
+ for actual, expected in zip(actual.tolist(), expected.tolist()):
+ assert set(map(tuple, actual)) == set(map(tuple, expected))
diff --git a/tests/test_crf.py b/tests/test_crf.py
index 2164a7b..f148245 100644
--- a/tests/test_crf.py
+++ b/tests/test_crf.py
@@ -1,117 +1,111 @@
import torch
-from hypothesis import given
-from torchrua import pack_sequence, cat_sequence, pack_catted_sequence
+from hypothesis import given, settings, strategies as st
+from torchcrf import CRF
+from torchnyan import BATCH_SIZE, TOKEN_SIZE, assert_close, assert_grad_close, assert_sequence_close, device, sizes
+from torchrua import C, D, P
-from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS
-from tests.third_party import ThirdPartyCrfDecoder
-from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal
-from torchlatent.crf import CrfDecoder
+from torchlatent.crf import CrfDecoder, crf_partitions, crf_scores
+from torchlatent.semiring import Log
+@settings(deadline=None)
@given(
- device=devices(),
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
- num_conjugate=sizes(NUM_CONJUGATES),
- num_tags=sizes(NUM_TAGS),
+ num_targets=sizes(TOKEN_SIZE),
+ rua_emissions=st.sampled_from([C.new, D.new, P.new]),
+ rua_targets=st.sampled_from([C.new, D.new, P.new]),
)
-def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags):
- emissions = pack_sequence([
- torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True)
+def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets):
+ inputs = [
+ torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
- ], device=device)
+ ]
- tags = pack_sequence([
- torch.randint(0, num_tags, (token_size, num_conjugate), device=device)
+ targets = [
+ torch.randint(0, num_targets, (token_size,), device=device)
for token_size in token_sizes
- ], device=device)
-
- actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder.reset_parameters_with_(decoder=actual_decoder)
-
- actual = actual_decoder.fit(emissions=emissions, tags=tags)
- expected = expected_decoder.fit(emissions=emissions, tags=tags)
-
- assert_close(actual=actual, expected=expected)
- assert_grad_close(actual=actual, expected=expected, inputs=(emissions.data,))
+ ]
+ expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)
-@given(
- device=devices(),
- token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
- num_conjugate=sizes(NUM_CONJUGATES),
- num_tags=sizes(NUM_TAGS),
-)
-def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags):
- emissions = pack_sequence([
- torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True)
- for token_size in token_sizes
- ], device=device)
+ expected_emissions = D.new(inputs)
+ expected_tags = D.new(targets)
- actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder.reset_parameters_with_(decoder=actual_decoder)
+ expected = expected_crf._compute_score(
+ expected_emissions.data.transpose(0, 1),
+ expected_tags.data.transpose(0, 1),
+ expected_emissions.mask().transpose(0, 1),
+ )
- expected = expected_decoder.decode(emissions=emissions)
- actual = actual_decoder.decode(emissions=emissions)
+ actual = crf_scores(
+ emissions=rua_emissions(inputs),
+ targets=rua_targets(targets),
+ transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
+ semiring=Log,
+ )
- assert_packed_sequence_equal(actual=actual, expected=expected)
+ assert_close(actual=actual, expected=expected)
+ assert_grad_close(actual=actual, expected=expected, inputs=inputs)
+@settings(deadline=None)
@given(
- device=devices(),
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
- num_conjugate=sizes(NUM_CONJUGATES),
- num_tags=sizes(NUM_TAGS),
+ num_targets=sizes(TOKEN_SIZE),
+ rua_emissions=st.sampled_from([C.new, D.new, P.new]),
)
-def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags):
- emissions = [
- torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True)
- for token_size in token_sizes
- ]
- tags = [
- torch.randint(0, num_tags, (token_size, num_conjugate), device=device)
+def test_crf_partitions(token_sizes, num_targets, rua_emissions):
+ inputs = [
+ torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
]
- catted_emissions = cat_sequence(emissions, device=device)
- packed_emissions = pack_sequence(emissions, device=device)
+ expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)
- catted_tags = cat_sequence(tags, device=device)
- packed_tags = pack_sequence(tags, device=device)
+ expected_emissions = D.new(inputs)
- actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder.reset_parameters_with_(decoder=actual_decoder)
+ expected = expected_crf._compute_normalizer(
+ expected_emissions.data.transpose(0, 1),
+ expected_emissions.mask().t(),
+ )
- actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags)
- expected = expected_decoder.fit(emissions=packed_emissions, tags=packed_tags)
+ actual = crf_partitions(
+ emissions=rua_emissions(inputs),
+ transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions),
+ semiring=Log,
+ )
- assert_close(actual=actual, expected=expected)
- assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions))
+ assert_close(actual=actual, expected=expected, rtol=1e-4, atol=1e-4)
+ assert_grad_close(actual=actual, expected=expected, inputs=inputs, rtol=1e-4, atol=1e-4)
+@settings(deadline=None)
@given(
- device=devices(),
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
- num_conjugate=sizes(NUM_CONJUGATES),
- num_tags=sizes(NUM_TAGS),
+ num_targets=sizes(TOKEN_SIZE),
+ rua_emissions=st.sampled_from([C.new, D.new, P.new]),
)
-def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags):
- emissions = [
- torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True)
+def test_crf_argmax(token_sizes, num_targets, rua_emissions):
+ inputs = [
+ torch.randn((token_size, num_targets), device=device, requires_grad=True)
for token_size in token_sizes
]
- catted_emissions = cat_sequence(emissions, device=device)
- packed_emissions = pack_sequence(emissions, device=device)
+ expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device)
+
+ expected_emissions = D.new(inputs)
+
+ expected = expected_crf.decode(
+ expected_emissions.data.transpose(0, 1),
+ expected_emissions.mask().t(),
+ )
+ expected = C.new([torch.tensor(tensor, device=device) for tensor in expected])
- actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device)
- expected_decoder.reset_parameters_with_(decoder=actual_decoder)
+ actual_crf = CrfDecoder(num_targets=num_targets)
+ actual_crf.transitions = expected_crf.transitions
+ actual_crf.head_transitions = expected_crf.start_transitions
+ actual_crf.last_transitions = expected_crf.end_transitions
- expected = expected_decoder.decode(emissions=packed_emissions)
- actual = actual_decoder.decode(emissions=catted_emissions)
- actual = pack_catted_sequence(*actual, device=device)
+ actual = actual_crf(rua_emissions(inputs)).argmax.cat()
- assert_packed_sequence_equal(actual=actual, expected=expected)
+ assert_sequence_close(actual=actual, expected=expected)
diff --git a/tests/test_functional.py b/tests/test_functional.py
index c12b74a..2686b68 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,16 +1,16 @@
import torch
-from hypothesis import given, strategies as st
+from hypothesis import given, settings, strategies as st
+from torchnyan.assertion import assert_close, assert_grad_close
+from torchnyan.strategy import TINY_BATCH_SIZE, TINY_TOKEN_SIZE, device, sizes
-from tests.strategies import devices, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE
-from tests.utils import assert_close, assert_grad_close
from torchlatent.functional import logaddexp, logsumexp
+@settings(deadline=None)
@given(
- device=devices(),
token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE)
)
-def test_logaddexp(device, token_sizes):
+def test_logaddexp(token_sizes):
x = torch.randn(token_sizes, device=device, requires_grad=True)
y = torch.randn(token_sizes, device=device, requires_grad=True)
@@ -21,12 +21,12 @@ def test_logaddexp(device, token_sizes):
assert_grad_close(actual=actual, expected=expected, inputs=(x, y))
+@settings(deadline=None)
@given(
data=st.data(),
- device=devices(),
token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE)
)
-def test_logsumexp(data, device, token_sizes):
+def test_logsumexp(data, token_sizes):
tensor = torch.randn(token_sizes, device=device, requires_grad=True)
dim = data.draw(st.integers(min_value=-len(token_sizes), max_value=len(token_sizes) - 1))
diff --git a/tests/third_party.py b/tests/third_party.py
deleted file mode 100644
index 6172f91..0000000
--- a/tests/third_party.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import torch
-import torchcrf
-from torch import Tensor, nn
-from torch.nn.utils.rnn import PackedSequence
-from torch.types import Device
-from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence
-
-from torchlatent.crf import CrfDecoder
-
-
-@torch.no_grad()
-def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor:
- if device is None:
- device = sizes.device
-
- size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device)
- mask = torch.zeros(size, device=device, dtype=torch.bool)
- mask[ptr] = True
- return mask
-
-
-class ThirdPartyCrfDecoder(nn.Module):
- def __init__(self, num_tags: int, num_conjugates: int) -> None:
- super(ThirdPartyCrfDecoder, self).__init__()
- self.num_tags = num_tags
- self.num_conjugates = num_conjugates
-
- self.decoders = nn.ModuleList([
- torchcrf.CRF(num_tags=num_tags, batch_first=False)
- for _ in range(num_conjugates)
- ])
-
- @torch.no_grad()
- def reset_parameters_with_(self, decoder: CrfDecoder) -> None:
- assert self.num_tags == decoder.num_tags
- assert self.num_conjugates == decoder.num_conjugates
-
- for index in range(self.num_conjugates):
- self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :]
- self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :]
- self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :]
-
- def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor:
- num_emissions_conjugates = emissions.data.size()[1]
- num_decoders_conjugates = self.num_conjugates
- num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates)
-
- emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False)
- tags, _ = pad_packed_sequence(tags, batch_first=False)
- mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False)
-
- log_probs = []
- for index in range(num_conjugates):
- decoder = self.decoders[index % num_decoders_conjugates]
- emission = emissions[:, :, index % num_emissions_conjugates]
- tag = tags[:, :, index % num_emissions_conjugates]
-
- log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none'))
-
- return torch.stack(log_probs, dim=-1)
-
- def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence:
- num_emissions_conjugates = emissions.data.size()[1]
- num_decoders_conjugates = self.num_conjugates
- num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates)
-
- emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False)
- mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False)
-
- predictions = []
- for index in range(num_conjugates):
- decoder = self.decoders[index % num_decoders_conjugates]
- emission = emissions[:, :, index % num_emissions_conjugates]
-
- prediction = decoder.decode(emissions=emission, mask=mask)
- predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device))
-
- return PackedSequence(
- torch.stack([prediction.data for prediction in predictions], dim=1),
- batch_sizes=predictions[0].batch_sizes,
- sorted_indices=predictions[0].sorted_indices,
- unsorted_indices=predictions[0].unsorted_indices,
- )
diff --git a/tests/utils.py b/tests/utils.py
deleted file mode 100644
index 3040db4..0000000
--- a/tests/utils.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from typing import List, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch.nn.utils.rnn import PackedSequence
-from torch.testing import assert_close
-from torchrua.catting import CattedSequence
-
-__all__ = [
- 'assert_equal', 'assert_close', 'assert_grad_close',
- 'assert_catted_sequence_equal', 'assert_catted_sequence_close',
- 'assert_packed_sequence_equal', 'assert_packed_sequence_close',
-]
-
-
-def assert_equal(actual: Tensor, expected: Tensor) -> None:
- assert torch.equal(actual, expected)
-
-
-def assert_grad_close(
- actual: Tensor, expected: Tensor,
- inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
- allow_unused: bool = False,
- check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None:
- kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride)
-
- grad = torch.rand_like(actual)
-
- actual_grads = torch.autograd.grad(
- actual, inputs, grad,
- create_graph=False,
- allow_unused=allow_unused,
- )
-
- expected_grads = torch.autograd.grad(
- expected, inputs, grad,
- create_graph=False,
- allow_unused=allow_unused,
- )
-
- for actual_grad, expected_grad in zip(actual_grads, expected_grads):
- assert_close(actual=actual_grad, expected=expected_grad, **kwargs)
-
-
-def assert_catted_sequence_close(
- actual: CattedSequence, expected: CattedSequence,
- check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None:
- kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride)
-
- assert_close(actual=actual.data, expected=expected.data, **kwargs)
- assert_equal(actual=actual.token_sizes, expected=expected.token_sizes)
-
-
-def assert_catted_sequence_equal(actual: CattedSequence, expected: CattedSequence) -> None:
- assert_equal(actual=actual.data, expected=expected.data)
- assert_equal(actual=actual.token_sizes, expected=expected.token_sizes)
-
-
-def assert_packed_sequence_close(
- actual: PackedSequence, expected: PackedSequence,
- check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None:
- kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride)
-
- assert_close(actual=actual.data, expected=expected.data, **kwargs)
- assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes)
-
- if actual.sorted_indices is None:
- assert expected.sorted_indices is None
- else:
- assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices)
-
- if actual.unsorted_indices is None:
- assert expected.unsorted_indices is None
- else:
- assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices)
-
-
-def assert_packed_sequence_equal(actual: PackedSequence, expected: PackedSequence) -> None:
- assert_equal(actual=actual.data, expected=expected.data)
- assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes)
-
- if actual.sorted_indices is None:
- assert expected.sorted_indices is None
- else:
- assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices)
-
- if actual.unsorted_indices is None:
- assert expected.unsorted_indices is None
- else:
- assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices)
diff --git a/torchlatent/__init__.py b/torchlatent/__init__.py
index e69de29..8203d58 100644
--- a/torchlatent/__init__.py
+++ b/torchlatent/__init__.py
@@ -0,0 +1 @@
+from torchlatent.crf import CrfDecoder, CrfDistribution
diff --git a/torchlatent/abc.py b/torchlatent/abc.py
new file mode 100644
index 0000000..d5e088d
--- /dev/null
+++ b/torchlatent/abc.py
@@ -0,0 +1,60 @@
+from abc import ABCMeta
+from typing import Union
+
+import torch
+import torch.autograd
+from torch import Tensor, nn
+from torch.distributions.utils import lazy_property
+from torchrua import C, D, P
+
+
+class StructuredDistribution(object, metaclass=ABCMeta):
+ def __init__(self, emissions: Union[C, D, P]) -> None:
+ super(StructuredDistribution, self).__init__()
+ self.emissions = emissions
+
+ def log_scores(self, targets: Union[C, D, P]) -> Tensor:
+ raise NotImplementedError
+
+ def log_probs(self, targets: Union[C, D, P]) -> Tensor:
+ return self.log_scores(targets=targets) - self.log_partitions
+
+ @lazy_property
+ def log_partitions(self) -> Tensor:
+ raise NotImplementedError
+
+ @lazy_property
+ def marginals(self) -> Tensor:
+ grad, = torch.autograd.grad(
+ self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions),
+ create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True,
+
+ )
+ return grad
+
+ @lazy_property
+ def max(self) -> Tensor:
+ raise NotImplementedError
+
+ @lazy_property
+ def argmax(self) -> Tensor:
+ grad, = torch.autograd.grad(
+ self.max, self.emissions.data, torch.ones_like(self.max),
+ create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True,
+ )
+ return grad
+
+
+class StructuredDecoder(nn.Module):
+ def __init__(self, *, num_targets: int) -> None:
+ super(StructuredDecoder, self).__init__()
+ self.num_targets = num_targets
+
+ def reset_parameters(self) -> None:
+ pass
+
+ def extra_repr(self) -> str:
+ return f'num_targets={self.num_targets}'
+
+ def forward(self, emissions: Union[C, D, P]) -> StructuredDistribution:
+ raise NotImplementedError
diff --git a/torchlatent/cky.py b/torchlatent/cky.py
new file mode 100644
index 0000000..b8c863d
--- /dev/null
+++ b/torchlatent/cky.py
@@ -0,0 +1,112 @@
+from typing import Tuple, Type, Union
+
+import torch
+from torch import Tensor
+from torch.distributions.utils import lazy_property
+from torchrua import C, D, P
+
+from torchlatent.abc import StructuredDecoder, StructuredDistribution
+from torchlatent.semiring import Log, Max, Semiring
+
+
+def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor:
+ xyz, token_sizes = targets = targets.cat()
+ batch_ptr, _ = targets.ptr()
+
+ emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]]
+ return semiring.segment_prod(emissions, token_sizes)
+
+
+def diag(tensor: Tensor, offset: int) -> Tensor:
+ return tensor.diagonal(offset=offset, dim1=1, dim2=2)
+
+
+def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None:
+ chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score
+
+
+def left(chart: Tensor, offset: int) -> Tensor:
+ b, t, _, *size = chart.size()
+ c, n, m, *stride = chart.stride()
+ return chart.as_strided(
+ size=(b, t - offset, offset, *size),
+ stride=(c, n + m, m, *stride),
+ )
+
+
+def right(chart: Tensor, offset: int) -> Tensor:
+ b, t, _, *size = chart.size()
+ c, n, m, *stride = chart.stride()
+ return chart[:, 1:, offset:].as_strided(
+ size=(b, t - offset, offset, *size),
+ stride=(c, n + m, n, *stride),
+ )
+
+
+def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor:
+ chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False)
+
+ diag_scatter(chart, diag(emissions.data, offset=0), offset=0)
+
+ for w in range(1, chart.size()[1]):
+ score = semiring.sum(semiring.mul(left(chart, offset=w), right(chart, offset=w)), dim=2)
+ diag_scatter(chart, semiring.mul(score, diag(emissions.data, offset=w)), offset=w)
+
+ index = torch.arange(chart.size()[0], dtype=torch.long, device=chart.device)
+ return chart[index, 0, emissions.token_sizes - 1]
+
+
+def masked_select(mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ _, t, _, n = mask.size()
+
+ index = torch.arange(t, device=mask.device)
+ x = torch.masked_select(index[None, :, None, None], mask=mask)
+ y = torch.masked_select(index[None, None, :, None], mask=mask)
+
+ index = torch.arange(n, device=mask.device)
+ z = torch.masked_select(index[None, None, None, :], mask=mask)
+
+ return x, y, z
+
+
+class CkyDistribution(StructuredDistribution):
+ def __init__(self, emissions: C) -> None:
+ super(CkyDistribution, self).__init__(emissions=emissions)
+
+ def log_scores(self, targets: Union[C, D, P]) -> Tensor:
+ return cky_scores(
+ emissions=self.emissions, targets=targets,
+ semiring=Log,
+ )
+
+ @lazy_property
+ def log_partitions(self) -> Tensor:
+ return cky_partitions(
+ emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)),
+ semiring=Log,
+ )
+
+ @lazy_property
+ def max(self) -> Tensor:
+ return cky_partitions(
+ emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)),
+ semiring=Max,
+ )
+
+ @lazy_property
+ def argmax(self) -> C:
+ argmax = super(CkyDistribution, self).argmax
+ x, y, z = masked_select(argmax > 0)
+
+ return C(
+ data=torch.stack([x, y, z], dim=-1),
+ token_sizes=self.emissions.token_sizes * 2 - 1,
+ )
+
+
+class CkyDecoder(StructuredDecoder):
+ def __init__(self, *, num_targets: int) -> None:
+ super(CkyDecoder, self).__init__(num_targets=num_targets)
+
+ def forward(self, emissions: C) -> CkyDistribution:
+ return CkyDistribution(emissions=emissions)
diff --git a/torchlatent/crf.py b/torchlatent/crf.py
new file mode 100644
index 0000000..83d1b09
--- /dev/null
+++ b/torchlatent/crf.py
@@ -0,0 +1,113 @@
+from typing import Tuple, Type, Union
+
+import torch
+from torch import Tensor, nn
+from torch.distributions.utils import lazy_property
+from torch.nn import init
+from torchrua import C, D, P
+
+from torchlatent.abc import StructuredDecoder, StructuredDistribution
+from torchlatent.semiring import Log, Max, Semiring
+
+T = Tuple[Tensor, Tensor, Tensor]
+
+
+def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor:
+ transitions, head_transitions, last_transitions = transitions
+
+ targets = _, token_sizes = targets.cat()
+ head_transitions = targets.head().rua(head_transitions)
+ last_transitions = targets.last().rua(last_transitions)
+ transitions = targets.data.roll(1).rua(transitions, targets)
+
+ emissions, _ = emissions.idx().cat().rua(emissions, targets)
+ emissions = semiring.segment_prod(emissions, sizes=token_sizes)
+
+ token_sizes = torch.stack([torch.ones_like(token_sizes), token_sizes - 1], dim=-1)
+ transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2]
+
+ return semiring.mul(
+ semiring.mul(head_transitions, last_transitions),
+ semiring.mul(emissions, transitions),
+ )
+
+
+def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor:
+ transitions, head_transitions, last_transitions = transitions
+
+ emissions = emissions.pack()
+ last_indices = emissions.idx().last()
+ emissions, batch_sizes, _, _ = emissions
+
+ _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist()
+ emission, *emissions = torch.split(emissions, sections, dim=0)
+
+ charts = [semiring.mul(head_transitions, emission)]
+ for emission, batch_size in zip(emissions, batch_sizes):
+ charts.append(semiring.mul(
+ semiring.bmm(charts[-1][:batch_size], transitions),
+ emission,
+ ))
+
+ emission = torch.cat(charts, dim=0)[last_indices]
+ return semiring.sum(semiring.mul(emission, last_transitions), dim=-1)
+
+
+class CrfDistribution(StructuredDistribution):
+ def __init__(self, emissions: Union[C, D, P], transitions: T) -> None:
+ super(CrfDistribution, self).__init__(emissions=emissions)
+ self.transitions = transitions
+
+ def log_scores(self, targets: Union[C, D, P]) -> Tensor:
+ return crf_scores(
+ emissions=self.emissions, targets=targets,
+ transitions=self.transitions,
+ semiring=Log,
+ )
+
+ @lazy_property
+ def log_partitions(self) -> Tensor:
+ return crf_partitions(
+ emissions=self.emissions,
+ transitions=self.transitions,
+ semiring=Log,
+ )
+
+ @lazy_property
+ def max(self) -> Tensor:
+ return crf_partitions(
+ emissions=self.emissions,
+ transitions=self.transitions,
+ semiring=Max,
+ )
+
+ @lazy_property
+ def argmax(self) -> Union[C, D, P]:
+ argmax = super(CrfDistribution, self).argmax.argmax(dim=-1)
+ return self.emissions._replace(data=argmax)
+
+
+class CrfDecoder(StructuredDecoder):
+ def __init__(self, *, num_targets: int) -> None:
+ super(CrfDecoder, self).__init__(num_targets=num_targets)
+
+ self.transitions = nn.Parameter(torch.empty((num_targets, num_targets)))
+ self.head_transitions = nn.Parameter(torch.empty((num_targets,)))
+ self.last_transitions = nn.Parameter(torch.empty((num_targets,)))
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init.zeros_(self.transitions)
+ init.zeros_(self.head_transitions)
+ init.zeros_(self.last_transitions)
+
+ def forward(self, emissions: Union[C, D, P]) -> CrfDistribution:
+ return CrfDistribution(
+ emissions=emissions,
+ transitions=(
+ self.transitions,
+ self.head_transitions,
+ self.last_transitions,
+ ),
+ )
diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py
deleted file mode 100644
index a126167..0000000
--- a/torchlatent/crf/__init__.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from abc import ABCMeta
-from typing import Optional, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch import nn
-from torch.nn import init
-from torchrua import ReductionIndices, PackedSequence, CattedSequence
-from torchrua import reduce_packed_indices, reduce_catted_indices
-
-from torchlatent.crf.catting import CattedCrfDistribution
-from torchlatent.crf.packing import PackedCrfDistribution
-
-__all__ = [
- 'CrfDecoderABC', 'CrfDecoder',
- 'PackedCrfDistribution',
- 'CattedCrfDistribution',
- 'Sequence',
-]
-
-Sequence = Union[
- PackedSequence,
- CattedSequence,
-]
-
-
-class CrfDecoderABC(nn.Module, metaclass=ABCMeta):
- def __init__(self, num_tags: int, num_conjugates: int) -> None:
- super(CrfDecoderABC, self).__init__()
-
- self.num_tags = num_tags
- self.num_conjugates = num_conjugates
-
- def reset_parameters(self) -> None:
- raise NotImplementedError
-
- def extra_repr(self) -> str:
- return ', '.join([
- f'num_tags={self.num_tags}',
- f'num_conjugates={self.num_conjugates}',
- ])
-
- @staticmethod
- def compile_indices(emissions: Sequence,
- tags: Optional[Sequence] = None,
- indices: Optional[ReductionIndices] = None, **kwargs):
- assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}'
- if tags is not None:
- assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}'
-
- if indices is None:
- if isinstance(emissions, PackedSequence):
- batch_sizes = emissions.batch_sizes.to(device=emissions.data.device)
- return reduce_packed_indices(batch_sizes=batch_sizes)
-
- if isinstance(emissions, CattedSequence):
- token_sizes = emissions.token_sizes.to(device=emissions.data.device)
- return reduce_catted_indices(token_sizes=token_sizes)
-
- return indices
-
- def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
- return self.transitions, self.head_transitions, self.last_transitions
-
- def forward(self, emissions: Sequence, tags: Optional[Sequence] = None,
- indices: Optional[ReductionIndices] = None, **kwargs):
- indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices)
- transitions, head_transitions, last_transitions = self.obtain_parameters(
- emissions=emissions, tags=tags, indices=indices,
- )
-
- if isinstance(emissions, PackedSequence):
- dist = PackedCrfDistribution(
- emissions=emissions, indices=indices,
- transitions=transitions,
- head_transitions=head_transitions,
- last_transitions=last_transitions,
- )
- return dist, tags
-
- if isinstance(emissions, CattedSequence):
- dist = CattedCrfDistribution(
- emissions=emissions, indices=indices,
- transitions=transitions,
- head_transitions=head_transitions,
- last_transitions=last_transitions,
- )
- return dist, tags
-
- raise TypeError(f'{type(emissions)} is not supported.')
-
- def fit(self, emissions: Sequence, tags: Sequence,
- indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor:
- dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs)
-
- return dist.log_prob(tags=tags)
-
- def decode(self, emissions: Sequence,
- indices: Optional[ReductionIndices] = None, **kwargs) -> Sequence:
- dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs)
- return dist.argmax
-
- def marginals(self, emissions: Sequence,
- indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor:
- dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs)
- return dist.marginals
-
-
-class CrfDecoder(CrfDecoderABC):
- def __init__(self, num_tags: int, num_conjugates: int = 1) -> None:
- super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates)
-
- self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags)))
- self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags)))
- self.last_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags)))
-
- self.reset_parameters()
-
- @torch.no_grad()
- def reset_parameters(self, bound: float = 0.01) -> None:
- init.uniform_(self.transitions, -bound, +bound)
- init.uniform_(self.head_transitions, -bound, +bound)
- init.uniform_(self.last_transitions, -bound, +bound)
diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py
deleted file mode 100644
index 474f704..0000000
--- a/torchlatent/crf/catting.py
+++ /dev/null
@@ -1,137 +0,0 @@
-from typing import Type
-
-import torch
-from torch import Tensor, autograd
-from torch.distributions.utils import lazy_property
-from torchrua import CattedSequence
-from torchrua import ReductionIndices, head_catted_indices
-from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence
-
-from torchlatent.semiring import Semiring, Log, Max
-
-__all__ = [
- 'compute_catted_sequence_scores',
- 'compute_catted_sequence_partitions',
- 'CattedCrfDistribution',
-]
-
-
-def compute_catted_sequence_scores(semiring: Type[Semiring]):
- def _compute_catted_sequence_scores(
- emissions: CattedSequence, tags: CattedSequence,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor:
- device = transitions.device
-
- emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c]
-
- h = emissions.token_sizes.size()[0]
- t = torch.arange(transitions.size()[0], device=device) # [t]
- c = torch.arange(transitions.size()[1], device=device) # [c]
-
- x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c]
- head = head_catted_sequence(tags) # [h, c]
- last = last_catted_sequence(tags) # [h, c]
-
- transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c]
- transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c]
- transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c]
-
- head_indices = head_catted_indices(emissions.token_sizes)
- transition_scores[head_indices] = transition_head_scores # [h, c]
-
- batch_ptr = torch.repeat_interleave(emissions.token_sizes)
- scores = semiring.mul(emission_scores, transition_scores)
- scores = semiring.scatter_mul(scores, index=batch_ptr)
-
- scores = semiring.mul(scores, transition_last_scores)
-
- return scores
-
- return _compute_catted_sequence_scores
-
-
-def compute_catted_sequence_partitions(semiring: Type[Semiring]):
- def _compute_catted_sequence_partitions(
- emissions: CattedSequence, indices: ReductionIndices,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor:
- h = emissions.token_sizes.size()[0]
- t = torch.arange(transitions.size()[0], device=transitions.device) # [t]
- c = torch.arange(transitions.size()[1], device=transitions.device) # [c]
- head_indices = head_catted_indices(emissions.token_sizes)
-
- emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n]
- emission_scores[head_indices] = eye[None, None, :, :]
- emission_scores = semiring.reduce(tensor=emission_scores, indices=indices)
-
- emission_head_scores = emissions.data[head_indices, :, None, :]
- transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :]
- transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None]
-
- scores = semiring.mul(transition_head_scores, emission_head_scores)
- scores = semiring.bmm(scores, emission_scores)
- scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0]
-
- return scores
-
- return _compute_catted_sequence_partitions
-
-
-class CattedCrfDistribution(object):
- def __init__(self, emissions: CattedSequence, indices: ReductionIndices,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None:
- super(CattedCrfDistribution, self).__init__()
- self.emissions = emissions
- self.indices = indices
-
- self.transitions = transitions
- self.head_transitions = head_transitions
- self.last_transitions = last_transitions
-
- def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor:
- return compute_catted_sequence_scores(semiring=semiring)(
- emissions=self.emissions, tags=tags,
- transitions=self.transitions,
- head_transitions=self.head_transitions,
- last_transitions=self.last_transitions,
- )
-
- def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor:
- return compute_catted_sequence_partitions(semiring=semiring)(
- emissions=self.emissions, indices=self.indices,
- transitions=self.transitions,
- head_transitions=self.head_transitions,
- last_transitions=self.last_transitions,
- eye=semiring.eye_like(self.transitions),
- )
-
- def log_prob(self, tags: CattedSequence) -> Tensor:
- return self.log_scores(tags=tags) - self.log_partitions
-
- def log_scores(self, tags: CattedSequence) -> Tensor:
- return self.semiring_scores(semiring=Log, tags=tags)
-
- @lazy_property
- def log_partitions(self) -> Tensor:
- return self.semiring_partitions(semiring=Log)
-
- @lazy_property
- def marginals(self) -> Tensor:
- log_partitions = self.log_partitions
- grad, = autograd.grad(
- log_partitions, self.emissions.data, torch.ones_like(log_partitions),
- create_graph=True, only_inputs=True, allow_unused=False,
- )
- return grad
-
- @lazy_property
- def argmax(self) -> CattedSequence:
- max_partitions = self.semiring_partitions(semiring=Max)
-
- grad, = torch.autograd.grad(
- max_partitions, self.emissions.data, torch.ones_like(max_partitions),
- retain_graph=False, create_graph=False, allow_unused=False,
- )
- return CattedSequence(
- data=grad.argmax(dim=-1),
- token_sizes=self.emissions.token_sizes,
- )
diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py
deleted file mode 100644
index ec22c38..0000000
--- a/torchlatent/crf/packing.py
+++ /dev/null
@@ -1,143 +0,0 @@
-from typing import Type
-
-import torch
-from torch import Tensor, autograd
-from torch.distributions.utils import lazy_property
-from torch.nn.utils.rnn import PackedSequence
-from torchrua import head_packed_indices, ReductionIndices
-from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr
-
-from torchlatent.semiring import Semiring, Log, Max
-
-__all__ = [
- 'compute_packed_sequence_scores',
- 'compute_packed_sequence_partitions',
- 'PackedCrfDistribution',
-]
-
-
-def compute_packed_sequence_scores(semiring: Type[Semiring]):
- def _compute_packed_sequence_scores(
- emissions: PackedSequence, tags: PackedSequence,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor:
- device = transitions.device
-
- emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c]
-
- h = emissions.batch_sizes[0].item()
- t = torch.arange(transitions.size()[0], device=device) # [t]
- c = torch.arange(transitions.size()[1], device=device) # [c]
-
- x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c]
- head = head_packed_sequence(tags, unsort=False) # [h, c]
- last = last_packed_sequence(tags, unsort=False) # [h, c]
-
- transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c]
- transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c]
- transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c]
-
- indices = head_packed_indices(tags.batch_sizes)
- transition_scores[indices] = transition_head_scores # [h, c]
-
- batch_ptr, _ = major_sizes_to_ptr(sizes=emissions.batch_sizes)
- scores = semiring.mul(emission_scores, transition_scores)
- scores = semiring.scatter_mul(scores, index=batch_ptr)
-
- scores = semiring.mul(scores, transition_last_scores)
-
- if emissions.unsorted_indices is not None:
- scores = scores[emissions.unsorted_indices]
-
- return scores
-
- return _compute_packed_sequence_scores
-
-
-def compute_packed_sequence_partitions(semiring: Type[Semiring]):
- def _compute_packed_sequence_partitions(
- emissions: PackedSequence, indices: ReductionIndices,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor:
- h = emissions.batch_sizes[0].item()
- t = torch.arange(transitions.size()[0], device=transitions.device) # [t]
- c = torch.arange(transitions.size()[1], device=transitions.device) # [c]
-
- emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n]
- emission_scores[:h] = eye[None, None, :, :]
- emission_scores = semiring.reduce(tensor=emission_scores, indices=indices)
-
- emission_head_scores = emissions.data[:h, :, None, :]
- transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :]
- transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None]
-
- scores = semiring.mul(transition_head_scores, emission_head_scores)
- scores = semiring.bmm(scores, emission_scores)
- scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0]
-
- if emissions.unsorted_indices is not None:
- scores = scores[emissions.unsorted_indices]
- return scores
-
- return _compute_packed_sequence_partitions
-
-
-class PackedCrfDistribution(object):
- def __init__(self, emissions: PackedSequence, indices: ReductionIndices,
- transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None:
- super(PackedCrfDistribution, self).__init__()
- self.emissions = emissions
- self.indices = indices
-
- self.transitions = transitions
- self.head_transitions = head_transitions
- self.last_transitions = last_transitions
-
- def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor:
- return compute_packed_sequence_scores(semiring=semiring)(
- emissions=self.emissions, tags=tags,
- transitions=self.transitions,
- head_transitions=self.head_transitions,
- last_transitions=self.last_transitions,
- )
-
- def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor:
- return compute_packed_sequence_partitions(semiring=semiring)(
- emissions=self.emissions, indices=self.indices,
- transitions=self.transitions,
- head_transitions=self.head_transitions,
- last_transitions=self.last_transitions,
- eye=semiring.eye_like(self.transitions),
- )
-
- def log_prob(self, tags: PackedSequence) -> Tensor:
- return self.log_scores(tags=tags) - self.log_partitions
-
- def log_scores(self, tags: PackedSequence) -> Tensor:
- return self.semiring_scores(semiring=Log, tags=tags)
-
- @lazy_property
- def log_partitions(self) -> Tensor:
- return self.semiring_partitions(semiring=Log)
-
- @lazy_property
- def marginals(self) -> Tensor:
- log_partitions = self.log_partitions
- grad, = autograd.grad(
- log_partitions, self.emissions.data, torch.ones_like(log_partitions),
- create_graph=True, only_inputs=True, allow_unused=False,
- )
- return grad
-
- @lazy_property
- def argmax(self) -> PackedSequence:
- max_partitions = self.semiring_partitions(semiring=Max)
-
- grad, = torch.autograd.grad(
- max_partitions, self.emissions.data, torch.ones_like(max_partitions),
- retain_graph=False, create_graph=False, allow_unused=False,
- )
- return PackedSequence(
- data=grad.argmax(dim=-1),
- batch_sizes=self.emissions.batch_sizes,
- sorted_indices=self.emissions.sorted_indices,
- unsorted_indices=self.emissions.unsorted_indices,
- )
diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py
index 1f6e44a..3f11e6d 100644
--- a/torchlatent/semiring.py
+++ b/torchlatent/semiring.py
@@ -1,13 +1,13 @@
import torch
from torch import Tensor
-from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp
-from torchrua.reduction import reduce_sequence, ReductionIndices
+from torchrua import segment_logsumexp, segment_max, segment_prod, segment_sum
-from torchlatent.functional import logsumexp, logaddexp
+from torchlatent.functional import logaddexp, logsumexp
__all__ = [
- 'Semiring',
- 'Std', 'Log', 'Max',
+ 'Semiring', 'ExceptionSemiring',
+ 'Std', 'Log', 'Max', 'Xen', 'Div',
+
]
@@ -41,21 +41,17 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
raise NotImplementedError
@classmethod
- def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor:
+ def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
raise NotImplementedError
@classmethod
- def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor:
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
raise NotImplementedError
@classmethod
def bmm(cls, x: Tensor, y: Tensor) -> Tensor:
return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False)
- @classmethod
- def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor:
- return reduce_sequence(cls.bmm)(tensor=tensor, indices=indices)
-
class Std(Semiring):
zero = 0.
@@ -78,12 +74,12 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
return torch.prod(tensor, dim=dim, keepdim=keepdim)
@classmethod
- def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_add(tensor=tensor, index=index)
+ def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum(tensor, segment_sizes=sizes)
@classmethod
- def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_mul(tensor=tensor, index=index)
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_prod(tensor, segment_sizes=sizes)
class Log(Semiring):
@@ -107,12 +103,12 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
return torch.sum(tensor, dim=dim, keepdim=keepdim)
@classmethod
- def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_logsumexp(tensor=tensor, index=index)
+ def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_logsumexp(tensor, segment_sizes=sizes)
@classmethod
- def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_add(tensor=tensor, index=index)
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum(tensor, segment_sizes=sizes)
class Max(Semiring):
@@ -136,9 +132,77 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
return torch.sum(tensor, dim=dim, keepdim=keepdim)
@classmethod
- def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_max(tensor=tensor, index=index)
+ def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_max(tensor, segment_sizes=sizes)
+
+ @classmethod
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum(tensor, segment_sizes=sizes)
+
+
+class ExceptionSemiring(Semiring):
+ @classmethod
+ def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+ raise NotImplementedError
+
+ @classmethod
+ def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor:
+ raise NotImplementedError
+
+
+class Xen(ExceptionSemiring):
+ zero = 0.
+ one = 0.
+
+ @classmethod
+ def add(cls, x: Tensor, y: Tensor) -> Tensor:
+ raise NotImplementedError
+
+ @classmethod
+ def mul(cls, x: Tensor, y: Tensor) -> Tensor:
+ return x + y
+
+ @classmethod
+ def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+ return torch.sum((tensor - log_q) * log_p.exp(), dim=dim, keepdim=keepdim)
+
+ @classmethod
+ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+ return torch.sum(tensor, dim=dim, keepdim=keepdim)
+
+ @classmethod
+ def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum((tensor - log_q) * log_p.exp(), segment_sizes=sizes)
+
+ @classmethod
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum(tensor, segment_sizes=sizes)
+
+
+class Div(ExceptionSemiring):
+ zero = 0.
+ one = 0.
+
+ @classmethod
+ def add(cls, x: Tensor, y: Tensor) -> Tensor:
+ raise NotImplementedError
+
+ @classmethod
+ def mul(cls, x: Tensor, y: Tensor) -> Tensor:
+ return x + y
+
+ @classmethod
+ def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+ return torch.sum((tensor - log_q + log_p) * log_p.exp(), dim=dim, keepdim=keepdim)
+
+ @classmethod
+ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
+ return torch.sum(tensor, dim=dim, keepdim=keepdim)
+
+ @classmethod
+ def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum((tensor - log_q + log_p) * log_p.exp(), segment_sizes=sizes)
@classmethod
- def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor:
- return scatter_add(tensor=tensor, index=index)
+ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor:
+ return segment_sum(tensor, segment_sizes=sizes)