Skip to content

Commit

Permalink
✨ Minimal RAM use while handle huge corpus. (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed Mar 16, 2020
1 parent 938a96c commit 03c335a
Show file tree
Hide file tree
Showing 16 changed files with 301 additions and 91 deletions.
87 changes: 87 additions & 0 deletions examples/custom_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: custom_generator.py
# time: 4:13 下午

import os
import linecache
from tensorflow.keras.utils import get_file
from kashgari.generators import ABCGenerator


def download_data(duplicate=1000):
url_list = [
'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.train.w-intent.iob',
'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.dev.w-intent.iob',
'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.test.w-intent.iob',
'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.train.w-intent.iob'
]
files = []
for url in url_list:
files.append(get_file(url.split('/')[-1], url))

return files * duplicate


class ClassificationGenerator:
def __init__(self, files):
self.files = files
self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files)

@property
def steps(self) -> int:
return self._line_count

def __iter__(self):
for file in self.files:
with open(file, 'r') as f:
for line in f:
rows = line.split('\t')
x = rows[0].strip().split(' ')[1:-1]
y = rows[1].strip().split(' ')[-1]
yield x, y


class LabelingGenerator(ABCGenerator):
def __init__(self, files):
self.files = files
self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files)

@property
def steps(self) -> int:
return self._line_count

def __iter__(self):
for file in self.files:
with open(file, 'r') as f:
for line in f:
rows = line.split('\t')
x = rows[0].strip().split(' ')[1:-1]
y = rows[1].strip().split(' ')[1:-1]
yield x, y


def run_classification_model():
from kashgari.tasks.classification import BiGRU_Model
files = download_data()
gen = ClassificationGenerator(files)

model = BiGRU_Model()
model.fit_generator(gen)


def run_labeling_model():
from kashgari.tasks.labeling import BiGRU_Model
files = download_data()
gen = LabelingGenerator(files)

model = BiGRU_Model()
model.fit_generator(gen)


if __name__ == "__main__":
run_classification_model()
12 changes: 12 additions & 0 deletions kashgari/embeddings/abc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# file: abc_embedding.py
# time: 2:43 下午

import tqdm
import json
import pydoc
import logging
Expand Down Expand Up @@ -73,13 +74,24 @@ def __init__(self,
self.segment = False # 默认不需要添加 segment
self.kwargs = kwargs

self.embedding_size = None

def set_sequence_length(self, length: int):
self.sequence_length = length
if self.embed_model is not None:
logging.info(f"Rebuild embedding model with sequence length: {length}")
self.embed_model = None
self.build_embedding_model()

def calculate_sequence_length_if_needs(self, corpus_gen: CorpusGenerator, cover_rate: float = 0.95):
if self.sequence_length is None:
seq_lens = []
for sentence, _ in tqdm.tqdm(corpus_gen, total=corpus_gen.steps,
desc="Calculating sequence length"):
seq_lens.append(len(sentence))
self.sequence_length = sorted(seq_lens)[int(cover_rate * len(seq_lens))]
logging.warning(f'Calculated sequence length = {self.sequence_length}')

def build(self, x_data: TextSamplesVar, y_data: LabelSamplesVar):
gen = CorpusGenerator(x_data=x_data, y_data=y_data)
self.build_with_generator(gen)
Expand Down
11 changes: 6 additions & 5 deletions kashgari/embeddings/transformer_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self,
self.segment = True

self.vocab_list = []
self.max_sequence_length = None

def build_text_vocab(self, gen: CorpusGenerator = None, force=False):
if not self.text_processor.is_vocab_build:
Expand All @@ -78,14 +79,13 @@ def build_text_vocab(self, gen: CorpusGenerator = None, force=False):

def build_embedding_model(self):
if self.embed_model is None:
kwargs = {}
config_path = self.config_path

config = json.load(open(config_path))
if self.sequence_length:
if self.sequence_length > config.get('max_position_embeddings'):
self.sequence_length = config.get('max_position_embeddings')
logging.warning(f"Max seq length is {self.sequence_length}")
if 'max_position' in config:
self.max_sequence_length = config['max_position']
else:
self.max_sequence_length = config.get('max_position_embeddings')

bert_model = build_transformer_model(config_path=self.config_path,
checkpoint_path=self.checkpoint_path,
Expand All @@ -94,6 +94,7 @@ def build_embedding_model(self):
return_keras_model=True)

self.embed_model = bert_model
self.embedding_size = bert_model.output.shape[-1]


if __name__ == "__main__":
Expand Down
76 changes: 36 additions & 40 deletions kashgari/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,43 @@
# file: generator.py
# time: 4:53 下午

from abc import ABC
import random
from typing import List
from typing import Iterable


class CorpusGenerator:
class ABCGenerator(Iterable, ABC):

def __init__(self, x_data: List, y_data: List):
@property
def steps(self) -> int:
raise NotImplementedError

def __iter__(self):
raise NotImplementedError


class CorpusGenerator(ABCGenerator):

def __init__(self, x_data: List, y_data: List, shuffle=True):
self.x_data = x_data
self.y_data = y_data

self._index_list = list(range(len(self.x_data)))
self._current_index = 0

random.shuffle(self._index_list)
if shuffle:
random.shuffle(self._index_list)

def reset(self):
self._current_index = 0
def __iter__(self):
for i in self._index_list:
yield self.x_data[i], self.y_data[i]

@property
def steps(self) -> int:
return len(self.x_data)

def __iter__(self):
return self

def __next__(self):
self._current_index += 1
if self._current_index >= len(self.x_data) - 1:
raise StopIteration()

sample_index = self._index_list[self._current_index]
return self.x_data[sample_index], self.y_data[sample_index]

def __call__(self, *args, **kwargs):
return self


class BatchDataGenerator:
class BatchDataGenerator(Iterable):
def __init__(self,
corpus,
text_processor,
Expand All @@ -66,27 +65,24 @@ def steps(self) -> int:
return self.corpus.steps // self.batch_size

def __iter__(self):
return self

def __next__(self):
x_set = []
y_set = []
for i in range(self.batch_size):
try:
x, y = next(self.corpus)
except StopIteration:
self.corpus.reset()
x, y = next(self.corpus)
x_set, y_set = [], []
for x, y in self.corpus:
x_set.append(x)
y_set.append(y)
if len(x_set) == self.batch_size:
x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment)
y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True)
yield x_tensor, y_tensor
x_set, y_set = [], []
# final step
if x_set:
x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment)
y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True)
yield x_tensor, y_tensor

