diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 67f8bbe5..1e08c3be 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -2,6 +2,7 @@ import pickle import xml.etree.ElementTree as ET from abc import abstractmethod +from io import BytesIO from logging import getLogger import luigi @@ -203,11 +204,11 @@ def __init__(self, engine='pyarrow', compression=None): super(ParquetFileProcessor, self).__init__() def format(self): - return None + return luigi.format.Nop def load(self, file): # MEMO: read_parquet only supports a filepath as string (not a file handle) - return pd.read_parquet(file.name) + return pd.read_parquet(BytesIO(file.read())) def dump(self, obj, file): assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' @@ -222,10 +223,10 @@ def __init__(self, store_index_in_feather: bool): self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' def format(self): - return None + return luigi.format.Nop def load(self, file): - loaded_df = pd.read_feather(file.name) + loaded_df = pd.read_feather(BytesIO(file.read())) if self._store_index_in_feather: if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): diff --git a/test/test_target.py b/test/test_target.py index 8b1332c0..8b6fe56c 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -208,6 +208,30 @@ def test_last_modified_time_without_file(self): with self.assertRaises(FileNotFoundError): target.last_modification_time() + @mock_s3 + def test_save_on_s3_feather(self): + conn = boto3.resource('s3', region_name='us-east-1') + conn.create_bucket(Bucket='test') + + obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) + file_path = os.path.join('s3://test/', 'test.feather') + + target = make_target(file_path=file_path, unique_id=None) + target.dump(obj) + loaded = target.load() + + @mock_s3 + def test_save_on_s3_parquet(self): + conn = boto3.resource('s3', region_name='us-east-1') + conn.create_bucket(Bucket='test') + + obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) + file_path = os.path.join('s3://test/', 'test.parquet') + + target = make_target(file_path=file_path, unique_id=None) + target.dump(obj) + loaded = target.load() + class ModelTargetTest(unittest.TestCase): def tearDown(self):