Skip to content

Commit

Permalink
Sequence Labeling Dataset (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
sivareddyg authored and jekbradbury committed Dec 23, 2017
1 parent 08adbbf commit 7a2e442
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 3 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ The datasets module currently contains:
- Sentiment analysis: SST and IMDb
- Question classification: TREC
- Entailment: SNLI
- Language modeling: WikiText-2
- Machine translation: Multi30k, IWSLT, WMT14
- Language modeling: abstract class + WikiText-2
- Machine translation: abstract class + Multi30k, IWSLT, WMT14
- Sequence tagging (e.g. POS/NER): abstract class + UDPOS

Others are planned or a work in progress:

Expand Down
47 changes: 47 additions & 0 deletions test/sequence_tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torchtext import data
from torchtext import datasets

# Define the fields associated with the sequences.
WORD = data.Field(init_token="<bos>", eos_token="<eos>")
UD_TAG = data.Field(init_token="<bos>", eos_token="<eos>")

# Download and the load default data.
train, val, test = datasets.UDPOS.splits(
fields=(('word', WORD), ('udtag', UD_TAG), (None, None)))

print(train.fields)
print(len(train))
print(vars(train[0]))

# We can also define more than two columns.
WORD = data.Field(init_token="<bos>", eos_token="<eos>")
UD_TAG = data.Field(init_token="<bos>", eos_token="<eos>")
PTB_TAG = data.Field(init_token="<bos>", eos_token="<eos>")

# Load the specified data.
train, val, test = datasets.UDPOS.splits(
fields=(('word', WORD), ('udtag', UD_TAG), ('ptbtag', PTB_TAG)),
path=".data/sequence-labeling/en-ud-v2",
train="en-ud-tag.v2.train.txt",
validation="en-ud-tag.v2.dev.txt",
test="en-ud-tag.v2.test.txt")

print(train.fields)
print(len(train))
print(vars(train[0]))

WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)
PTB_TAG.build_vocab(train.ptbtag)

print(UD_TAG.vocab.freqs)
print(PTB_TAG.vocab.freqs)

train_iter, val_iter = data.BucketIterator.splits(
(train, val), batch_size=3, device=0)

batch = next(iter(train_iter))

print("words", batch.word)
print("udtags", batch.udtag)
print("ptbtags", batch.ptbtag)
5 changes: 4 additions & 1 deletion torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .snli import SNLI
from .sst import SST
from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA
from .sequence_tagging import SequenceTaggingDataset, UDPOS # NOQA
from .trec import TREC
from .imdb import IMDB

Expand All @@ -15,4 +16,6 @@
'WMT14'
'WikiText2',
'TREC',
'IMDB']
'IMDB',
'SequenceTaggingDataset',
'UDPOS']
65 changes: 65 additions & 0 deletions torchtext/datasets/sequence_tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from .. import data


class SequenceTaggingDataset(data.Dataset):
"""Defines a dataset for sequence tagging. Examples in this dataset
contain paired lists -- paired list of words and tags.
For example, in the case of part-of-speech tagging, an example is of the
form
[I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]
See torchtext/test/sequence_tagging.py on how to use this class.
"""

@staticmethod
def sort_key(example):
for attr in dir(example):
if not callable(getattr(example, attr)) and \
not attr.startswith("__"):
return len(getattr(example, attr))
return 0

def __init__(self, path, fields, **kwargs):
examples = []
columns = []

with open(path) as input_file:
for line in input_file:
line = line.strip()
if line == "":
if columns:
examples.append(data.Example.fromlist(columns, fields))
columns = []
else:
for i, column in enumerate(line.split("\t")):
if len(columns) < i + 1:
columns.append([])
columns[i].append(column)

if columns:
examples.append(data.Example.fromlist(columns, fields))
super(SequenceTaggingDataset, self).__init__(examples, fields,
**kwargs)


class UDPOS(SequenceTaggingDataset):

# Universal Dependencies English Web Treebank.
# Download original at http://universaldependencies.org/
# License: http://creativecommons.org/licenses/by-sa/4.0/
urls = ['https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip']
dirname = 'en-ud-v2'
name = 'udpos'

@classmethod
def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt",
validation="en-ud-tag.v2.dev.txt",
test="en-ud-tag.v2.test.txt", **kwargs):
"""Downloads and loads the Universal Dependencies Version 2 POS Tagged
data.
"""

return super(UDPOS, cls).splits(
fields=fields, root=root, train=train, validation=validation,
test=test, **kwargs)

0 comments on commit 7a2e442

Please sign in to comment.