Skip to content

Commit

Permalink
Merge pull request #12 from Yoctol/serializable_cherry_pick
Browse files Browse the repository at this point in the history
serialization cherry pick
  • Loading branch information
SoluMilken authored Nov 9, 2018
2 parents e8c2e1e + 1b3c931 commit 452702e
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 35 deletions.
1 change: 0 additions & 1 deletion text_indexer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .char_indexer import CharIndexer
2 changes: 2 additions & 0 deletions text_indexer/indexers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import Indexer
from .char_indexer import CharIndexer
12 changes: 10 additions & 2 deletions text_indexer/base.py → text_indexer/indexers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Tuple

from abc import abstractmethod, ABC
from abc import abstractmethod, abstractclassmethod, ABC


class BaseIndexer(ABC):
class Indexer(ABC):

@abstractmethod
def fit(self, utterances: List[str]):
Expand All @@ -24,3 +24,11 @@ def inverse_transform(
) -> List[str]:
"""Restore indices to strings"""
pass

@abstractmethod
def save(self, output_path: str):
pass

@abstractclassmethod
def load(cls, output_path: str):
pass
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List
import os
import warnings

import strpipe as sp

from .base import Indexer
from .pipe_indexer import PipeIndexer
from .utils import load_json, save_json, mkdir_p


class CharIndexer(PipeIndexer):
Expand Down Expand Up @@ -114,3 +117,22 @@ def fit(self, utterances: List[str]):
UserWarning,
)
self.pipe.fit(['dummy fit'])

def save(self, output_dir: str):
mkdir_p(output_dir)
params = {
"maxlen": self.maxlen,
"sos_token": self.sos_token,
"eos_token": self.eos_token,
"pad_token": self.pad_token,
"unk_token": self.unk_token,
}
save_json(params, os.path.join(output_dir, 'indexer.json'))
self.pipe.save_json(os.path.join(output_dir, 'pipe.json'))

@classmethod
def load(cls, output_dir: str) -> Indexer:
params = load_json(os.path.join(output_dir, 'indexer.json'))
indexer = cls.create_without_word2vec(**params)
indexer.pipe = sp.Pipe.restore_from_json(os.path.join(output_dir, 'pipe.json'))
return indexer
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import abc
from typing import List, Tuple

from .base import BaseIndexer
from .base import Indexer


class PipeIndexer(BaseIndexer):
class PipeIndexer(Indexer):

def __init__(
self,
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
import abc
import os
from pathlib import Path
import shutil


def list_all_files(root: str):
output = []
for prefix, _, files in os.walk(root):
for f in files:
path = os.path.join(prefix, f)
output.append(path)
return output


class TestTemplate(abc.ABC):
Expand All @@ -16,11 +28,20 @@ def setUpClass(cls):
'安靜的祥睿', # equal to 7 after adding sos eos
'喔', # shorter than 7 after adding sos eos
]
cls.output_dir = Path(__file__).parent / 'example_indexer/'

def setUp(self):
self.indexer = self.get_indexer()
self.indexer.fit(self.input_data)

def tearDown(self):
if self.output_dir.exists():
shutil.rmtree(str(self.output_dir))

@abc.abstractmethod
def get_indexer_class(self):
pass

@abc.abstractmethod
def get_indexer(self):
pass
Expand All @@ -40,19 +61,26 @@ def test_index2word_out_of_range(self):
def test_transform(self):
tx_data, meta = self.indexer.transform(self.input_data)
correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data()
self.assertEqual(
correct_idxs,
tx_data,
)
self.assertEqual(
correct_seqs,
meta['seqlen'],
)
self.assertEqual(correct_idxs, tx_data)
self.assertEqual(correct_seqs, meta['seqlen'])

def test_inverse_transform(self):
tx_data, meta = self.indexer.transform(self.input_data)
output = self.indexer.inverse_transform(tx_data, meta['inv_info'])
self.assertEqual(output, self.input_data)

def test_save(self):
output_dir = str(self.output_dir)
self.indexer.save(output_dir)
self.assertEqual(
output,
self.input_data,
set([os.path.join(output_dir, filepath) for filepath in ['pipe.json', 'indexer.json']]),
set(list_all_files(output_dir)),
)

def test_load(self):
self.indexer.save(str(self.output_dir))
indexer = self.get_indexer_class().load(str(self.output_dir))
tx_data, meta = indexer.transform(self.input_data)
correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data()
self.assertEqual(correct_idxs, tx_data)
self.assertEqual(correct_seqs, meta['seqlen'])
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def load_w2v(path: str):

class CharIndexerWithoutW2vTestCase(TestTemplate, TestCase):

def get_indexer_class(self):
return CharIndexer

def get_indexer(self):
return CharIndexer.create_without_word2vec(
sos_token=self.sos_token,
Expand Down Expand Up @@ -75,28 +78,13 @@ def get_indexer(self):
)

def test_embedding_correct(self):
self.assertEqual(
self.indexer.word2vec,
self.test_emb,
)
self.assertEqual(self.indexer.word2vec, self.test_emb)

