-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e10dab1
commit 6fe7686
Showing
3 changed files
with
140 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
## Data Preparation | ||
You can download the datasets from the [XML repo](http://manikvarma.org/downloads/XC/XMLRepository.html). | ||
|
||
A dataset folder should have the following directory structure. Below we show it for LF-AmazonTitles-131K dataset: | ||
|
||
```bash | ||
📁 LF-AmazonTitles-131K/ | ||
📄 trn_X_Y.txt # contains mappings from train IDs to label IDs | ||
📄 trn_filter_labels.txt # this contains train reciprocal pairs to be ignored in evaluation | ||
📄 tst_X_Y.txt # contains mappings from test IDs to label IDs | ||
📄 tst_filter_labels.txt # this contains test reciprocal pairs to be ignored in evaluation | ||
📄 trn_X.txt # each line contains the raw input train text, this needs to be tokenized | ||
📄 tst_X.txt # each line contains the raw input test text, this needs to be tokenized | ||
📄 Y.txt # each line contains the raw label text, this needs to be tokenized | ||
``` | ||
|
||
To tokenize the raw train, test and label texts, we can use the following command (change the path of the dataset folder accordingly): | ||
```bash | ||
python -W ignore -u utils/CreateTokenizedFiles.py \ | ||
--data-dir xc/Datasets/LF-AmazonTitles-131K \ | ||
--max-length 32 \ | ||
--tokenizer-type bert-base-uncased \ | ||
--tokenize-label-texts | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
For a dataset, create tokenized files in the folder `{tokenizer-type}-{maxlen}` folder inside the `output-path` (default is `data-dir`) folder. | ||
Sample usage: | ||
python -W ignore -u CreateTokenizedFiles.py \ | ||
--data-dir <Dataset-dir> \ | ||
--max-length 32 \ | ||
--tokenizer-type bert-base-uncased \ | ||
--tokenize-label-texts | ||
""" | ||
import torch.multiprocessing as mp | ||
from transformers import AutoTokenizer | ||
import os | ||
import numpy as np | ||
import time | ||
import functools | ||
import argparse | ||
from tqdm import tqdm | ||
|
||
def timeit(func): | ||
"""Print the runtime of the decorated function""" | ||
@functools.wraps(func) | ||
def wrapper_timer(*args, **kwargs): | ||
start_time = time.perf_counter() | ||
value = func(*args, **kwargs) | ||
end_time = time.perf_counter() | ||
run_time = end_time - start_time | ||
print("Finished {} in {:.4f} secs".format(func.__name__, run_time)) | ||
return value | ||
return wrapper_timer | ||
|
||
def _tokenize(batch_input): | ||
tokenizer, max_len, batch_corpus = batch_input[0], batch_input[1], batch_input[2] | ||
temp = tokenizer.batch_encode_plus( | ||
batch_corpus, # Sentence to encode. | ||
add_special_tokens = True, # Add '[CLS]' and '[SEP]' | ||
max_length = max_len, # Pad & truncate all sentences. | ||
padding = 'max_length', | ||
return_attention_mask = True, # Construct attn. masks. | ||
return_tensors = 'np', # Return numpy tensors. | ||
truncation=True | ||
) | ||
|
||
return (temp['input_ids'], temp['attention_mask']) | ||
|
||
def convert(corpus, tokenizer, max_len, num_threads, bsz=100000): | ||
batches = [(tokenizer, max_len, corpus[batch_start: batch_start + bsz]) for batch_start in range(0, len(corpus), bsz)] | ||
|
||
pool = mp.Pool(num_threads) | ||
batch_tokenized = pool.map(_tokenize, batches) | ||
pool.close() | ||
|
||
input_ids = np.vstack([x[0] for x in batch_tokenized]) | ||
attention_mask = np.vstack([x[1] for x in batch_tokenized]) | ||
|
||
del batch_tokenized | ||
|
||
return input_ids, attention_mask | ||
|
||
@timeit | ||
def tokenize_dump_memmap(corpus, tokenization_dir, tokenizer, max_len, prefix, num_threads, batch_size=10000000): | ||
ii = np.memmap(f"{tokenization_dir}/{prefix}_input_ids.dat", dtype='int64', mode='w+', shape=(len(corpus), max_len)) | ||
am = np.memmap(f"{tokenization_dir}/{prefix}_attention_mask.dat", dtype='int64', mode='w+', shape=(len(corpus), max_len)) | ||
for i in tqdm(range(0, len(corpus), batch_size)): | ||
_input_ids, _attention_mask = convert(corpus[i: i + batch_size], tokenizer, max_len, num_threads) | ||
ii[i: i + _input_ids.shape[0], :] = _input_ids | ||
am[i: i + _input_ids.shape[0], :] = _attention_mask | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--data-dir", type=str, required=True, help="Data directory path - with {trn,tst}_X.txt and Y.txt where each line contains train/test and label text respectively") | ||
parser.add_argument("--max-length", type=int, help="Max sequence length for tokenizer", default=32) | ||
parser.add_argument("--tokenizer-type", type=str, help="Tokenizer to use", default="bert-base-uncased") | ||
parser.add_argument("--num-threads", type=int, help="Number of threads to use", default=32) | ||
parser.add_argument("--output-folder", type=str, help="Folder to dump the tokenization to, if not provided, uses data-dir", default="") | ||
parser.add_argument("--tokenize-label-texts", action="store_true", default=False, help="Whether to tokenize label texts or not") | ||
|
||
args = parser.parse_args() | ||
|
||
DATA_DIR = args.data_dir | ||
|
||
trnX = [x.strip() for x in open(f'{DATA_DIR}/trn_X.txt', "r", encoding="utf-8").readlines()] | ||
tstX = [x.strip() for x in open(f'{DATA_DIR}/tst_X.txt', "r", encoding="utf-8").readlines()] | ||
|
||
if args.tokenize_label_texts: | ||
Y = [x.strip() for x in open(f'{DATA_DIR}/Y.txt', "r", encoding="utf-8").readlines()] | ||
print(len(Y)) | ||
print(Y[:10]) | ||
|
||
print(len(trnX), len(tstX)) | ||
print(trnX[:10], tstX[:10]) | ||
|
||
max_len = args.max_length | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_type, do_lower_case=True) | ||
|
||
if len(args.output_folder) == 0: | ||
args.output_folder = args.data_dir | ||
tokenization_dir = f"{args.output_folder}/{args.tokenizer_type}-{max_len}" | ||
os.makedirs(tokenization_dir, exist_ok=True) | ||
|
||
print(f"Dumping files in {tokenization_dir}...") | ||
print("Dumping for trnX...") | ||
tokenize_dump_memmap(trnX, tokenization_dir, tokenizer, max_len, "trn_doc", args.num_threads) | ||
print("Dumping for tstX...") | ||
tokenize_dump_memmap(tstX, tokenization_dir, tokenizer, max_len, "tst_doc", args.num_threads) | ||
|
||
if args.tokenize_label_texts: | ||
print("Dumping for Y...") | ||
tokenize_dump_memmap(Y, tokenization_dir, tokenizer, max_len, "lbl", args.num_threads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters