Skip to content

Commit

Permalink
Merge pull request #12 from codertimo/alpha0.0.1a3
Browse files Browse the repository at this point in the history
alpha0.0.1a3 version update
  • Loading branch information
codertimo authored Oct 20, 2018
2 parents 5b9f139 + b8f27e3 commit 7efd2b5
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 220 deletions.
28 changes: 16 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# BERT-pytorch

[![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/kor2vec/kor2vec/blob/master/LICENSE)
[![LICENSE](https://img.shields.io/github/license/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/blob/master/LICENSE)
![GitHub issues](https://img.shields.io/github/issues/codertimo/BERT-pytorch.svg)
[![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/kor2vec/kor2vec/stargazers)
[![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/kor2vec/kor2vec)
[![GitHub stars](https://img.shields.io/github/stars/codertimo/BERT-pytorch.svg)](https://github.com/codertimo/BERT-pytorch/stargazers)
[![CircleCI](https://circleci.com/gh/codertimo/BERT-pytorch.svg?style=shield)](https://circleci.com/gh/codertimo/BERT-pytorch)
[![PyPI](https://img.shields.io/pypi/v/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/)
[![PyPI - Status](https://img.shields.io/pypi/status/bert-pytorch.svg)](https://pypi.org/project/bert_pytorch/)
[![Documentation Status](https://readthedocs.org/projects/bert-pytorch/badge/?version=latest)](https://bert-pytorch.readthedocs.io/en/latest/?badge=latest)
Expand Down Expand Up @@ -39,24 +39,28 @@ pip install bert-pytorch
## Quickstart

**NOTICE : Your corpus should be prepared with two sentences in one line with tab(\t) separator**

### 0. Prepare your corpus
```
Welcome to the \t the jungle \n
I can stay \t here all night \n
Welcome to the \t the jungle\n
I can stay \t here all night\n
```

### 1. Building vocab based on your corpus
```shell
bert-vocab -c data/corpus.small -o data/corpus.small.vocab
or tokenized corpus (tokenization is not in package)
```
Wel_ _come _to _the \t _the _jungle\n
_I _can _stay \t _here _all _night\n
```


### 2. Building BERT train dataset with your corpus
### 1. Building vocab based on your corpus
```shell
bert-dataset -d data/corpus.small -v data/corpus.small.vocab -o data/dataset.small
bert-vocab -c data/corpus.small -o data/vocab.small
```

### 3. Train your own BERT model
### 2. Train your own BERT model
```shell
bert -d data/dataset.small -v data/corpus.small.vocab -o output/bert.model
bert -c data/dataset.small -v data/vocab.small -o output/bert.model
```

## Language Model Pre-training
Expand Down
67 changes: 67 additions & 0 deletions bert_pytorch/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse

from torch.utils.data import DataLoader

from .model import BERT
from .trainer import BERTTrainer
from .dataset import BERTDataset, WordVocab


def train():
parser = argparse.ArgumentParser()

parser.add_argument("-c", "--train_dataset", required=True, type=str)
parser.add_argument("-t", "--test_dataset", type=str, default=None)
parser.add_argument("-v", "--vocab_path", required=True, type=str)
parser.add_argument("-o", "--output_path", required=True, type=str)

parser.add_argument("-hs", "--hidden", type=int, default=256)
parser.add_argument("-l", "--layers", type=int, default=8)
parser.add_argument("-a", "--attn_heads", type=int, default=8)
parser.add_argument("-s", "--seq_len", type=int, default=20)

parser.add_argument("-b", "--batch_size", type=int, default=64)
parser.add_argument("-e", "--epochs", type=int, default=10)
parser.add_argument("-w", "--num_workers", type=int, default=5)
parser.add_argument("--with_cuda", type=bool, default=True)
parser.add_argument("--log_freq", type=int, default=10)
parser.add_argument("--corpus_lines", type=int, default=None)

parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)

args = parser.parse_args()

print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", len(vocab))

print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)

print("Loading Test Dataset", args.test_dataset)
test_dataset = BERTDataset(args.test_dataset, vocab,
seq_len=args.seq_len) if args.test_dataset is not None else None

print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None

print("Building BERT model")
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)

print("Creating BERT Trainer")
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, log_freq=args.log_freq)

print("Training Start")
for epoch in range(args.epochs):
trainer.train(epoch)
trainer.save(epoch, args.output_path)

if test_data_loader is not None:
trainer.test(epoch)
41 changes: 0 additions & 41 deletions bert_pytorch/build_dataset.py

This file was deleted.

19 changes: 0 additions & 19 deletions bert_pytorch/build_vocab.py

This file was deleted.

1 change: 0 additions & 1 deletion bert_pytorch/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .dataset import BERTDataset
from .creator import BERTDatasetCreator
from .vocab import WordVocab
61 changes: 0 additions & 61 deletions bert_pytorch/dataset/creator.py

This file was deleted.

58 changes: 46 additions & 12 deletions bert_pytorch/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from torch.utils.data import Dataset
import tqdm
import torch
import random


class BERTDataset(Dataset):
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
self.vocab = vocab
self.seq_len = seq_len

self.datas = []
with open(corpus_path, "r", encoding=encoding) as f:
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
t1, t2, t1_l, t2_l, is_next = line[:-1].split("\t")
t1, t2 = [[int(token) for token in t.split(",")] for t in [t1, t2]]
t1_l, t2_l = [[int(token) for token in label.split(",")] for label in [t1_l, t2_l]]
is_next = int(is_next)
self.datas.append({"t1": t1, "t2": t2, "t1_label": t1_l, "t2_label": t2_l, "is_next": is_next})
self.datas = [line[:-1].split("\t")
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]

def __len__(self):
return len(self.datas)

def __getitem__(self, item):
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
t1_random, t1_label = self.random_word(t1)
t2_random, t2_label = self.random_word(t2)

# [CLS] tag = SOS tag, [SEP] tag = EOS tag
t1 = [self.vocab.sos_index] + self.datas[item]["t1"] + [self.vocab.eos_index]
t2 = self.datas[item]["t2"] + [self.vocab.eos_index]
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
t2 = t2_random + [self.vocab.eos_index]

t1_label = [0] + self.datas[item]["t1_label"] + [0]
t2_label = self.datas[item]["t2_label"] + [0]
t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
t2_label = t2_label + [self.vocab.pad_index]

segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len]
Expand All @@ -38,6 +38,40 @@ def __getitem__(self, item):
output = {"bert_input": bert_input,
"bert_label": bert_label,
"segment_label": segment_label,
"is_next": self.datas[item]["is_next"]}
"is_next": is_next_label}

return {key: torch.tensor(value) for key, value in output.items()}

def random_word(self, sentence):
tokens = sentence.split()
output_label = []

for i, token in enumerate(tokens):
prob = random.random()
if prob < 0.15:
# 80% randomly change token to make token
if prob < prob * 0.8:
tokens[i] = self.vocab.mask_index

# 10% randomly change token to random token
elif prob * 0.8 <= prob < prob * 0.9:
tokens[i] = random.randrange(len(self.vocab))

# 10% randomly change token to current token
elif prob >= prob * 0.9:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)

output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))

else:
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
output_label.append(0)

return tokens, output_label

def random_sent(self, index):
# output_text, label(isNotNext:0, isNext:1)
if random.random() > 0.5:
return self.datas[index][1], 1
else:
return self.datas[random.randrange(len(self.datas))][1], 0
18 changes: 18 additions & 0 deletions bert_pytorch/dataset/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,21 @@ def from_seq(self, seq, join=False, with_pad=False):
def load_vocab(vocab_path: str) -> 'WordVocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)


def build():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--corpus_path", required=True, type=str)
parser.add_argument("-o", "--output_path", required=True, type=str)
parser.add_argument("-s", "--vocab_size", type=int, default=None)
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
parser.add_argument("-m", "--min_freq", type=int, default=1)
args = parser.parse_args()

with open(args.corpus_path, "r", encoding=args.encoding) as f:
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)

print("VOCAB SIZE:", len(vocab))
vocab.save_vocab(args.output_path)
Loading

0 comments on commit 7efd2b5

Please sign in to comment.