diff --git a/pystreamapi/loaders/__csv_loader.py b/pystreamapi/loaders/__csv_loader.py index 9c9ef97..3f19274 100644 --- a/pystreamapi/loaders/__csv_loader.py +++ b/pystreamapi/loaders/__csv_loader.py @@ -3,8 +3,10 @@ from collections import namedtuple from csv import reader +from pystreamapi.loaders.__lazy_file_iterable import LazyFileIterable -def csv(file_path: str, delimiter=',', encoding="utf-8") -> list: + +def csv(file_path: str, delimiter=',', encoding="utf-8") -> LazyFileIterable: """ Loads a CSV file and converts it into a list of namedtuples. @@ -15,8 +17,13 @@ def csv(file_path: str, delimiter=',', encoding="utf-8") -> list: :param delimiter: The delimiter used in the CSV file. """ file_path = __validate_path(file_path) + return LazyFileIterable(lambda: __load_csv(file_path, delimiter, encoding)) + + +def __load_csv(file_path, delimiter, encoding): + """Load a CSV file and convert it into a list of namedtuples""" # skipcq: PTC-W6004 - with open(file_path, 'r', newline='', encoding=encoding) as csvfile: + with open(file_path, mode='r', newline='', encoding=encoding) as csvfile: csvreader = reader(csvfile, delimiter=delimiter) # Create a namedtuple type, casting the header values to int or float if possible @@ -24,18 +31,15 @@ def csv(file_path: str, delimiter=',', encoding="utf-8") -> list: # Process the data, casting values to int or float if possible data = [Row(*[__try_cast(value) for value in row]) for row in csvreader] - return data def __validate_path(file_path: str): - """Validate a path string to prevent path traversal attacks""" - if not os.path.isabs(file_path): - raise ValueError("The file_path must be an absolute path.") - + """Validate the path to the CSV file""" if not os.path.exists(file_path): raise FileNotFoundError("The specified file does not exist.") - + if not os.path.isfile(file_path): + raise ValueError("The specified path is not a file.") return file_path diff --git a/pystreamapi/loaders/__lazy_file_iterable.py b/pystreamapi/loaders/__lazy_file_iterable.py new file mode 100644 index 0000000..8440b15 --- /dev/null +++ b/pystreamapi/loaders/__lazy_file_iterable.py @@ -0,0 +1,23 @@ +class LazyFileIterable: + """LazyFileIterable is an iterable that loads data from a data source lazily.""" + + def __init__(self, loader): + self.__loader = loader + self.__data = None + + def __iter__(self): + self.__load_data() + return iter(self.__data) + + def __getitem__(self, index): + self.__load_data() + return self.__data[index] + + def __len__(self): + self.__load_data() + return len(self.__data) + + def __load_data(self): + """Loads the data from the data source if it has not been loaded yet.""" + if self.__data is None: + self.__data = self.__loader() diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 161551a..674d248 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,7 +1,9 @@ import os from unittest import TestCase + from pystreamapi.loaders import csv + class TestLoaders(TestCase): def setUp(self) -> None: @@ -18,6 +20,10 @@ def test_csv_loader(self): self.assertEqual(data[1].attr1, 'a') self.assertIsInstance(data[1].attr1, str) + def test_csv_loader_is_iterable(self): + data = csv(f'{self.path}/data.csv') + self.assertEqual(len(list(iter(data))), 2) + def test_csv_loader_with_custom_delimiter(self): data = csv(f'{self.path}/data2.csv', delimiter=';') self.assertEqual(len(data), 1) @@ -32,6 +38,6 @@ def test_csv_loader_with_invalid_path(self): with self.assertRaises(FileNotFoundError): csv(f'{self.path}/invalid.csv') - def test_csv_loader_with_non_absolute_path(self): + def test_csv_loader_with_non_file(self): with self.assertRaises(ValueError): - csv('invalid.csv') + csv(f'{self.path}/')