Skip to content

Commit

Permalink
Merge pull request #42 from PickwickSoft/bugfix/#41/stream-not-closin…
Browse files Browse the repository at this point in the history
…g-after-terminal-operation

Bugfix/#41/stream not closing after terminal operation
  • Loading branch information
garlontas authored Jul 20, 2023
2 parents 178b928 + 9e21499 commit 710cc68
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 11 deletions.
60 changes: 53 additions & 7 deletions pystreamapi/_streams/__base_stream.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,50 @@
# pylint: disable=protected-access
from __future__ import annotations
import functools
import itertools
from abc import abstractmethod
from builtins import reversed
from functools import cmp_to_key
from typing import Iterable, Callable, Any, TypeVar, Iterator
from typing import Iterable, Callable, Any, TypeVar, Iterator, TYPE_CHECKING

from pystreamapi.__optional import Optional
from pystreamapi._itertools.tools import dropwhile
from pystreamapi._lazy.process import Process
from pystreamapi._lazy.queue import ProcessQueue
from pystreamapi._streams.error.__error import ErrorHandler
from pystreamapi._itertools.tools import dropwhile
if TYPE_CHECKING:
from pystreamapi._streams.numeric.__numeric_base_stream import NumericBaseStream

K = TypeVar('K')
_V = TypeVar('_V')
_identity_missing = object()


def _operation(func):
"""
Decorator to execute all the processes in the queue before executing the decorated function.
To be applied to intermediate operations.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
self: BaseStream = args[0]
self._verify_open()
return func(*args, **kwargs)

return wrapper


def terminal(func):
"""
Decorator to execute all the processes in the queue before executing the decorated function.
To be applied to terminal operations.
"""
@functools.wraps(func)
@_operation
def wrapper(*args, **kwargs):
self: BaseStream = args[0]
# pylint: disable=protected-access
self._queue.execute_all()
self._close()
return func(*args, **kwargs)

return wrapper
Expand All @@ -47,6 +66,16 @@ class BaseStream(Iterable[K], ErrorHandler):
def __init__(self, source: Iterable[K]):
self._source = source
self._queue = ProcessQueue()
self._open = True

def _close(self):
"""Close the stream."""
self._open = False

def _verify_open(self):
"""Verify if stream is open. If not, raise an exception."""
if not self._open:
raise RuntimeError("The stream has been closed")

@terminal
def __iter__(self) -> Iterator[K]:
Expand All @@ -63,6 +92,7 @@ def concat(cls, *streams: "BaseStream[K]"):
"""
return cls(itertools.chain(*list(streams)))

@_operation
def distinct(self) -> 'BaseStream[_V]':
"""Returns a stream consisting of the distinct elements of this stream."""
self._queue.append(Process(self.__distinct))
Expand All @@ -72,6 +102,7 @@ def __distinct(self):
"""Removes duplicate elements from the stream."""
self._source = list(set(self._source))

