diff --git a/README.md b/README.md index d2a5f0e6c1..e070bda2ad 100644 --- a/README.md +++ b/README.md @@ -74,8 +74,9 @@ The datasets module currently contains: - Sentiment analysis: SST and IMDb - Question classification: TREC - Entailment: SNLI -- Language modeling: WikiText-2 -- Machine translation: Multi30k, IWSLT, WMT14 +- Language modeling: abstract class + WikiText-2 +- Machine translation: abstract class + Multi30k, IWSLT, WMT14 +- Sequence tagging (e.g. POS/NER): abstract class + UDPOS Others are planned or a work in progress: diff --git a/test/sequence_tagging.py b/test/sequence_tagging.py new file mode 100644 index 0000000000..eff0809288 --- /dev/null +++ b/test/sequence_tagging.py @@ -0,0 +1,47 @@ +from torchtext import data +from torchtext import datasets + +# Define the fields associated with the sequences. +WORD = data.Field(init_token="", eos_token="") +UD_TAG = data.Field(init_token="", eos_token="") + +# Download and the load default data. +train, val, test = datasets.UDPOS.splits( + fields=(('word', WORD), ('udtag', UD_TAG), (None, None))) + +print(train.fields) +print(len(train)) +print(vars(train[0])) + +# We can also define more than two columns. +WORD = data.Field(init_token="", eos_token="") +UD_TAG = data.Field(init_token="", eos_token="") +PTB_TAG = data.Field(init_token="", eos_token="") + +# Load the specified data. +train, val, test = datasets.UDPOS.splits( + fields=(('word', WORD), ('udtag', UD_TAG), ('ptbtag', PTB_TAG)), + path=".data/sequence-labeling/en-ud-v2", + train="en-ud-tag.v2.train.txt", + validation="en-ud-tag.v2.dev.txt", + test="en-ud-tag.v2.test.txt") + +print(train.fields) +print(len(train)) +print(vars(train[0])) + +WORD.build_vocab(train.word, min_freq=3) +UD_TAG.build_vocab(train.udtag) +PTB_TAG.build_vocab(train.ptbtag) + +print(UD_TAG.vocab.freqs) +print(PTB_TAG.vocab.freqs) + +train_iter, val_iter = data.BucketIterator.splits( + (train, val), batch_size=3, device=0) + +batch = next(iter(train_iter)) + +print("words", batch.word) +print("udtags", batch.udtag) +print("ptbtags", batch.ptbtag) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index 43caad3f1e..c6541b79b5 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -2,6 +2,7 @@ from .snli import SNLI from .sst import SST from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA +from .sequence_tagging import SequenceTaggingDataset, UDPOS # NOQA from .trec import TREC from .imdb import IMDB @@ -15,4 +16,6 @@ 'WMT14' 'WikiText2', 'TREC', - 'IMDB'] + 'IMDB', + 'SequenceTaggingDataset', + 'UDPOS'] diff --git a/torchtext/datasets/sequence_tagging.py b/torchtext/datasets/sequence_tagging.py new file mode 100644 index 0000000000..7f1ccef06f --- /dev/null +++ b/torchtext/datasets/sequence_tagging.py @@ -0,0 +1,65 @@ +from .. import data + + +class SequenceTaggingDataset(data.Dataset): + """Defines a dataset for sequence tagging. Examples in this dataset + contain paired lists -- paired list of words and tags. + + For example, in the case of part-of-speech tagging, an example is of the + form + [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT] + + See torchtext/test/sequence_tagging.py on how to use this class. + """ + + @staticmethod + def sort_key(example): + for attr in dir(example): + if not callable(getattr(example, attr)) and \ + not attr.startswith("__"): + return len(getattr(example, attr)) + return 0 + + def __init__(self, path, fields, **kwargs): + examples = [] + columns = [] + + with open(path) as input_file: + for line in input_file: + line = line.strip() + if line == "": + if columns: + examples.append(data.Example.fromlist(columns, fields)) + columns = [] + else: + for i, column in enumerate(line.split("\t")): + if len(columns) < i + 1: + columns.append([]) + columns[i].append(column) + + if columns: + examples.append(data.Example.fromlist(columns, fields)) + super(SequenceTaggingDataset, self).__init__(examples, fields, + **kwargs) + + +class UDPOS(SequenceTaggingDataset): + + # Universal Dependencies English Web Treebank. + # Download original at http://universaldependencies.org/ + # License: http://creativecommons.org/licenses/by-sa/4.0/ + urls = ['https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip'] + dirname = 'en-ud-v2' + name = 'udpos' + + @classmethod + def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt", + validation="en-ud-tag.v2.dev.txt", + test="en-ud-tag.v2.test.txt", **kwargs): + """Downloads and loads the Universal Dependencies Version 2 POS Tagged + data. + """ + + return super(UDPOS, cls).splits( + fields=fields, root=root, train=train, validation=validation, + test=test, **kwargs)