x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment)
y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True)
return x_tensor, y_tensor

def __call__(self, *args, **kwargs):
def __next__(self):
return self


if __name__ == "__main__":
pass
def generator(self):
for item in self:
yield item
4 changes: 1 addition & 3 deletions kashgari/processors/abc_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def __init__(self, **kwargs):
self.vocab2idx = kwargs.get('vocab2idx', {})
self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()])

self.corpus_sequence_length = kwargs.get('corpus_sequence_length', None)

@property
def vocab_size(self) -> int:
return len(self.vocab2idx)
Expand All @@ -35,7 +33,7 @@ def vocab_size(self) -> int:
def is_vocab_build(self) -> bool:
return self.vocab_size != 0

def build_vocab_dict_if_needs(self, generator: Generator, min_count: int = 3):
def build_vocab_dict_if_needs(self, generator: Generator):
raise NotImplementedError


Expand Down
1 change: 0 additions & 1 deletion kashgari/processors/class_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
class ClassificationProcessor(ABCProcessor):

def build_vocab_dict_if_needs(self, generator: CorpusGenerator):
generator.reset()
if not self.vocab2idx:
vocab2idx = {}

Expand Down
26 changes: 8 additions & 18 deletions kashgari/processors/sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tqdm
import numpy as np
from typing import Dict, List

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

Expand Down Expand Up @@ -73,19 +74,19 @@ def __init__(self,
else:
self._initial_vocab_dic = {}

self._showed_seq_len_warning = False

def build_vocab_dict_if_needs(self, generator: CorpusGenerator):
if not self.vocab2idx:
vocab2idx = self._initial_vocab_dic

token2count = {}
seq_lens = []
generator.reset()

for sentence, label in tqdm.tqdm(generator, total=generator.steps, desc="Preparing text vocab dict"):
if self.vocab_dict_type == 'text':
target = sentence
else:
target = label
seq_lens.append(len(target))
for token in target:
count = token2count.get(token, 0)
token2count[token] = count + 1
Expand All @@ -101,20 +102,10 @@ def build_vocab_dict_if_needs(self, generator: CorpusGenerator):
self.vocab2idx = vocab2idx
self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()])

if self.corpus_sequence_length is None:
self.corpus_sequence_length = sorted(seq_lens)[int(0.95 * len(seq_lens))]

logging.info("------ Build vocab dict finished, Top 10 token ------")
for token, index in list(self.vocab2idx.items())[:10]:
logging.info(f"Token: {token:8s} -> {index}")
logging.info("------ Build vocab dict finished, Top 10 token ------")
else:
if self.corpus_sequence_length is None:
seq_lens = []
generator.reset()
for sentence, _ in generator:
seq_lens.append(len(sentence))
self.corpus_sequence_length = sorted(seq_lens)[int(0.95 * len(seq_lens))]

def numerize_samples(self,
samples: TextSamplesVar,
Expand All @@ -124,8 +115,10 @@ def numerize_samples(self,
**kwargs) -> np.ndarray:
if seq_length is None:
seq_length = max([len(i) for i in samples])
logging.warning(
f'Sequence length is None, will use the max length of the samples, which is {seq_length}')
if not self._showed_seq_len_warning:
logging.warning(
f'Sequence length is None, will use the max length of the samples, which is {seq_length}')
self._showed_seq_len_warning = True

numerized_samples = []
for seq in samples:
Expand Down Expand Up @@ -173,6 +166,3 @@ def reverse_numerize(self,
p.build_vocab_dict_if_needs(gen)
print(p.vocab2idx)

p2 = SequenceProcessor()
p2.build_vocab_dict_if_needs(gen)
print(p2.vocab2idx)
1 change: 1 addition & 0 deletions kashgari/tasks/abs_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_model(self,
raise ValueError('Need to set default_labeling_processor')
self.embedding.label_processor = self.default_labeling_processor
self.embedding.build_with_generator(train_gen)
self.embedding.calculate_sequence_length_if_needs(train_gen)
if self.tf_model is None:
self.build_model_arc()
self.compile_model()
Expand Down
4 changes: 2 additions & 2 deletions kashgari/tasks/classification/abc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def fit_generator(self,
segment=self.embedding.segment,
seq_length=self.embedding.sequence_length,
batch_size=batch_size)
fit_kwargs['validation_data'] = valid_gen
fit_kwargs['validation_data'] = valid_gen.generator()
fit_kwargs['validation_steps'] = valid_gen.steps

if callbacks:
fit_kwargs['callbacks'] = callbacks

return self.tf_model.fit(train_gen,
return self.tf_model.fit(train_gen.generator(),
steps_per_epoch=train_gen.steps,
epochs=epochs,
callbacks=callbacks)
Expand Down
1 change: 1 addition & 0 deletions kashgari/tasks/labeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# file: __init__.py
# time: 4:30 下午

from .bi_gru_model import BiGRU_Model
from .bi_lstm_model import BiLSTM_Model

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 03c335a

Please sign in to comment.