Skip to content

Commit

Permalink
fix: adapt with feather and parquet in S3
Browse files Browse the repository at this point in the history
  • Loading branch information
Mamoru Miura committed Mar 22, 2024
1 parent 837742d commit d43deef
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
9 changes: 5 additions & 4 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import xml.etree.ElementTree as ET
from abc import abstractmethod
from io import BytesIO
from logging import getLogger

import luigi
Expand Down Expand Up @@ -203,11 +204,11 @@ def __init__(self, engine='pyarrow', compression=None):
super(ParquetFileProcessor, self).__init__()

def format(self):
return None
return luigi.format.Nop

def load(self, file):
# MEMO: read_parquet only supports a filepath as string (not a file handle)
return pd.read_parquet(file.name)
return pd.read_parquet(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.'
Expand All @@ -222,10 +223,10 @@ def __init__(self, store_index_in_feather: bool):
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'

def format(self):
return None
return luigi.format.Nop

def load(self, file):
loaded_df = pd.read_feather(file.name)
loaded_df = pd.read_feather(BytesIO(file.read()))

if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
Expand Down
24 changes: 24 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,30 @@ def test_last_modified_time_without_file(self):
with self.assertRaises(FileNotFoundError):
target.last_modification_time()

@mock_s3
def test_save_on_s3_feather(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')

obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.feather')

target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()

@mock_s3
def test_save_on_s3_parquet(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')

obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.parquet')

target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()


class ModelTargetTest(unittest.TestCase):
def tearDown(self):
Expand Down

0 comments on commit d43deef

Please sign in to comment.