Skip to content

Commit

Permalink
Fix Pipeline reference leak in PythonFunction. (#5668)
Browse files Browse the repository at this point in the history
* Use a stub pipeline as PythonFunction's "current pipeline" to avoid pipeline self-referencing and self-deleting from within its ThreadPool.

---------

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient authored Oct 10, 2024
1 parent 988265a commit f8a76a6
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 31 deletions.
10 changes: 7 additions & 3 deletions dali/operators/python_function/python_function.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,11 @@ namespace dali {

DALI_SCHEMA(PythonFunctionBase)
.AddArg("function",
"Function object.",
R"code(A callable object that defines the function of the operator.
.. warning::
The function must not hold a reference to the pipeline in which it is used. If it does,
a circular reference to the pipeline will form and the pipeline will never be freed.)code",
DALI_PYTHON_OBJECT)
.AddOptionalArg("num_outputs", R"code(Number of outputs.)code", 1)
.AddOptionalArg<std::vector<TensorLayout>>("output_layouts",
Expand All @@ -41,7 +45,7 @@ a more universal data format, see :meth:`nvidia.dali.fn.dl_tensor_python_functio
The function should not modify input tensors.
.. warning::
This operator is not compatible with TensorFlow integration.
This operator is not compatible with TensorFlow integration.
.. warning::
When the pipeline has conditional execution enabled, additional steps must be taken to
Expand Down
5 changes: 3 additions & 2 deletions dali/python/nvidia/dali/ops/_operators/python_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def __init__(self, function, num_outputs=1, **kwargs):

def __call__(self, *inputs, **kwargs):
inputs = ops._preprocess_inputs(inputs, impl_name, self._device, None)
self.pipeline = _Pipeline.current()
if self.pipeline is None:
curr_pipe = _Pipeline.current()
if curr_pipe is None:
_Pipeline._raise_pipeline_required("PythonFunction operator")
self.pipeline = curr_pipe._stub()

for inp in inputs:
if not isinstance(inp, _DataNode):
Expand Down
28 changes: 28 additions & 0 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from threading import local as tls
from . import data_node as _data_node
import atexit
import copy
import ctypes
import functools
import inspect
Expand Down Expand Up @@ -1764,6 +1765,33 @@ def _generate_build_args(self):
for (name, dev), dtype, ndim in zip(self._names_and_devices, dtypes, ndims)
]

def _stub(self):
"""Produce a stub by shallow-copying the pipeline, removing the backend and forbidding
operations that require the backend.
Stub pipelines are necessary in contexts where passing the actual pipeline would cause
circular reference - notably, PythonFunction operator.
"""
stub = copy.copy(self)
stub._pipe = None

def short_circuit(self, *args, **kwargs):
raise RuntimeError("This method is forbidden in current context")

stub.start_py_workers = short_circuit
stub.build = short_circuit
stub.run = short_circuit
stub.schedule_run = short_circuit
stub.outputs = short_circuit
stub.share_outputs = short_circuit
stub.release_outputs = short_circuit
stub.add_sink = short_circuit
stub.checkpoint = short_circuit
stub.set_outputs = short_circuit
stub.executor_statistics = short_circuit
stub.external_source_shm_statistics = short_circuit
return stub


def _shutdown_pipelines():
for weak in list(Pipeline._pipes):
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/pytorch/_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def torch_wrapper(self, batch_processing, function, device, *args):
)

def __call__(self, *inputs, **kwargs):
pipeline = Pipeline.current()
pipeline = Pipeline.current()._stub()
if pipeline is None:
Pipeline._raise_pipeline_required("TorchPythonFunction")
if self.stream is None:
Expand Down
39 changes: 14 additions & 25 deletions dali/test/python/operator_2/test_python_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -649,30 +649,6 @@ def py_fun_pipeline():
pipe.run()


def verify_pipeline(pipeline, input):
assert pipeline is Pipeline.current()
return input


def test_current_pipeline():
pipe1 = Pipeline(13, 4, 0)
with pipe1:
dummy = types.Constant(numpy.ones((1)))
output = fn.python_function(dummy, function=lambda inp: verify_pipeline(pipe1, inp))
pipe1.set_outputs(output)

pipe2 = Pipeline(6, 2, 0)
with pipe2:
dummy = types.Constant(numpy.ones((1)))
output = fn.python_function(dummy, function=lambda inp: verify_pipeline(pipe2, inp))
pipe2.set_outputs(output)

pipe1.build()
pipe2.build()
pipe1.run()
pipe2.run()


@params(
numpy.bool_,
numpy.int_,
Expand Down Expand Up @@ -716,3 +692,16 @@ def test_pipe():
pipe.build()

_ = pipe.run()


def test_delete_pipe_while_function_running():
def func(x):
time.sleep(0.02)
return x

for i in range(5):
with Pipeline(batch_size=1, num_threads=1, device_id=None) as pipe:
pipe.set_outputs(fn.python_function(types.Constant(0), function=func))
pipe.build()
pipe.run()
del pipe

0 comments on commit f8a76a6

Please sign in to comment.