def test_transform_and_fit_dont_change(self):
tx_data, meta = self.indexer.transform(self.input_data)
correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data()
self.assertEqual(
correct_idxs,
tx_data,
)
self.assertEqual(
correct_seqs,
meta['seqlen'],
)
self.assertEqual(correct_idxs, tx_data)
self.assertEqual(correct_seqs, meta['seqlen'])
self.indexer.fit(self.input_data)
self.assertEqual(
correct_idxs,
tx_data,
)
self.assertEqual(
correct_seqs,
meta['seqlen'],
)
self.assertEqual(correct_idxs, tx_data)
self.assertEqual(correct_seqs, meta['seqlen'])
24 changes: 24 additions & 0 deletions text_indexer/indexers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import errno
import json


def save_json(data, path):
with open(path, 'w', encoding='utf-8') as filep:
json.dump(data, filep, ensure_ascii=False, indent=2)


def load_json(path):
with open(path, 'r', encoding='utf-8') as filep:
output = json.load(filep)
return output


def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
117 changes: 117 additions & 0 deletions text_indexer/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from os.path import join, dirname, basename, isdir, isfile
import tarfile
import shutil
import logging

from .indexers import (
Indexer,
CharIndexer,
)


LOGGER = logging.getLogger('__file__')
INDEXERS = {
indexer_cls.__class__.__name__: indexer_cls for indexer_cls in [
CharIndexer,
]
}


def save_indexer(
indexer: Indexer,
output_dir: str,
logger: logging.Logger = LOGGER,
) -> str:

_validate_dir(output_dir)

# save indexer class name
class_name = indexer.__class__.__name__
_save_name(class_name, _gen_name_path(output_dir))

# save indexer
indexer.save(output_dir) # save indexer
del indexer

# compress
compressed_filepath = _compress_to_tar(output_dir) # compressed
shutil.rmtree(output_dir) # remove output_dir
logger.info(f'Export to {compressed_filepath}')

return compressed_filepath


def load_indexer(
path: str,
logger: logging.Logger = LOGGER,
) -> Indexer:

_validate_file(path)

# extract
output_dir = _extract_from_tar(path)
logger.info(f'Extract to {output_dir}')

# load indexer
indexer_name = _load_name(_gen_name_path(output_dir))
indexer_module = _get_indexer_module(indexer_name)
indexer = indexer_module.load(output_dir)

return indexer


def _validate_file(path: str):
if not isfile(path):
raise ValueError(f'[{path}] is not a file path.')


def _validate_dir(directory: str):
if not isdir(directory):
raise ValueError(f'[{directory}] is not a directory.')


def _save_name(name: str, path: str) -> None:
with open(path, 'w', encoding='utf-8') as text_file:
text_file.write(name)


def _load_name(path: str) -> str:
with open(path, 'r', encoding='utf-8') as text_file:
name = text_file.read()
return name


def _compress_to_tar(output_dir: str) -> str:
tar_path = _gen_compression_path(output_dir)
with tarfile.open(tar_path, "w:gz") as tar:
tar.add(output_dir, arcname=basename(output_dir))
return tar_path


def _extract_from_tar(path: str) -> str:
output_dir = _gen_extraction_dir(path)
with tarfile.open(path, "r:gz") as tar:
tar.extractall(path=output_dir)
return output_dir


def _gen_name_path(directory: str) -> str:
return join(directory, 'name')


def _gen_compression_path(directory: str) -> str:
parent_dir = dirname(dirname(directory))
dir_name = basename(dirname(directory))
path = join(parent_dir, f'{dir_name}-all.tar.gz')
return path


def _gen_extraction_dir(path: str) -> str:
parent_dir = dirname(path)
filename = basename(path)
output_dirname = '{}/'.format(filename.split('-')[0])
return join(parent_dir, output_dirname)


def _get_indexer_module(indexer_name: str) -> Indexer:
return INDEXERS[indexer_name]
53 changes: 53 additions & 0 deletions text_indexer/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from unittest import TestCase
from unittest.mock import patch
import shutil
from os.path import join, abspath, exists, dirname
import os

from ..io import save_indexer, load_indexer
from text_indexer.indexers.utils import save_json, load_json


class MockIndexer(object):

def __init__(self, aa=1, bb=2):
self.aa = aa
self.bb = bb
self.a = 1
self.b = 2

def save(self, output_dir):
save_json({'a': self.a, 'b': self.b}, join(output_dir, 'fake_pipe.json'))
save_json({'aa': self.aa, 'bb': self.bb}, join(output_dir, 'fake_indexer.json'))

@classmethod
def load(cls, output_dir):
pipe = load_json(join(output_dir, 'fake_pipe.json'))
params = load_json(join(output_dir, 'fake_indexer.json'))
indexer = cls(**params)
indexer.pipe = pipe
return indexer


class IOTestCase(TestCase):

def setUp(self):
root_dir = dirname(abspath(__file__))
self.output_dir = join(root_dir, 'example/')
os.mkdir(self.output_dir)

def tearDown(self):
if exists(self.output_dir):
shutil.rmtree(self.output_dir)

def test_save_indexer(self):
export_path = save_indexer(indexer=MockIndexer(), output_dir=self.output_dir)
self.assertTrue(exists(export_path))
os.remove(export_path)

def test_load_indexer(self):
export_path = save_indexer(indexer=MockIndexer(), output_dir=self.output_dir)
with patch('text_indexer.io._get_indexer_module', return_value=MockIndexer):
load_indexer(export_path)
self.assertTrue(exists(self.output_dir))
os.remove(export_path)

0 comments on commit 452702e

Please sign in to comment.