diff --git a/samcli/commands/local/cli_common/invoke_context.py b/samcli/commands/local/cli_common/invoke_context.py index 837b122558..30dc8b4c03 100644 --- a/samcli/commands/local/cli_common/invoke_context.py +++ b/samcli/commands/local/cli_common/invoke_context.py @@ -7,7 +7,7 @@ import os from enum import Enum from pathlib import Path -from typing import IO, Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, cast from samcli.commands._utils.template import TemplateFailedParsingException, TemplateNotFoundException from samcli.commands.exceptions import ContainersInitializationException @@ -196,7 +196,7 @@ def __init__( self._stacks: List[Stack] = None # type: ignore self._env_vars_value: Optional[Dict] = None self._container_env_vars_value: Optional[Dict] = None - self._log_file_handle: Optional[IO] = None + self._log_file_handle: Optional[TextIO] = None self._debug_context: Optional[DebugContext] = None self._layers_downloader: Optional[LayerDownloader] = None self._container_manager: Optional[ContainerManager] = None @@ -490,7 +490,7 @@ def _get_env_vars_value(filename: Optional[str]) -> Optional[Dict]: ) from ex @staticmethod - def _setup_log_file(log_file: Optional[str]) -> Optional[IO]: + def _setup_log_file(log_file: Optional[str]) -> Optional[TextIO]: """ Open a log file if necessary and return the file handle. This will create a file if it does not exist @@ -500,7 +500,7 @@ def _setup_log_file(log_file: Optional[str]) -> Optional[IO]: if not log_file: return None - return open(log_file, "wb") + return open(log_file, "w", encoding="utf8") @staticmethod def _get_debug_context( diff --git a/samcli/commands/remote/remote_invoke_context.py b/samcli/commands/remote/remote_invoke_context.py index b710df8410..d1294983bc 100644 --- a/samcli/commands/remote/remote_invoke_context.py +++ b/samcli/commands/remote/remote_invoke_context.py @@ -242,7 +242,7 @@ class DefaultRemoteInvokeResponseConsumer(RemoteInvokeConsumer[RemoteInvokeRespo _stream_writer: StreamWriter def consume(self, remote_invoke_response: RemoteInvokeResponse) -> None: - self._stream_writer.write(cast(str, remote_invoke_response.response).encode()) + self._stream_writer.write_bytes(cast(str, remote_invoke_response.response).encode()) @dataclass @@ -254,4 +254,4 @@ class DefaultRemoteInvokeLogConsumer(RemoteInvokeConsumer[RemoteInvokeLogOutput] _stream_writer: StreamWriter def consume(self, remote_invoke_response: RemoteInvokeLogOutput) -> None: - self._stream_writer.write(remote_invoke_response.log_output.encode()) + self._stream_writer.write_bytes(remote_invoke_response.log_output.encode()) diff --git a/samcli/lib/docker/log_streamer.py b/samcli/lib/docker/log_streamer.py index b013459bae..3bb437781a 100644 --- a/samcli/lib/docker/log_streamer.py +++ b/samcli/lib/docker/log_streamer.py @@ -47,23 +47,21 @@ def stream_progress(self, logs: docker.APIClient.logs): else: curr_log_line_id = ids[_id] change_cursor_count = len(ids) - curr_log_line_id - self._stream.write( + self._stream.write_str( self._cursor_up_formatter.cursor_format(change_cursor_count) - + self._cursor_left_formatter.cursor_format(), - encode=True, + + self._cursor_left_formatter.cursor_format() ) self._stream_write(_id, status, stream, progress, error) if _id: - self._stream.write( + self._stream.write_str( self._cursor_down_formatter.cursor_format(change_cursor_count) - + self._cursor_left_formatter.cursor_format(), - encode=True, + + self._cursor_left_formatter.cursor_format() ) - self._stream.write(os.linesep, encode=True) + self._stream.write_str(os.linesep) - def _stream_write(self, _id: str, status: str, stream: bytes, progress: str, error: str): + def _stream_write(self, _id: str, status: str, stream: str, progress: str, error: str): """ Write stream information to stderr, if the stream information contains a log id, use the carriage return character to rewrite that particular line. @@ -80,14 +78,14 @@ def _stream_write(self, _id: str, status: str, stream: bytes, progress: str, err # NOTE(sriram-mv): Required for the purposes of when the cursor overflows existing terminal buffer. if not stream: - self._stream.write(os.linesep, encode=True) - self._stream.write( - self._cursor_up_formatter.cursor_format() + self._cursor_left_formatter.cursor_format(), encode=True + self._stream.write_str(os.linesep) + self._stream.write_str( + self._cursor_up_formatter.cursor_format() + self._cursor_left_formatter.cursor_format() ) - self._stream.write(self._cursor_clear_formatter.cursor_format(), encode=True) + self._stream.write_str(self._cursor_clear_formatter.cursor_format()) if not _id: - self._stream.write(stream, encode=True) - self._stream.write(status, encode=True) + self._stream.write_str(stream) + self._stream.write_str(status) else: - self._stream.write(f"\r{_id}: {status} {progress}", encode=True) + self._stream.write_str(f"\r{_id}: {status} {progress}") diff --git a/samcli/lib/package/ecr_uploader.py b/samcli/lib/package/ecr_uploader.py index f2d4371407..0393596b39 100644 --- a/samcli/lib/package/ecr_uploader.py +++ b/samcli/lib/package/ecr_uploader.py @@ -2,8 +2,8 @@ Client for uploading packaged artifacts to ecr """ import base64 -import io import logging +from io import StringIO from typing import Dict import botocore @@ -94,7 +94,7 @@ def upload(self, image, resource_name): else: # we need to wait till the image got pushed to ecr, without this workaround sam sync for template # contains image always fail, because the provided ecr uri is not exist. - _log_streamer = LogStreamer(stream=StreamWriter(stream=io.BytesIO(), auto_flush=True)) + _log_streamer = LogStreamer(stream=StreamWriter(stream=StringIO(), auto_flush=True)) _log_streamer.stream_progress(push_logs) except (BuildError, APIError, LogStreamError) as ex: diff --git a/samcli/lib/package/s3_uploader.py b/samcli/lib/package/s3_uploader.py index fe141ada51..95981e92ed 100644 --- a/samcli/lib/package/s3_uploader.py +++ b/samcli/lib/package/s3_uploader.py @@ -265,4 +265,4 @@ def on_progress(self, bytes_transferred, **kwargs): ) sys.stderr.flush() if int(percentage) == 100: # noqa: PLR2004 - sys.stderr.write("\n") + sys.stderr.write(os.linesep) diff --git a/samcli/lib/utils/osutils.py b/samcli/lib/utils/osutils.py index d53dc9ffb5..a9a12bf88c 100644 --- a/samcli/lib/utils/osutils.py +++ b/samcli/lib/utils/osutils.py @@ -1,6 +1,7 @@ """ Common OS utilities """ +import io import logging import os import shutil @@ -78,7 +79,7 @@ def rmtree_if_exists(path: Union[str, Path]): shutil.rmtree(path_obj) -def stdout(): +def stdout() -> io.TextIOWrapper: """ Returns the stdout as a byte stream in a Py2/PY3 compatible manner @@ -87,10 +88,15 @@ def stdout(): io.BytesIO Byte stream of Stdout """ - return sys.stdout.buffer + # ensure stdout is utf8 + sys.stdout.reconfigure(encoding="utf-8") # type:ignore[attr-defined] + # Note(jfuss): sys.stdout is a type typing.TextIO but are initialized to + # io.TextIOWrapper. To make mypy and typing play well, tell mypy to ignore. + return sys.stdout # type:ignore[return-value] -def stderr(): + +def stderr() -> io.TextIOWrapper: """ Returns the stderr as a byte stream in a Py2/PY3 compatible manner @@ -99,7 +105,12 @@ def stderr(): io.BytesIO Byte stream of stderr """ - return sys.stderr.buffer + # ensure stderr is utf8 + sys.stderr.reconfigure(encoding="utf-8") # type:ignore[attr-defined] + + # Note(jfuss): sys.stderr is a type typing.TextIO but are initialized to + # io.TextIOWrapper. To make mypy and typing play well, tell mypy to ignore. + return sys.stderr # type:ignore[return-value] def remove(path): diff --git a/samcli/lib/utils/stream_writer.py b/samcli/lib/utils/stream_writer.py index 1fc62fa690..99f72c1036 100644 --- a/samcli/lib/utils/stream_writer.py +++ b/samcli/lib/utils/stream_writer.py @@ -1,10 +1,11 @@ """ This class acts like a wrapper around output streams to provide any flexibility with output we need """ +from typing import TextIO, Union class StreamWriter: - def __init__(self, stream, auto_flush=False): + def __init__(self, stream: TextIO, auto_flush: bool = False): """ Instatiates new StreamWriter to the specified stream @@ -19,19 +20,33 @@ def __init__(self, stream, auto_flush=False): self._auto_flush = auto_flush @property - def stream(self): + def stream(self) -> TextIO: return self._stream - def write(self, output, encode=False): + def write_bytes(self, output: Union[bytes, bytearray]): """ Writes specified text to the underlying stream Parameters ---------- output bytes-like object - Bytes to write + Bytes to write into buffer """ - self._stream.write(output.encode() if encode else output) + self._stream.buffer.write(output) + + if self._auto_flush: + self._stream.flush() + + def write_str(self, output: str): + """ + Writes specified text to the underlying stream + + Parameters + ---------- + output string object + String to write + """ + self._stream.write(output) if self._auto_flush: self._stream.flush() diff --git a/samcli/lib/utils/subprocess_utils.py b/samcli/lib/utils/subprocess_utils.py index e08ec12e49..1937a44eeb 100644 --- a/samcli/lib/utils/subprocess_utils.py +++ b/samcli/lib/utils/subprocess_utils.py @@ -34,7 +34,7 @@ def default_loading_pattern(stream_writer: Optional[StreamWriter] = None, loadin How frequently to generate the pattern """ stream_writer = stream_writer or StreamWriter(sys.stderr) - stream_writer.write(".") + stream_writer.write_str(".") stream_writer.flush() sleep(loading_pattern_rate) @@ -96,7 +96,7 @@ def _print_loading_pattern(): return_code = process.wait() keep_printing = False - stream_writer.write(os.linesep) + stream_writer.write_str(os.linesep) stream_writer.flush() process_stderr = _check_and_convert_stream_to_string(process.stderr) diff --git a/samcli/local/apigw/local_apigw_service.py b/samcli/local/apigw/local_apigw_service.py index f979b2e9a3..b80b1fc2c2 100644 --- a/samcli/local/apigw/local_apigw_service.py +++ b/samcli/local/apigw/local_apigw_service.py @@ -4,7 +4,7 @@ import json import logging from datetime import datetime -from io import BytesIO +from io import StringIO from time import time from typing import Any, Dict, List, Optional @@ -605,7 +605,7 @@ def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str str A string containing the output from the Lambda function """ - with BytesIO() as stdout: + with StringIO() as stdout: event_str = json.dumps(event, sort_keys=True) stdout_writer = StreamWriter(stdout, auto_flush=True) diff --git a/samcli/local/docker/container.py b/samcli/local/docker/container.py index fc2a190b53..d4c7c93fef 100644 --- a/samcli/local/docker/container.py +++ b/samcli/local/docker/container.py @@ -1,6 +1,8 @@ """ Representation of a generic Docker container """ +import io +import json import logging import os import pathlib @@ -9,7 +11,7 @@ import tempfile import threading import time -from typing import Optional +from typing import Iterator, Optional, Tuple, Union import docker import requests @@ -17,6 +19,7 @@ from samcli.lib.constants import DOCKER_MIN_API_VERSION from samcli.lib.utils.retry import retry +from samcli.lib.utils.stream_writer import StreamWriter from samcli.lib.utils.tar import extract_tarfile from samcli.local.docker.effective_user import ROOT_USER_ID, EffectiveUser from samcli.local.docker.exceptions import ContainerNotStartableException, PortAlreadyInUse @@ -318,7 +321,7 @@ def start(self, input_data=None): raise ex @retry(exc=requests.exceptions.RequestException, exc_raise=ContainerResponseException) - def wait_for_http_response(self, name, event, stdout): + def wait_for_http_response(self, name, event, stdout) -> str: # TODO(sriram-mv): `aws-lambda-rie` is in a mode where the function_name is always "function" # NOTE(sriram-mv): There is a connection timeout set on the http call to `aws-lambda-rie`, however there is not # a read time out for the response received from the server. @@ -328,7 +331,7 @@ def wait_for_http_response(self, name, event, stdout): data=event.encode("utf-8"), timeout=(self.RAPID_CONNECTION_TIMEOUT, None), ) - stdout.write(resp.content) + return json.dumps(json.loads(resp.content), ensure_ascii=False) def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # NOTE(sriram-mv): Let logging happen in its own thread, so that a http request can be sent. @@ -348,11 +351,21 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None): # start the timer for function timeout right before executing the function, as waiting for the socket # can take some time timer = start_timer() if start_timer else None - self.wait_for_http_response(full_path, event, stdout) + response = self.wait_for_http_response(full_path, event, stdout) if timer: timer.cancel() - def wait_for_logs(self, stdout=None, stderr=None): + # NOTE(jfuss): Adding a sleep after we get a response from the contianer but before we + # we write the response to ensure the last thing written to stdout is the container response + time.sleep(1) + stdout.write_str(response) + stdout.flush() + + def wait_for_logs( + self, + stdout: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None, + stderr: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None, + ): # Return instantly if we don't have to fetch any logs if not stdout and not stderr: return @@ -364,7 +377,6 @@ def wait_for_logs(self, stdout=None, stderr=None): # Fetch both stdout and stderr streams from Docker as a single iterator. logs_itr = real_container.attach(stream=True, logs=True, demux=True) - self._write_container_output(logs_itr, stdout=stdout, stderr=stderr) def _wait_for_socket_connection(self) -> None: @@ -415,7 +427,11 @@ def copy(self, from_container_path, to_host_path) -> None: extract_tarfile(file_obj=fp, unpack_dir=to_host_path) @staticmethod - def _write_container_output(output_itr, stdout=None, stderr=None): + def _write_container_output( + output_itr: Iterator[Tuple[bytes, bytes]], + stdout: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None, + stderr: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None, + ): """ Based on the data returned from the Container output, via the iterator, write it to the appropriate streams @@ -434,13 +450,26 @@ def _write_container_output(output_itr, stdout=None, stderr=None): # Iterator returns a tuple of (stdout, stderr) for stdout_data, stderr_data in output_itr: if stdout_data and stdout: - stdout.write(stdout_data) + Container._handle_data_writing(stdout, stdout_data) if stderr_data and stderr: - stderr.write(stderr_data) + Container._handle_data_writing(stderr, stderr_data) + except Exception as ex: LOG.debug("Failed to get the logs from the container", exc_info=ex) + @staticmethod + def _handle_data_writing(output_stream: Union[StreamWriter, io.BytesIO, io.TextIOWrapper], output_data: bytes): + if isinstance(output_stream, StreamWriter): + output_stream.write_bytes(output_data) + output_stream.flush() + + if isinstance(output_stream, io.BytesIO): + output_stream.write(output_data) + + if isinstance(output_stream, io.TextIOWrapper): + output_stream.buffer.write(output_data) + @property def network_id(self): """ diff --git a/samcli/local/docker/lambda_image.py b/samcli/local/docker/lambda_image.py index 923b740edc..72ad813de9 100644 --- a/samcli/local/docker/lambda_image.py +++ b/samcli/local/docker/lambda_image.py @@ -3,6 +3,7 @@ """ import hashlib import logging +import os import platform import re import sys @@ -227,7 +228,7 @@ def build(self, runtime, packagetype, image, layers, architecture, stream=None, or not runtime ): stream_writer = stream or StreamWriter(sys.stderr) - stream_writer.write("Building image...") + stream_writer.write_str("Building image...") stream_writer.flush() self._build_image( image if image else base_image, rapid_image, downloaded_layers, architecture, stream=stream_writer @@ -338,15 +339,15 @@ def set_item_permission(tar_info): platform=get_docker_platform(architecture), ) for log in resp_stream: - stream_writer.write(".") + stream_writer.write_str(".") stream_writer.flush() if "error" in log: - stream_writer.write("\n") + stream_writer.write_str(os.linesep) LOG.exception("Failed to build Docker Image") raise ImageBuildException("Error building docker image: {}".format(log["error"])) - stream_writer.write("\n") + stream_writer.write_str(os.linesep) except (docker.errors.BuildError, docker.errors.APIError) as ex: - stream_writer.write("\n") + stream_writer.write_str(os.linesep) LOG.exception("Failed to build Docker Image") raise ImageBuildException("Building Image failed.") from ex finally: diff --git a/samcli/local/docker/manager.py b/samcli/local/docker/manager.py index a035003bb0..6975828cd1 100644 --- a/samcli/local/docker/manager.py +++ b/samcli/local/docker/manager.py @@ -168,16 +168,16 @@ def pull_image(self, image_name, tag=None, stream=None): raise DockerImagePullFailedException(str(ex)) from ex # io streams, especially StringIO, work only with unicode strings - stream_writer.write("\nFetching {}:{} Docker container image...".format(image_name, tag)) + stream_writer.write_str("\nFetching {}:{} Docker container image...".format(image_name, tag)) # Each line contains information on progress of the pull. Each line is a JSON string for _ in result_itr: # For every line, print a dot to show progress - stream_writer.write(".") + stream_writer.write_str(".") stream_writer.flush() # We are done. Go to the next line - stream_writer.write("\n") + stream_writer.write_str("\n") def has_image(self, image_name): """ diff --git a/samcli/local/lambda_service/local_lambda_invoke_service.py b/samcli/local/lambda_service/local_lambda_invoke_service.py index c6d7506fb2..546066449c 100644 --- a/samcli/local/lambda_service/local_lambda_invoke_service.py +++ b/samcli/local/lambda_service/local_lambda_invoke_service.py @@ -162,7 +162,7 @@ def _invoke_request_handler(self, function_name): request_data = request_data.decode("utf-8") - stdout_stream = io.BytesIO() + stdout_stream = io.StringIO() stdout_stream_writer = StreamWriter(stdout_stream, auto_flush=True) try: diff --git a/samcli/local/services/base_local_service.py b/samcli/local/services/base_local_service.py index fcb7cd95ae..671d48888c 100644 --- a/samcli/local/services/base_local_service.py +++ b/samcli/local/services/base_local_service.py @@ -82,7 +82,7 @@ def service_response(body, headers, status_code): class LambdaOutputParser: @staticmethod - def get_lambda_output(stdout_stream: io.BytesIO) -> Tuple[str, bool]: + def get_lambda_output(stdout_stream: io.StringIO) -> Tuple[str, bool]: """ This method will extract read the given stream and return the response from Lambda function separated out from any log statements it might have outputted. Logs end up in the stdout stream if the Lambda function @@ -100,7 +100,7 @@ def get_lambda_output(stdout_stream: io.BytesIO) -> Tuple[str, bool]: bool If the response is an error/exception from the container """ - lambda_response = stdout_stream.getvalue().decode("utf-8") + lambda_response = stdout_stream.getvalue() # When the Lambda Function returns an Error/Exception, the output is added to the stdout of the container. From # our perspective, the container returned some value, which is not always true. Since the output is the only diff --git a/tests/integration/local/invoke/test_integrations_cli.py b/tests/integration/local/invoke/test_integrations_cli.py index 3604fc4010..70711459d6 100644 --- a/tests/integration/local/invoke/test_integrations_cli.py +++ b/tests/integration/local/invoke/test_integrations_cli.py @@ -291,6 +291,27 @@ def test_invoke_returns_expected_result_when_no_event_given(self): self.assertEqual(process.returncode, 0) self.assertEqual("{}", process_stdout.decode("utf-8")) + @pytest.mark.flaky(reruns=3) + def test_invoke_returns_utf8(self): + command_list = InvokeIntegBase.get_command_list( + "EchoEventFunction", template_path=self.template_path, event_path=self.event_utf8_path + ) + + process = Popen(command_list, stdout=PIPE) + try: + stdout, _ = process.communicate(timeout=TIMEOUT) + except TimeoutExpired: + process.kill() + raise + + process_stdout = stdout.strip() + + with open(self.event_utf8_path) as f: + expected_output = json.dumps(json.load(f), ensure_ascii=False) + + self.assertEqual(process.returncode, 0) + self.assertEqual(expected_output, process_stdout.decode("utf-8")) + @pytest.mark.flaky(reruns=3) def test_invoke_with_env_using_parameters(self): command_list = InvokeIntegBase.get_command_list( diff --git a/tests/unit/commands/local/cli_common/test_invoke_context.py b/tests/unit/commands/local/cli_common/test_invoke_context.py index 3cab08c82a..b89d5b6115 100644 --- a/tests/unit/commands/local/cli_common/test_invoke_context.py +++ b/tests/unit/commands/local/cli_common/test_invoke_context.py @@ -1106,7 +1106,7 @@ def test_must_open_file_for_writing(self): with patch("samcli.commands.local.cli_common.invoke_context.open", m): InvokeContext._setup_log_file(filename) - m.assert_called_with(filename, "wb") + m.assert_called_with(filename, "w", encoding="utf8") class TestInvokeContext_get_debug_context(TestCase): diff --git a/tests/unit/lib/utils/test_osutils.py b/tests/unit/lib/utils/test_osutils.py index bf4794f2c4..6f7a6cf4df 100644 --- a/tests/unit/lib/utils/test_osutils.py +++ b/tests/unit/lib/utils/test_osutils.py @@ -34,9 +34,7 @@ def test_raises_on_cleanup_failure(self, rmdir_mock): @patch("os.rmdir") def test_handles_ignore_error_case(self, rmdir_mock): rmdir_mock.side_effect = OSError("fail") - dir_name = None with osutils.mkdir_temp(ignore_errors=True) as tempdir: - dir_name = tempdir self.assertTrue(os.path.exists(tempdir)) @@ -44,9 +42,6 @@ class Test_stderr(TestCase): def test_must_return_sys_stderr(self): expected_stderr = sys.stderr - if sys.version_info.major > 2: - expected_stderr = sys.stderr.buffer - self.assertEqual(expected_stderr, osutils.stderr()) @@ -54,9 +49,6 @@ class Test_stdout(TestCase): def test_must_return_sys_stdout(self): expected_stdout = sys.stdout - if sys.version_info.major > 2: - expected_stdout = sys.stdout.buffer - self.assertEqual(expected_stdout, osutils.stdout()) diff --git a/tests/unit/lib/utils/test_stream_writer.py b/tests/unit/lib/utils/test_stream_writer.py index cb48955850..a6875b59da 100644 --- a/tests/unit/lib/utils/test_stream_writer.py +++ b/tests/unit/lib/utils/test_stream_writer.py @@ -1,6 +1,7 @@ """ Tests for StreamWriter """ +import io from unittest import TestCase @@ -11,13 +12,13 @@ class TestStreamWriter(TestCase): def test_must_write_to_stream(self): - buffer = "something" + buffer = b"something" stream_mock = Mock() writer = StreamWriter(stream_mock) - writer.write(buffer) + writer.write_bytes(buffer) - stream_mock.write.assert_called_once_with(buffer) + stream_mock.buffer.write.assert_called_once_with(buffer) def test_must_flush_underlying_stream(self): stream_mock = Mock() @@ -31,7 +32,7 @@ def test_auto_flush_must_be_off_by_default(self): stream_mock = Mock() writer = StreamWriter(stream_mock) - writer.write("something") + writer.write_str("something") stream_mock.flush.assert_not_called() @@ -46,6 +47,6 @@ def test_when_auto_flush_on_flush_after_each_write(self): writer = StreamWriter(stream_mock, True) for line in lines: - writer.write(line) + writer.write_str(line) flush_mock.assert_called_once_with() flush_mock.reset_mock() diff --git a/tests/unit/lib/utils/test_subprocess_utils.py b/tests/unit/lib/utils/test_subprocess_utils.py index 969f06085b..a9d39afdd2 100644 --- a/tests/unit/lib/utils/test_subprocess_utils.py +++ b/tests/unit/lib/utils/test_subprocess_utils.py @@ -11,6 +11,7 @@ from parameterized import parameterized from unittest.mock import patch, Mock, call, ANY +from samcli.lib.utils.stream_writer import StreamWriter from samcli.lib.utils.subprocess_utils import ( default_loading_pattern, invoke_subprocess_with_loading_pattern, @@ -64,7 +65,7 @@ def test_loader_stream_uses_passed_in_stdout( @patch("samcli.lib.utils.subprocess_utils.Popen") def test_loader_raises_exception_non_zero_exit_code(self, patched_Popen): standard_error = "an error has occurred" - mock_stream_writer = Mock() + mock_stream_writer = Mock(spec=StreamWriter) mock_process = Mock() mock_process.returncode = 1 mock_process.stdout = None @@ -74,7 +75,7 @@ def test_loader_raises_exception_non_zero_exit_code(self, patched_Popen): with self.assertRaises(LoadingPatternError) as ex: invoke_subprocess_with_loading_pattern({"args": ["ls"]}, mock_pattern, mock_stream_writer) self.assertIn(standard_error, ex.exception.message) - mock_stream_writer.write.assert_called_once_with(os.linesep) + mock_stream_writer.write_str.assert_called_once_with(os.linesep) mock_stream_writer.flush.assert_called_once_with() @patch("samcli.lib.utils.subprocess_utils.Popen") @@ -95,19 +96,19 @@ def test_loader_raises_exception_bad_process(self, patched_Popen): @patch("samcli.lib.utils.subprocess_utils.StreamWriter") def test_default_pattern_default_stream_writer(self, patched_stream_writer): - stream_writer_mock = Mock() + stream_writer_mock = Mock(spec=StreamWriter) patched_stream_writer.return_value = stream_writer_mock default_loading_pattern(loading_pattern_rate=0.01) patched_stream_writer.assert_called_once_with(sys.stderr) - stream_writer_mock.write.assert_called_once_with(".") + stream_writer_mock.write_str.assert_called_once_with(".") stream_writer_mock.flush.assert_called_once_with() @patch("samcli.lib.utils.subprocess_utils.StreamWriter") def test_default_pattern(self, patched_stream_writer): - stream_writer_mock = Mock() + stream_writer_mock = Mock(spec=StreamWriter) default_loading_pattern(stream_writer_mock, 0.01) patched_stream_writer.assert_not_called() - stream_writer_mock.write.assert_called_once_with(".") + stream_writer_mock.write_str.assert_called_once_with(".") stream_writer_mock.flush.assert_called_once_with() @parameterized.expand([("hello".encode("utf-8"), "hello"), ("hello", "hello")]) diff --git a/tests/unit/local/docker/test_container.py b/tests/unit/local/docker/test_container.py index b7b7311563..064bb845cb 100644 --- a/tests/unit/local/docker/test_container.py +++ b/tests/unit/local/docker/test_container.py @@ -9,6 +9,7 @@ from requests import RequestException from samcli.lib.utils.packagetype import IMAGE +from samcli.lib.utils.stream_writer import StreamWriter from samcli.local.docker.container import ( Container, ContainerResponseException, @@ -721,17 +722,17 @@ def test_wait_for_result_waits_for_socket_before_post_request(self, patched_time self.assertEqual(mock_requests.post.call_count, 0) def test_write_container_output_successful(self): - stdout_mock = Mock() - stderr_mock = Mock() + stdout_mock = Mock(spec=StreamWriter) + stderr_mock = Mock(spec=StreamWriter) def _output_iterator(): - yield "Hello", None - yield None, "World" + yield b"Hello", None + yield None, b"World" raise ValueError("The pipe has been ended.") Container._write_container_output(_output_iterator(), stdout_mock, stderr_mock) - stdout_mock.assert_has_calls([call.write("Hello")]) - stderr_mock.assert_has_calls([call.write("World")]) + stdout_mock.assert_has_calls([call.write_bytes(b"Hello")]) + stderr_mock.assert_has_calls([call.write_bytes(b"World")]) class TestContainer_wait_for_logs(TestCase): @@ -785,33 +786,33 @@ class TestContainer_write_container_output(TestCase): def setUp(self): self.output_itr = [(b"stdout1", None), (None, b"stderr1"), (b"stdout2", b"stderr2"), (None, None)] - self.stdout_mock = Mock() - self.stderr_mock = Mock() + self.stdout_mock = Mock(spec=StreamWriter) + self.stderr_mock = Mock(spec=StreamWriter) def test_must_write_stdout_and_stderr_data(self): # All the invalid frames must be ignored Container._write_container_output(self.output_itr, stdout=self.stdout_mock, stderr=self.stderr_mock) - self.stdout_mock.write.assert_has_calls([call(b"stdout1"), call(b"stdout2")]) + self.stdout_mock.write_bytes.assert_has_calls([call(b"stdout1"), call(b"stdout2")]) - self.stderr_mock.write.assert_has_calls([call(b"stderr1"), call(b"stderr2")]) + self.stderr_mock.write_bytes.assert_has_calls([call(b"stderr1"), call(b"stderr2")]) def test_must_write_only_stderr(self): # All the invalid frames must be ignored Container._write_container_output(self.output_itr, stdout=None, stderr=self.stderr_mock) - self.stdout_mock.write.assert_not_called() + self.stdout_mock.write_bytes.assert_not_called() - self.stderr_mock.write.assert_has_calls([call(b"stderr1"), call(b"stderr2")]) + self.stderr_mock.write_bytes.assert_has_calls([call(b"stderr1"), call(b"stderr2")]) def test_must_write_only_stdout(self): Container._write_container_output(self.output_itr, stdout=self.stdout_mock, stderr=None) - self.stdout_mock.write.assert_has_calls([call(b"stdout1"), call(b"stdout2")]) + self.stdout_mock.write_bytes.assert_has_calls([call(b"stdout1"), call(b"stdout2")]) - self.stderr_mock.write.assert_not_called() # stderr must never be called + self.stderr_mock.write_bytes.assert_not_called() # stderr must never be called class TestContainer_wait_for_socket_connection(TestCase): diff --git a/tests/unit/local/docker/test_lambda_image.py b/tests/unit/local/docker/test_lambda_image.py index 1e8f936d98..03b57be804 100644 --- a/tests/unit/local/docker/test_lambda_image.py +++ b/tests/unit/local/docker/test_lambda_image.py @@ -1,4 +1,3 @@ -import io import tempfile from unittest import TestCase @@ -271,7 +270,7 @@ def test_force_building_image_that_doesnt_already_exists( docker_client_mock.images.get.side_effect = ImageNotFound("image not found") docker_client_mock.images.list.return_value = [] - stream = io.StringIO() + stream = Mock() lambda_image = LambdaImage(layer_downloader_mock, False, True, docker_client=docker_client_mock) actual_image_id = lambda_image.build( @@ -311,7 +310,7 @@ def test_force_building_image_on_daemon_404( docker_client_mock.images.get.side_effect = NotFound("image not found") docker_client_mock.images.list.return_value = [] - stream = io.StringIO() + stream = Mock() lambda_image = LambdaImage(layer_downloader_mock, False, True, docker_client=docker_client_mock) actual_image_id = lambda_image.build( @@ -351,7 +350,7 @@ def test_docker_distribution_api_error_on_daemon_api_error( docker_client_mock.images.get.side_effect = APIError("error from docker daemon") docker_client_mock.images.list.return_value = [] - stream = io.StringIO() + stream = Mock() lambda_image = LambdaImage(layer_downloader_mock, False, True, docker_client=docker_client_mock) with self.assertRaises(DockerDistributionAPIError): @@ -377,7 +376,7 @@ def test_not_force_building_image_that_doesnt_already_exists( docker_client_mock.images.get.side_effect = ImageNotFound("image not found") docker_client_mock.images.list.return_value = [] - stream = io.StringIO() + stream = Mock() lambda_image = LambdaImage(layer_downloader_mock, False, False, docker_client=docker_client_mock) actual_image_id = lambda_image.build( diff --git a/tests/unit/local/docker/test_manager.py b/tests/unit/local/docker/test_manager.py index ada69903ea..4cb42bbd02 100644 --- a/tests/unit/local/docker/test_manager.py +++ b/tests/unit/local/docker/test_manager.py @@ -1,8 +1,6 @@ """ Tests container manager """ - -import io import importlib from unittest import TestCase from unittest.mock import Mock, patch, MagicMock, ANY, call @@ -218,17 +216,29 @@ def setUp(self): self.manager = ContainerManager(docker_client=self.mock_docker_client) def test_must_pull_and_print_progress_dots(self): - stream = io.StringIO() + stream = Mock() pull_result = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] self.mock_docker_client.api.pull.return_value = pull_result - expected_stream_output = "\nFetching {}:latest Docker container image...{}\n".format( - self.image_name, "." * len(pull_result) # Progress bar will print one dot per response from pull API - ) + expected_stream_calls = [ + call(f"\nFetching {self.image_name}:latest Docker container image..."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("."), + call("\n"), + ] self.manager.pull_image(self.image_name, stream=stream) self.mock_docker_client.api.pull.assert_called_with(self.image_name, stream=True, decode=True, tag="latest") - self.assertEqual(stream.getvalue(), expected_stream_output) + + stream.write_str.assert_has_calls(expected_stream_calls) def test_must_raise_if_image_not_found(self): msg = "some error" diff --git a/tests/unit/local/services/test_base_local_service.py b/tests/unit/local/services/test_base_local_service.py index fec13e25c9..34bc44c193 100644 --- a/tests/unit/local/services/test_base_local_service.py +++ b/tests/unit/local/services/test_base_local_service.py @@ -66,17 +66,17 @@ def test_create_returns_not_implemented(self): class TestLambdaOutputParser(TestCase): @parameterized.expand( [ - param("with mixed data and json response", b'data\n{"a": "b"}', 'data\n{"a": "b"}'), - param("with response as string", b"response", "response"), - param("with json response only", b'{"a": "b"}', '{"a": "b"}'), - param("with one new line and json", b'\n{"a": "b"}', '\n{"a": "b"}'), - param("with response only as string", b"this is the response line", "this is the response line"), - param("with whitespaces", b'data\n{"a": "b"} \n\n\n', 'data\n{"a": "b"} \n\n\n'), - param("with empty data", b"", ""), - param("with just new lines", b"\n\n", "\n\n"), + param("with mixed data and json response", 'data\n{"a": "b"}', 'data\n{"a": "b"}'), + param("with response as string", "response", "response"), + param("with json response only", '{"a": "b"}', '{"a": "b"}'), + param("with one new line and json", '\n{"a": "b"}', '\n{"a": "b"}'), + param("with response only as string", "this is the response line", "this is the response line"), + param("with whitespaces", 'data\n{"a": "b"} \n\n\n', 'data\n{"a": "b"} \n\n\n'), + param("with empty data", "", ""), + param("with just new lines", "\n\n", "\n\n"), param( "with whitespaces", - b"\n \n \n", + "\n \n \n", "\n \n \n", ), ]