Skip to content

Commit

Permalink
fix(invoke): Write in UTF-8 string instead of bytes (#5642)
Browse files Browse the repository at this point in the history
* Revert "fix: Revert UTF-8 fixes #5485 and #5427 (#5512)"

This reverts commit 36f8bf9.

* Enforce utf8 on stdout/stderr/logfile

---------

Co-authored-by: Jacob Fuss <jfuss@users.noreply.github.com>
  • Loading branch information
jfuss and jfuss authored Aug 1, 2023
1 parent d264edd commit 253852c
Show file tree
Hide file tree
Showing 23 changed files with 191 additions and 112 deletions.
8 changes: 4 additions & 4 deletions samcli/commands/local/cli_common/invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import IO, Any, Dict, List, Optional, Tuple, Type, cast
from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, cast

from samcli.commands._utils.template import TemplateFailedParsingException, TemplateNotFoundException
from samcli.commands.exceptions import ContainersInitializationException
Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(
self._stacks: List[Stack] = None # type: ignore
self._env_vars_value: Optional[Dict] = None
self._container_env_vars_value: Optional[Dict] = None
self._log_file_handle: Optional[IO] = None
self._log_file_handle: Optional[TextIO] = None
self._debug_context: Optional[DebugContext] = None
self._layers_downloader: Optional[LayerDownloader] = None
self._container_manager: Optional[ContainerManager] = None
Expand Down Expand Up @@ -490,7 +490,7 @@ def _get_env_vars_value(filename: Optional[str]) -> Optional[Dict]:
) from ex

@staticmethod
def _setup_log_file(log_file: Optional[str]) -> Optional[IO]:
def _setup_log_file(log_file: Optional[str]) -> Optional[TextIO]:
"""
Open a log file if necessary and return the file handle. This will create a file if it does not exist
Expand All @@ -500,7 +500,7 @@ def _setup_log_file(log_file: Optional[str]) -> Optional[IO]:
if not log_file:
return None

return open(log_file, "wb")
return open(log_file, "w", encoding="utf8")

@staticmethod
def _get_debug_context(
Expand Down
4 changes: 2 additions & 2 deletions samcli/commands/remote/remote_invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class DefaultRemoteInvokeResponseConsumer(RemoteInvokeConsumer[RemoteInvokeRespo
_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeResponse) -> None:
self._stream_writer.write(cast(str, remote_invoke_response.response).encode())
self._stream_writer.write_bytes(cast(str, remote_invoke_response.response).encode())


@dataclass
Expand All @@ -254,4 +254,4 @@ class DefaultRemoteInvokeLogConsumer(RemoteInvokeConsumer[RemoteInvokeLogOutput]
_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeLogOutput) -> None:
self._stream_writer.write(remote_invoke_response.log_output.encode())
self._stream_writer.write_bytes(remote_invoke_response.log_output.encode())
28 changes: 13 additions & 15 deletions samcli/lib/docker/log_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,21 @@ def stream_progress(self, logs: docker.APIClient.logs):
else:
curr_log_line_id = ids[_id]
change_cursor_count = len(ids) - curr_log_line_id
self._stream.write(
self._stream.write_str(
self._cursor_up_formatter.cursor_format(change_cursor_count)
+ self._cursor_left_formatter.cursor_format(),
encode=True,
+ self._cursor_left_formatter.cursor_format()
)

self._stream_write(_id, status, stream, progress, error)

if _id:
self._stream.write(
self._stream.write_str(
self._cursor_down_formatter.cursor_format(change_cursor_count)
+ self._cursor_left_formatter.cursor_format(),
encode=True,
+ self._cursor_left_formatter.cursor_format()
)
self._stream.write(os.linesep, encode=True)
self._stream.write_str(os.linesep)

def _stream_write(self, _id: str, status: str, stream: bytes, progress: str, error: str):
def _stream_write(self, _id: str, status: str, stream: str, progress: str, error: str):
"""
Write stream information to stderr, if the stream information contains a log id,
use the carriage return character to rewrite that particular line.
Expand All @@ -80,14 +78,14 @@ def _stream_write(self, _id: str, status: str, stream: bytes, progress: str, err

# NOTE(sriram-mv): Required for the purposes of when the cursor overflows existing terminal buffer.
if not stream:
self._stream.write(os.linesep, encode=True)
self._stream.write(
self._cursor_up_formatter.cursor_format() + self._cursor_left_formatter.cursor_format(), encode=True
self._stream.write_str(os.linesep)
self._stream.write_str(
self._cursor_up_formatter.cursor_format() + self._cursor_left_formatter.cursor_format()
)
self._stream.write(self._cursor_clear_formatter.cursor_format(), encode=True)
self._stream.write_str(self._cursor_clear_formatter.cursor_format())

if not _id:
self._stream.write(stream, encode=True)
self._stream.write(status, encode=True)
self._stream.write_str(stream)
self._stream.write_str(status)
else:
self._stream.write(f"\r{_id}: {status} {progress}", encode=True)
self._stream.write_str(f"\r{_id}: {status} {progress}")
4 changes: 2 additions & 2 deletions samcli/lib/package/ecr_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Client for uploading packaged artifacts to ecr
"""
import base64
import io
import logging
from io import StringIO
from typing import Dict

import botocore
Expand Down Expand Up @@ -94,7 +94,7 @@ def upload(self, image, resource_name):
else:
# we need to wait till the image got pushed to ecr, without this workaround sam sync for template
# contains image always fail, because the provided ecr uri is not exist.
_log_streamer = LogStreamer(stream=StreamWriter(stream=io.BytesIO(), auto_flush=True))
_log_streamer = LogStreamer(stream=StreamWriter(stream=StringIO(), auto_flush=True))
_log_streamer.stream_progress(push_logs)

except (BuildError, APIError, LogStreamError) as ex:
Expand Down
2 changes: 1 addition & 1 deletion samcli/lib/package/s3_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,4 @@ def on_progress(self, bytes_transferred, **kwargs):
)
sys.stderr.flush()
if int(percentage) == 100: # noqa: PLR2004
sys.stderr.write("\n")
sys.stderr.write(os.linesep)
19 changes: 15 additions & 4 deletions samcli/lib/utils/osutils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Common OS utilities
"""
import io
import logging
import os
import shutil
Expand Down Expand Up @@ -78,7 +79,7 @@ def rmtree_if_exists(path: Union[str, Path]):
shutil.rmtree(path_obj)


def stdout():
def stdout() -> io.TextIOWrapper:
"""
Returns the stdout as a byte stream in a Py2/PY3 compatible manner
Expand All @@ -87,10 +88,15 @@ def stdout():
io.BytesIO
Byte stream of Stdout
"""
return sys.stdout.buffer
# ensure stdout is utf8
sys.stdout.reconfigure(encoding="utf-8") # type:ignore[attr-defined]

# Note(jfuss): sys.stdout is a type typing.TextIO but are initialized to
# io.TextIOWrapper. To make mypy and typing play well, tell mypy to ignore.
return sys.stdout # type:ignore[return-value]

def stderr():

def stderr() -> io.TextIOWrapper:
"""
Returns the stderr as a byte stream in a Py2/PY3 compatible manner
Expand All @@ -99,7 +105,12 @@ def stderr():
io.BytesIO
Byte stream of stderr
"""
return sys.stderr.buffer
# ensure stderr is utf8
sys.stderr.reconfigure(encoding="utf-8") # type:ignore[attr-defined]

# Note(jfuss): sys.stderr is a type typing.TextIO but are initialized to
# io.TextIOWrapper. To make mypy and typing play well, tell mypy to ignore.
return sys.stderr # type:ignore[return-value]


def remove(path):
Expand Down
25 changes: 20 additions & 5 deletions samcli/lib/utils/stream_writer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
This class acts like a wrapper around output streams to provide any flexibility with output we need
"""
from typing import TextIO, Union


class StreamWriter:
def __init__(self, stream, auto_flush=False):
def __init__(self, stream: TextIO, auto_flush: bool = False):
"""
Instatiates new StreamWriter to the specified stream
Expand All @@ -19,19 +20,33 @@ def __init__(self, stream, auto_flush=False):
self._auto_flush = auto_flush

@property
def stream(self):
def stream(self) -> TextIO:
return self._stream

def write(self, output, encode=False):
def write_bytes(self, output: Union[bytes, bytearray]):
"""
Writes specified text to the underlying stream
Parameters
----------
output bytes-like object
Bytes to write
Bytes to write into buffer
"""
self._stream.write(output.encode() if encode else output)
self._stream.buffer.write(output)

if self._auto_flush:
self._stream.flush()

def write_str(self, output: str):
"""
Writes specified text to the underlying stream
Parameters
----------
output string object
String to write
"""
self._stream.write(output)

if self._auto_flush:
self._stream.flush()
Expand Down
4 changes: 2 additions & 2 deletions samcli/lib/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def default_loading_pattern(stream_writer: Optional[StreamWriter] = None, loadin
How frequently to generate the pattern
"""
stream_writer = stream_writer or StreamWriter(sys.stderr)
stream_writer.write(".")
stream_writer.write_str(".")
stream_writer.flush()
sleep(loading_pattern_rate)

Expand Down Expand Up @@ -96,7 +96,7 @@ def _print_loading_pattern():
return_code = process.wait()
keep_printing = False

stream_writer.write(os.linesep)
stream_writer.write_str(os.linesep)
stream_writer.flush()
process_stderr = _check_and_convert_stream_to_string(process.stderr)

Expand Down
4 changes: 2 additions & 2 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from datetime import datetime
from io import BytesIO
from io import StringIO
from time import time
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -605,7 +605,7 @@ def _invoke_lambda_function(self, lambda_function_name: str, event: dict) -> str
str
A string containing the output from the Lambda function
"""
with BytesIO() as stdout:
with StringIO() as stdout:
event_str = json.dumps(event, sort_keys=True)
stdout_writer = StreamWriter(stdout, auto_flush=True)

Expand Down
47 changes: 38 additions & 9 deletions samcli/local/docker/container.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Representation of a generic Docker container
"""
import io
import json
import logging
import os
import pathlib
Expand All @@ -9,14 +11,15 @@
import tempfile
import threading
import time
from typing import Optional
from typing import Iterator, Optional, Tuple, Union

import docker
import requests
from docker.errors import NotFound as DockerNetworkNotFound

from samcli.lib.constants import DOCKER_MIN_API_VERSION
from samcli.lib.utils.retry import retry
from samcli.lib.utils.stream_writer import StreamWriter
from samcli.lib.utils.tar import extract_tarfile
from samcli.local.docker.effective_user import ROOT_USER_ID, EffectiveUser
from samcli.local.docker.exceptions import ContainerNotStartableException, PortAlreadyInUse
Expand Down Expand Up @@ -318,7 +321,7 @@ def start(self, input_data=None):
raise ex

@retry(exc=requests.exceptions.RequestException, exc_raise=ContainerResponseException)
def wait_for_http_response(self, name, event, stdout):
def wait_for_http_response(self, name, event, stdout) -> str:
# TODO(sriram-mv): `aws-lambda-rie` is in a mode where the function_name is always "function"
# NOTE(sriram-mv): There is a connection timeout set on the http call to `aws-lambda-rie`, however there is not
# a read time out for the response received from the server.
Expand All @@ -328,7 +331,7 @@ def wait_for_http_response(self, name, event, stdout):
data=event.encode("utf-8"),
timeout=(self.RAPID_CONNECTION_TIMEOUT, None),
)
stdout.write(resp.content)
return json.dumps(json.loads(resp.content), ensure_ascii=False)

def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None):
# NOTE(sriram-mv): Let logging happen in its own thread, so that a http request can be sent.
Expand All @@ -348,11 +351,21 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None):
# start the timer for function timeout right before executing the function, as waiting for the socket
# can take some time
timer = start_timer() if start_timer else None
self.wait_for_http_response(full_path, event, stdout)
response = self.wait_for_http_response(full_path, event, stdout)
if timer:
timer.cancel()

def wait_for_logs(self, stdout=None, stderr=None):
# NOTE(jfuss): Adding a sleep after we get a response from the contianer but before we
# we write the response to ensure the last thing written to stdout is the container response
time.sleep(1)
stdout.write_str(response)
stdout.flush()

def wait_for_logs(
self,
stdout: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None,
stderr: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None,
):
# Return instantly if we don't have to fetch any logs
if not stdout and not stderr:
return
Expand All @@ -364,7 +377,6 @@ def wait_for_logs(self, stdout=None, stderr=None):

# Fetch both stdout and stderr streams from Docker as a single iterator.
logs_itr = real_container.attach(stream=True, logs=True, demux=True)

self._write_container_output(logs_itr, stdout=stdout, stderr=stderr)

def _wait_for_socket_connection(self) -> None:
Expand Down Expand Up @@ -415,7 +427,11 @@ def copy(self, from_container_path, to_host_path) -> None:
extract_tarfile(file_obj=fp, unpack_dir=to_host_path)

@staticmethod
def _write_container_output(output_itr, stdout=None, stderr=None):
def _write_container_output(
output_itr: Iterator[Tuple[bytes, bytes]],
stdout: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None,
stderr: Optional[Union[StreamWriter, io.BytesIO, io.TextIOWrapper]] = None,
):
"""
Based on the data returned from the Container output, via the iterator, write it to the appropriate streams
Expand All @@ -434,13 +450,26 @@ def _write_container_output(output_itr, stdout=None, stderr=None):
# Iterator returns a tuple of (stdout, stderr)
for stdout_data, stderr_data in output_itr:
if stdout_data and stdout:
stdout.write(stdout_data)
Container._handle_data_writing(stdout, stdout_data)

if stderr_data and stderr:
stderr.write(stderr_data)
Container._handle_data_writing(stderr, stderr_data)

except Exception as ex:
LOG.debug("Failed to get the logs from the container", exc_info=ex)

@staticmethod
def _handle_data_writing(output_stream: Union[StreamWriter, io.BytesIO, io.TextIOWrapper], output_data: bytes):
if isinstance(output_stream, StreamWriter):
output_stream.write_bytes(output_data)
output_stream.flush()

if isinstance(output_stream, io.BytesIO):
output_stream.write(output_data)

if isinstance(output_stream, io.TextIOWrapper):
output_stream.buffer.write(output_data)

@property
def network_id(self):
"""
Expand Down
Loading

0 comments on commit 253852c

Please sign in to comment.