Skip to content

Commit

Permalink
Merge branch 'release/v0.4.2'
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 3, 2022
2 parents a033a95 + 1508825 commit 630afcf
Show file tree
Hide file tree
Showing 18 changed files with 747 additions and 407 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@

## Requirements

- Python 3.7
- PyTorch 1.6.0
- Python 3.8
- PyTorch 1.10.2

## Installation

`python3 -m pip torchlatent`

## Performance

```
TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497
Third (0.232487) => 0.103277 0.129209 0.145311
```

## Usage

```python
Expand Down
Empty file added benchmark/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions benchmark/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from aku import Aku

from benchmark.crf import benchmark_crf

aku = Aku()

aku.option(benchmark_crf)

aku.run()
62 changes: 62 additions & 0 deletions benchmark/crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from torchrua import pack_sequence
from tqdm import tqdm

from benchmark.meter import TimeMeter
from tests.third_party import ThirdPartyCrfDecoder
from torchlatent.crf import CrfDecoder


def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100,
batch_size: int = 32, max_token_size: int = 512):
j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter()
j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter()

if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print(f'device => {device}')

decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device)
print(f'decoder => {decoder}')

third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device)
print(f'third_decoder => {third_decoder}')

for _ in tqdm(range(num_runs)):
token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist()

emissions = pack_sequence([
torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True)
for token_size in token_sizes
])

tags = pack_sequence([
torch.randint(0, num_tags, (token_size, num_conjugates), device=device)
for token_size in token_sizes
])

with j1:
indices = decoder.compile_indices(emissions=emissions, tags=tags)

with f1:
loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean()

with b1:
_, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss))

with d1:
_ = decoder.decode(emissions=emissions, indices=indices)

with f2:
loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean()

with b2:
_, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss))

with d2:
_ = third_decoder.decode(emissions=emissions)

print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}')
print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}')
23 changes: 23 additions & 0 deletions benchmark/meter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import datetime


class TimeMeter(object):
def __init__(self) -> None:
super(TimeMeter, self).__init__()

self.seconds = 0
self.counts = 0

def __enter__(self):
self.start_tm = datetime.now()

def __exit__(self, exc_type, exc_val, exc_tb):
self.seconds += (datetime.now() - self.start_tm).total_seconds()
self.counts += 1

@property
def merit(self) -> float:
return self.seconds / max(1, self.counts)

def __repr__(self) -> str:
return f'{self.merit :.6f}'
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

setup(
name=name,
version='0.4.1',
version='0.4.2',
packages=[package for package in find_packages() if package.startswith(name)],
url='https://github.com/speedcell4/torchlatent',
license='MIT',
author='speedcell4',
author_email='speedcell4@gmail.com',
description='High Performance Structured Prediction in PyTorch',
python_requires='>=3.7',
python_requires='>=3.8',
install_requires=[
'numpy',
'torchrua>=0.3.0',
'torchrua>=0.4.0',
],
extras_require={
'dev': [
Expand Down
57 changes: 14 additions & 43 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@

from hypothesis import strategies as st

if torch.cuda.is_available():
MAX_BATCH_SIZE = 120
MAX_TOKEN_SIZE = 512
MAX_NUM_TAGS = 100
MAX_NUM_CONJUGATES = 16
else:
MAX_BATCH_SIZE = 12
MAX_TOKEN_SIZE = 24
MAX_NUM_TAGS = 12
MAX_NUM_CONJUGATES = 6

TINY_BATCH_SIZE = 6
TINY_TOKEN_SIZE = 12

BATCH_SIZE = 24
TOKEN_SIZE = 50
NUM_TAGS = 8
NUM_CONJUGATES = 5


@st.composite
def devices(draw):
Expand All @@ -28,36 +22,13 @@ def devices(draw):


@st.composite
def batch_sizes(draw, max_value: int = MAX_BATCH_SIZE):
return draw(st.integers(min_value=1, max_value=max_value))


@st.composite
def batch_size_lists(draw, max_batch_size: int = MAX_BATCH_SIZE):
return [
draw(batch_sizes(max_value=max_batch_size))
for _ in range(draw(batch_sizes(max_value=max_batch_size)))
]


@st.composite
def token_sizes(draw, max_value: int = MAX_TOKEN_SIZE):
return draw(st.integers(min_value=1, max_value=max_value))

def sizes(draw, *size: int, min_size: int = 1):
max_size, *size = size

@st.composite
def token_size_lists(draw, max_token_size: int = MAX_TOKEN_SIZE, max_batch_size: int = MAX_BATCH_SIZE):
return [
draw(token_sizes(max_value=max_token_size))
for _ in range(draw(batch_sizes(max_value=max_batch_size)))
]


@st.composite
def tag_sizes(draw, max_value: int = MAX_NUM_TAGS):
return draw(st.integers(min_value=1, max_value=max_value))


@st.composite
def conjugate_sizes(draw, max_value: int = MAX_NUM_CONJUGATES):
return draw(st.integers(min_value=1, max_value=max_value))
if len(size) == 0:
return draw(st.integers(min_value=min_size, max_value=max_size))
else:
return [
draw(sizes(*size, min_size=min_size))
for _ in range(draw(st.integers(min_value=min_size, max_value=max_size)))
]
Loading

0 comments on commit 630afcf

Please sign in to comment.