Skip to content

Commit

Permalink
BUG: Don't close stream passed to PdfWriter.write() (#2909)
Browse files Browse the repository at this point in the history
Closes #2905.
  • Loading branch information
alexaryn authored Oct 30, 2024
1 parent 9e0fce7 commit 98aa974
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
33 changes: 24 additions & 9 deletions pypdf/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StreamType,
_get_max_pdf_version_header,
deprecate,
deprecate_no_replacement,
deprecation_with_replacement,
logger_warning,
)
Expand Down Expand Up @@ -236,10 +237,9 @@ def _get_clone_from(
or Path(str(fileobj)).stat().st_size == 0
):
cloning = False
if isinstance(fileobj, (IO, BytesIO)):
if isinstance(fileobj, (IOBase, BytesIO)):
t = fileobj.tell()
fileobj.seek(-1, 2)
if fileobj.tell() == 0:
if fileobj.seek(0, 2) == 0:
cloning = False
fileobj.seek(t, 0)
if cloning:
Expand All @@ -250,7 +250,8 @@ def _get_clone_from(
# to prevent overwriting
self.temp_fileobj = fileobj
self.fileobj = ""
self.with_as_usage = False
self._with_as_usage = False
self._cloned = False
# The root of our page tree node.
pages = DictionaryObject()
pages.update(
Expand All @@ -268,6 +269,7 @@ def _get_clone_from(
if not isinstance(clone_from, PdfReader):
clone_from = PdfReader(clone_from)
self.clone_document_from_reader(clone_from)
self._cloned = True
else:
self._pages = self._add_object(pages)
# root object
Expand Down Expand Up @@ -355,11 +357,23 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None:

return self.root_object.xmp_metadata # type: ignore

@property
def with_as_usage(self) -> bool:
deprecate_no_replacement("with_as_usage", "6.0")
return self._with_as_usage

@with_as_usage.setter
def with_as_usage(self, value: bool) -> None:
deprecate_no_replacement("with_as_usage", "6.0")
self._with_as_usage = value

def __enter__(self) -> "PdfWriter":
"""Store that writer is initialized by 'with'."""
"""Store how writer is initialized by 'with'."""
c: bool = self._cloned
t = self.temp_fileobj
self.__init__() # type: ignore
self.with_as_usage = True
self._cloned = c
self._with_as_usage = True
self.fileobj = t # type: ignore
return self

Expand All @@ -370,7 +384,7 @@ def __exit__(
traceback: Optional[TracebackType],
) -> None:
"""Write data to the fileobj."""
if self.fileobj:
if self.fileobj and not self._cloned:
self.write(self.fileobj)

def _repr_mimebundle_(
Expand Down Expand Up @@ -1406,13 +1420,14 @@ def write(self, stream: Union[Path, StrByteType]) -> Tuple[bool, IO[Any]]:

if isinstance(stream, (str, Path)):
stream = FileIO(stream, "wb")
self.with_as_usage = True
my_file = True

self.write_stream(stream)

if self.with_as_usage:
if my_file:
stream.close()
else:
stream.flush()

return my_file, stream

Expand Down
40 changes: 40 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,3 +2481,43 @@ def test_append_pdf_with_dest_without_page(caplog):
writer.append(reader)
assert "/__WKANCHOR_8" not in writer.named_destinations
assert len(writer.named_destinations) == 3


def test_stream_not_closed():
"""Tests for #2905"""
src = RESOURCE_ROOT / "pdflatex-outline.pdf"
with NamedTemporaryFile(suffix=".pdf") as tmp:
with PdfReader(src) as reader, PdfWriter() as writer:
writer.add_page(reader.pages[0])
writer.write(tmp)
assert not tmp.file.closed

with NamedTemporaryFile(suffix=".pdf") as target:
with PdfWriter(target.file) as writer:
writer.add_blank_page(100, 100)
assert not target.file.closed

with open(src, "rb") as fileobj:
with PdfWriter(fileobj) as writer:
pass
assert not fileobj.closed


def test_auto_write(tmp_path):
"""Another test for #2905"""
target = tmp_path / "out.pdf"
with PdfWriter(target) as writer:
writer.add_blank_page(100, 100)
assert target.stat().st_size > 0


def test_deprecate_with_as():
"""Yet another test for #2905"""
with PdfWriter() as writer:
with pytest.warns(DeprecationWarning) as w:
val = writer.with_as_usage
assert "with_as_usage is deprecated" in w[0].message.args[0]
assert val
with pytest.warns(DeprecationWarning) as w:
writer.with_as_usage = val # old code allowed setting this, so...
assert "with_as_usage is deprecated" in w[0].message.args[0]

0 comments on commit 98aa974

Please sign in to comment.