@_operation
def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
"""
Returns, if this stream is ordered, a stream consisting of the remaining elements of this
Expand All @@ -86,6 +117,7 @@ def __drop_while(self, predicate: Callable[[Any], bool]):
"""Drops elements from the stream while the predicate is true."""
self._source = list(dropwhile(predicate, self._source, self))

@_operation
def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
"""
Returns a stream consisting of the elements of this stream that match the given predicate.
Expand All @@ -99,6 +131,7 @@ def filter(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
def _filter(self, predicate: Callable[[K], bool]):
"""Implementation of filter. Should be implemented by subclasses."""

@_operation
def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of replacing each element of this stream with
Expand All @@ -114,6 +147,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
def _flat_map(self, predicate: Callable[[K], Iterable[_V]]):
"""Implementation of flat_map. Should be implemented by subclasses."""

@_operation
def group_by(self, key_mapper: Callable[[K], Any]) -> 'BaseStream[K]':
"""
Returns a Stream consisting of the results of grouping the elements of this stream
Expand All @@ -133,6 +167,7 @@ def __group_by(self, key_mapper: Callable[[Any], Any]):
def _group_to_dict(self, key_mapper: Callable[[K], Any]) -> dict[K, list]:
"""Groups the stream into a dictionary. Should be implemented by subclasses."""

@_operation
def limit(self, max_size: int) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, truncated to be no longer
Expand All @@ -147,6 +182,7 @@ def __limit(self, max_size: int):
"""Limits the stream to the first n elements."""
self._source = itertools.islice(self._source, max_size)

@_operation
def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of applying the given function to the elements
Expand All @@ -161,18 +197,20 @@ def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
def _map(self, mapper: Callable[[K], _V]):
"""Implementation of map. Should be implemented by subclasses."""

def map_to_int(self) -> 'BaseStream[_V]':
@_operation
def map_to_int(self) -> NumericBaseStream[_V]:
"""
Returns a stream consisting of the results of converting the elements of this stream to
integers.
"""
self._queue.append(Process(self.__map_to_int))
return self
return self._to_numeric_stream()

def __map_to_int(self):
"""Converts the stream to integers."""
self._map(int)

@_operation
def map_to_str(self) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the results of converting the elements of this stream to
Expand All @@ -185,6 +223,7 @@ def __map_to_str(self):
"""Converts the stream to strings."""
self._map(str)

@_operation
def peek(self, action: Callable) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, additionally performing the
Expand All @@ -196,9 +235,11 @@ def peek(self, action: Callable) -> 'BaseStream[_V]':
return self

@abstractmethod
@_operation
def _peek(self, action: Callable):
"""Implementation of peek. Should be implemented by subclasses."""

@_operation
def reversed(self) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, with their order being
Expand All @@ -214,6 +255,7 @@ def __reversed(self):
except TypeError:
self._source = reversed(list(self._source))

@_operation
def skip(self, n: int) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the remaining elements of this stream after discarding the
Expand All @@ -228,6 +270,7 @@ def __skip(self, n: int):
"""Skips the first n elements of the stream."""
self._source = self._source[n:]

@_operation
def sorted(self, comparator: Callable[[K], int] = None) -> 'BaseStream[_V]':
"""
Returns a stream consisting of the elements of this stream, sorted according to natural
Expand All @@ -243,6 +286,7 @@ def __sorted(self, comparator: Callable[[K], int] = None):
else:
self._source = sorted(self._source, key=cmp_to_key(comparator))

@_operation
def take_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[_V]':
"""
Returns, if this stream is ordered, a stream consisting of the longest prefix of elements
Expand All @@ -257,8 +301,6 @@ def __take_while(self, predicate: Callable[[Any], bool]):
"""Takes elements from the stream while the predicate is true."""
self._source = list(itertools.takewhile(predicate, self._source))

# Terminal Operations:

@abstractmethod
@terminal
def all_match(self, predicate: Callable[[K], bool]):
Expand Down Expand Up @@ -373,3 +415,7 @@ def to_dict(self, key_mapper: Callable[[K], Any]) -> dict:
:param key_mapper:
"""

@abstractmethod
def _to_numeric_stream(self) -> NumericBaseStream[_V]:
"""Converts a stream to a numeric stream. To be implemented by subclasses."""
6 changes: 6 additions & 0 deletions pystreamapi/_streams/__parallel_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ def _set_parallelizer_src(self):

def __mapper(self, mapper):
return lambda x: self._one(mapper=mapper, item=x)

def _to_numeric_stream(self):
# pylint: disable=import-outside-toplevel
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
self.__class__ = ParallelNumericStream
return self
6 changes: 6 additions & 0 deletions pystreamapi/_streams/__sequential_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ def reduce(self, predicate: Callable, identity=_identity_missing, depends_on_sta
@stream.terminal
def to_dict(self, key_mapper: Callable[[Any], Any]) -> dict:
return self._group_to_dict(key_mapper)

def _to_numeric_stream(self):
# pylint: disable=import-outside-toplevel
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream
self.__class__ = SequentialNumericStream
return self
16 changes: 14 additions & 2 deletions pystreamapi/_streams/numeric/__numeric_base_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ def interquartile_range(self) -> Union[float, int, None]:
Calculates the iterquartile range of a numerical Stream
:return: The iterquartile range, can be int or float
"""
return self.third_quartile() - self.first_quartile() if len(self._source) > 0 else None
return self._interquartile_range()

def _interquartile_range(self):
"""Implementation of the interquartile range calculation"""
return self._third_quartile() - self._first_quartile() if len(self._source) > 0 else None

@terminal
def first_quartile(self) -> Union[float, int, None]:
"""
Calculates the first quartile of a numerical Stream
:return: The first quartile, can be int or float
"""
return self._first_quartile()

def _first_quartile(self):
"""Implementation of the first quartile calculation"""
self._source = sorted(self._source)
return self.__median(self._source[:(len(self._source)) // 2])

Expand Down Expand Up @@ -59,7 +67,7 @@ def __median(source) -> Union[float, int, None]:
@terminal
def mode(self) -> Union[list[Union[int, float]], None]:
"""
Calculates the mode(s) (most frequently occurring element) of a numerical Stream
Calculates the mode/modes (most frequently occurring element/elements) of a numerical Stream
:return: The mode, can be int or float
"""
frequency = Counter(self._source)
Expand Down Expand Up @@ -90,5 +98,9 @@ def third_quartile(self) -> Union[float, int, None]:
Calculates the third quartile of a numerical Stream
:return: The third quartile, can be int or float
"""
return self._third_quartile()

def _third_quartile(self):
"""Implementation of the third quartile calculation"""
self._source = sorted(self._source)
return self.__median(self._source[(len(self._source) + 1) // 2:])
1 change: 0 additions & 1 deletion pystreamapi/_streams/numeric/__parallel_numeric_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def sum(self) -> Union[float, int, None]:
_sum = self.__sum()
return 0 if _sum == [] else _sum

@terminal
def __sum(self):
"""Parallel sum method"""
self._set_parallelizer_src()
Expand Down
108 changes: 108 additions & 0 deletions tests/test_stream_closed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import unittest

from parameterized import parameterized_class

from pystreamapi._streams.__parallel_stream import ParallelStream
from pystreamapi._streams.__sequential_stream import SequentialStream
from pystreamapi._streams.numeric.__parallel_numeric_stream import ParallelNumericStream
from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream


@parameterized_class("stream", [
[SequentialStream],
[ParallelStream],
[SequentialNumericStream],
[ParallelNumericStream]])
class BaseStreamClosed(unittest.TestCase):
def test_closed_stream_throws_exception(self):
# pylint: disable=too-many-statements
closed_stream = self.stream([])
closed_stream.for_each(lambda _: ...)

# Verify that all methods throw a RuntimeError
with self.assertRaises(RuntimeError):
list(closed_stream)

with self.assertRaises(RuntimeError):
closed_stream.distinct()

with self.assertRaises(RuntimeError):
closed_stream.drop_while(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.filter(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.flat_map(lambda x: [x])

with self.assertRaises(RuntimeError):
closed_stream.group_by(lambda x: x)

with self.assertRaises(RuntimeError):
closed_stream.limit(5)

with self.assertRaises(RuntimeError):
closed_stream.map(lambda x: x)

with self.assertRaises(RuntimeError):
closed_stream.map_to_int()

with self.assertRaises(RuntimeError):
closed_stream.map_to_str()

with self.assertRaises(RuntimeError):
closed_stream.peek(lambda x: None)

with self.assertRaises(RuntimeError):
closed_stream.reversed()

with self.assertRaises(RuntimeError):
closed_stream.skip(5)

with self.assertRaises(RuntimeError):
closed_stream.sorted()

with self.assertRaises(RuntimeError):
closed_stream.take_while(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.all_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.any_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.count()

with self.assertRaises(RuntimeError):
closed_stream.find_any()

with self.assertRaises(RuntimeError):
closed_stream.find_first()

with self.assertRaises(RuntimeError):
closed_stream.for_each(lambda x: None)

with self.assertRaises(RuntimeError):
closed_stream.none_match(lambda x: True)

with self.assertRaises(RuntimeError):
closed_stream.min()

with self.assertRaises(RuntimeError):
closed_stream.max()

with self.assertRaises(RuntimeError):
closed_stream.reduce(lambda x, y: x + y)

with self.assertRaises(RuntimeError):
closed_stream.to_list()

with self.assertRaises(RuntimeError):
closed_stream.to_tuple()

with self.assertRaises(RuntimeError):
closed_stream.to_set()

with self.assertRaises(RuntimeError):
closed_stream.to_dict(lambda x: x)
Loading

0 comments on commit 710cc68

Please sign in to comment.