diff --git a/text_indexer/__init__.py b/text_indexer/__init__.py index f5143cf..e69de29 100644 --- a/text_indexer/__init__.py +++ b/text_indexer/__init__.py @@ -1 +0,0 @@ -from .char_indexer import CharIndexer diff --git a/text_indexer/indexers/__init__.py b/text_indexer/indexers/__init__.py new file mode 100644 index 0000000..d0fe22c --- /dev/null +++ b/text_indexer/indexers/__init__.py @@ -0,0 +1,2 @@ +from .base import Indexer +from .char_indexer import CharIndexer diff --git a/text_indexer/base.py b/text_indexer/indexers/base.py similarity index 67% rename from text_indexer/base.py rename to text_indexer/indexers/base.py index 0b050b5..b787d8f 100644 --- a/text_indexer/base.py +++ b/text_indexer/indexers/base.py @@ -1,9 +1,9 @@ from typing import List, Tuple -from abc import abstractmethod, ABC +from abc import abstractmethod, abstractclassmethod, ABC -class BaseIndexer(ABC): +class Indexer(ABC): @abstractmethod def fit(self, utterances: List[str]): @@ -24,3 +24,11 @@ def inverse_transform( ) -> List[str]: """Restore indices to strings""" pass + + @abstractmethod + def save(self, output_path: str): + pass + + @abstractclassmethod + def load(cls, output_path: str): + pass diff --git a/text_indexer/char_indexer.py b/text_indexer/indexers/char_indexer.py similarity index 79% rename from text_indexer/char_indexer.py rename to text_indexer/indexers/char_indexer.py index 5b41145..683793f 100644 --- a/text_indexer/char_indexer.py +++ b/text_indexer/indexers/char_indexer.py @@ -1,9 +1,12 @@ from typing import List +import os import warnings import strpipe as sp +from .base import Indexer from .pipe_indexer import PipeIndexer +from .utils import load_json, save_json, mkdir_p class CharIndexer(PipeIndexer): @@ -114,3 +117,22 @@ def fit(self, utterances: List[str]): UserWarning, ) self.pipe.fit(['dummy fit']) + + def save(self, output_dir: str): + mkdir_p(output_dir) + params = { + "maxlen": self.maxlen, + "sos_token": self.sos_token, + "eos_token": self.eos_token, + "pad_token": self.pad_token, + "unk_token": self.unk_token, + } + save_json(params, os.path.join(output_dir, 'indexer.json')) + self.pipe.save_json(os.path.join(output_dir, 'pipe.json')) + + @classmethod + def load(cls, output_dir: str) -> Indexer: + params = load_json(os.path.join(output_dir, 'indexer.json')) + indexer = cls.create_without_word2vec(**params) + indexer.pipe = sp.Pipe.restore_from_json(os.path.join(output_dir, 'pipe.json')) + return indexer diff --git a/text_indexer/pipe_indexer.py b/text_indexer/indexers/pipe_indexer.py similarity index 96% rename from text_indexer/pipe_indexer.py rename to text_indexer/indexers/pipe_indexer.py index 942c75d..45c3d8b 100644 --- a/text_indexer/pipe_indexer.py +++ b/text_indexer/indexers/pipe_indexer.py @@ -1,10 +1,10 @@ import abc from typing import List, Tuple -from .base import BaseIndexer +from .base import Indexer -class PipeIndexer(BaseIndexer): +class PipeIndexer(Indexer): def __init__( self, diff --git a/text_indexer/indexers/tests/__init__.py b/text_indexer/indexers/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_indexer/tests/data/example.json b/text_indexer/indexers/tests/data/example.json similarity index 100% rename from text_indexer/tests/data/example.json rename to text_indexer/indexers/tests/data/example.json diff --git a/text_indexer/tests/data/example.msg b/text_indexer/indexers/tests/data/example.msg similarity index 100% rename from text_indexer/tests/data/example.msg rename to text_indexer/indexers/tests/data/example.msg diff --git a/text_indexer/tests/template.py b/text_indexer/indexers/tests/template.py similarity index 53% rename from text_indexer/tests/template.py rename to text_indexer/indexers/tests/template.py index 7ef8bae..22c8ca6 100644 --- a/text_indexer/tests/template.py +++ b/text_indexer/indexers/tests/template.py @@ -1,4 +1,16 @@ import abc +import os +from pathlib import Path +import shutil + + +def list_all_files(root: str): + output = [] + for prefix, _, files in os.walk(root): + for f in files: + path = os.path.join(prefix, f) + output.append(path) + return output class TestTemplate(abc.ABC): @@ -16,11 +28,20 @@ def setUpClass(cls): '安靜的祥睿', # equal to 7 after adding sos eos '喔', # shorter than 7 after adding sos eos ] + cls.output_dir = Path(__file__).parent / 'example_indexer/' def setUp(self): self.indexer = self.get_indexer() self.indexer.fit(self.input_data) + def tearDown(self): + if self.output_dir.exists(): + shutil.rmtree(str(self.output_dir)) + + @abc.abstractmethod + def get_indexer_class(self): + pass + @abc.abstractmethod def get_indexer(self): pass @@ -40,19 +61,26 @@ def test_index2word_out_of_range(self): def test_transform(self): tx_data, meta = self.indexer.transform(self.input_data) correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data() - self.assertEqual( - correct_idxs, - tx_data, - ) - self.assertEqual( - correct_seqs, - meta['seqlen'], - ) + self.assertEqual(correct_idxs, tx_data) + self.assertEqual(correct_seqs, meta['seqlen']) def test_inverse_transform(self): tx_data, meta = self.indexer.transform(self.input_data) output = self.indexer.inverse_transform(tx_data, meta['inv_info']) + self.assertEqual(output, self.input_data) + + def test_save(self): + output_dir = str(self.output_dir) + self.indexer.save(output_dir) self.assertEqual( - output, - self.input_data, + set([os.path.join(output_dir, filepath) for filepath in ['pipe.json', 'indexer.json']]), + set(list_all_files(output_dir)), ) + + def test_load(self): + self.indexer.save(str(self.output_dir)) + indexer = self.get_indexer_class().load(str(self.output_dir)) + tx_data, meta = indexer.transform(self.input_data) + correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data() + self.assertEqual(correct_idxs, tx_data) + self.assertEqual(correct_seqs, meta['seqlen']) diff --git a/text_indexer/tests/test_char_indexer.py b/text_indexer/indexers/tests/test_char_indexer.py similarity index 85% rename from text_indexer/tests/test_char_indexer.py rename to text_indexer/indexers/tests/test_char_indexer.py index d03222b..df95b8b 100644 --- a/text_indexer/tests/test_char_indexer.py +++ b/text_indexer/indexers/tests/test_char_indexer.py @@ -15,6 +15,9 @@ def load_w2v(path: str): class CharIndexerWithoutW2vTestCase(TestTemplate, TestCase): + def get_indexer_class(self): + return CharIndexer + def get_indexer(self): return CharIndexer.create_without_word2vec( sos_token=self.sos_token, @@ -75,28 +78,13 @@ def get_indexer(self): ) def test_embedding_correct(self): - self.assertEqual( - self.indexer.word2vec, - self.test_emb, - ) + self.assertEqual(self.indexer.word2vec, self.test_emb) def test_transform_and_fit_dont_change(self): tx_data, meta = self.indexer.transform(self.input_data) correct_idxs, correct_seqs = self.get_correct_idxs_and_seqlen_of_input_data() - self.assertEqual( - correct_idxs, - tx_data, - ) - self.assertEqual( - correct_seqs, - meta['seqlen'], - ) + self.assertEqual(correct_idxs, tx_data) + self.assertEqual(correct_seqs, meta['seqlen']) self.indexer.fit(self.input_data) - self.assertEqual( - correct_idxs, - tx_data, - ) - self.assertEqual( - correct_seqs, - meta['seqlen'], - ) + self.assertEqual(correct_idxs, tx_data) + self.assertEqual(correct_seqs, meta['seqlen']) diff --git a/text_indexer/indexers/utils.py b/text_indexer/indexers/utils.py new file mode 100644 index 0000000..3f473d9 --- /dev/null +++ b/text_indexer/indexers/utils.py @@ -0,0 +1,24 @@ +import os +import errno +import json + + +def save_json(data, path): + with open(path, 'w', encoding='utf-8') as filep: + json.dump(data, filep, ensure_ascii=False, indent=2) + + +def load_json(path): + with open(path, 'r', encoding='utf-8') as filep: + output = json.load(filep) + return output + + +def mkdir_p(path): + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise diff --git a/text_indexer/io.py b/text_indexer/io.py new file mode 100644 index 0000000..24f0360 --- /dev/null +++ b/text_indexer/io.py @@ -0,0 +1,117 @@ +from os.path import join, dirname, basename, isdir, isfile +import tarfile +import shutil +import logging + +from .indexers import ( + Indexer, + CharIndexer, +) + + +LOGGER = logging.getLogger('__file__') +INDEXERS = { + indexer_cls.__class__.__name__: indexer_cls for indexer_cls in [ + CharIndexer, + ] +} + + +def save_indexer( + indexer: Indexer, + output_dir: str, + logger: logging.Logger = LOGGER, + ) -> str: + + _validate_dir(output_dir) + + # save indexer class name + class_name = indexer.__class__.__name__ + _save_name(class_name, _gen_name_path(output_dir)) + + # save indexer + indexer.save(output_dir) # save indexer + del indexer + + # compress + compressed_filepath = _compress_to_tar(output_dir) # compressed + shutil.rmtree(output_dir) # remove output_dir + logger.info(f'Export to {compressed_filepath}') + + return compressed_filepath + + +def load_indexer( + path: str, + logger: logging.Logger = LOGGER, + ) -> Indexer: + + _validate_file(path) + + # extract + output_dir = _extract_from_tar(path) + logger.info(f'Extract to {output_dir}') + + # load indexer + indexer_name = _load_name(_gen_name_path(output_dir)) + indexer_module = _get_indexer_module(indexer_name) + indexer = indexer_module.load(output_dir) + + return indexer + + +def _validate_file(path: str): + if not isfile(path): + raise ValueError(f'[{path}] is not a file path.') + + +def _validate_dir(directory: str): + if not isdir(directory): + raise ValueError(f'[{directory}] is not a directory.') + + +def _save_name(name: str, path: str) -> None: + with open(path, 'w', encoding='utf-8') as text_file: + text_file.write(name) + + +def _load_name(path: str) -> str: + with open(path, 'r', encoding='utf-8') as text_file: + name = text_file.read() + return name + + +def _compress_to_tar(output_dir: str) -> str: + tar_path = _gen_compression_path(output_dir) + with tarfile.open(tar_path, "w:gz") as tar: + tar.add(output_dir, arcname=basename(output_dir)) + return tar_path + + +def _extract_from_tar(path: str) -> str: + output_dir = _gen_extraction_dir(path) + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=output_dir) + return output_dir + + +def _gen_name_path(directory: str) -> str: + return join(directory, 'name') + + +def _gen_compression_path(directory: str) -> str: + parent_dir = dirname(dirname(directory)) + dir_name = basename(dirname(directory)) + path = join(parent_dir, f'{dir_name}-all.tar.gz') + return path + + +def _gen_extraction_dir(path: str) -> str: + parent_dir = dirname(path) + filename = basename(path) + output_dirname = '{}/'.format(filename.split('-')[0]) + return join(parent_dir, output_dirname) + + +def _get_indexer_module(indexer_name: str) -> Indexer: + return INDEXERS[indexer_name] diff --git a/text_indexer/tests/test_io.py b/text_indexer/tests/test_io.py new file mode 100644 index 0000000..8b3cdee --- /dev/null +++ b/text_indexer/tests/test_io.py @@ -0,0 +1,53 @@ +from unittest import TestCase +from unittest.mock import patch +import shutil +from os.path import join, abspath, exists, dirname +import os + +from ..io import save_indexer, load_indexer +from text_indexer.indexers.utils import save_json, load_json + + +class MockIndexer(object): + + def __init__(self, aa=1, bb=2): + self.aa = aa + self.bb = bb + self.a = 1 + self.b = 2 + + def save(self, output_dir): + save_json({'a': self.a, 'b': self.b}, join(output_dir, 'fake_pipe.json')) + save_json({'aa': self.aa, 'bb': self.bb}, join(output_dir, 'fake_indexer.json')) + + @classmethod + def load(cls, output_dir): + pipe = load_json(join(output_dir, 'fake_pipe.json')) + params = load_json(join(output_dir, 'fake_indexer.json')) + indexer = cls(**params) + indexer.pipe = pipe + return indexer + + +class IOTestCase(TestCase): + + def setUp(self): + root_dir = dirname(abspath(__file__)) + self.output_dir = join(root_dir, 'example/') + os.mkdir(self.output_dir) + + def tearDown(self): + if exists(self.output_dir): + shutil.rmtree(self.output_dir) + + def test_save_indexer(self): + export_path = save_indexer(indexer=MockIndexer(), output_dir=self.output_dir) + self.assertTrue(exists(export_path)) + os.remove(export_path) + + def test_load_indexer(self): + export_path = save_indexer(indexer=MockIndexer(), output_dir=self.output_dir) + with patch('text_indexer.io._get_indexer_module', return_value=MockIndexer): + load_indexer(export_path) + self.assertTrue(exists(self.output_dir)) + os.remove(export_path)