Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Jan 17, 2024
1 parent 4e532c1 commit d665985
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 46 deletions.
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/_autograph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nvidia.dali._autograph.impl.api import convert
from nvidia.dali._autograph.impl.api import converted_call
from nvidia.dali._autograph.impl.api import do_not_convert
from nvidia.dali._autograph.impl.api import autograph_artifact
from nvidia.dali._autograph.impl.api import is_autograph_artifact

# from nvidia.dali._autograph.impl.api import StackTraceMapper
Expand Down
24 changes: 16 additions & 8 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,9 @@ def my_pipe():
"""

def actual_decorator(func):
if _conditionals._autograph.is_autograph_artifact(func):
raise ValueError("Pipeline definition cannot be marked with @do_not_convert.")

@functools.wraps(func)
def create_pipeline(*args, **kwargs):
conditionals_on = kwargs.get("enable_conditionals", enable_conditionals)
Expand Down Expand Up @@ -1897,23 +1900,27 @@ def do_not_convert(func: _F = None) -> _F:
to transform the code, enabling us to rewrite and detect the ``if`` statements, so they can be
used in processing the DALI pipeline.
When used with :meth:`external source <nvidia.dali.fn.external_source>` in parallel mode
(``parallel=True``), this may interfere with the serialization of the provided ``source``
parameter. To prevent this, functions that are used to create the ``source`` parameter,
should be decorated with :meth:`@do_not_convert <nvidia.dali.pipeline.do_not_convert>`.
The AutoGraph conversion is applied to any top-level function or method called within the
pipeline definition (as well as the pipeline definition itself).
When a function is converted, all functions defined within its syntactical scope are also
converted.
converted. The rewriting, among other effects, makes these functions non-serializable.
To stop a function from being converted, its top-level encompassing function must be marked
with this decorator. This may sometimes require refactoring the function to outer scope.
Parallel mode of :meth:`external source <nvidia.dali.fn.external_source>` (``parallel=True``),
requires that its ``source`` parameter is serializable. To prevent the rewriting of the
``source``, the functions that are used to create the ``source``,
should be decorated with :meth:`@do_not_convert <nvidia.dali.pipeline.do_not_convert>`.
.. note::
Only functions that do not process :class:`DataNode` (so do not use DALI operators)
should be marked with this decorator.
.. note::
If a function is declared outside of the pipeline definition, and is passed as a parameter,
but not directly invoked within the pipeline definition, it will not be converted.
For example::
from nvidia.dali import pipeline_def, fn
Expand Down Expand Up @@ -1954,12 +1961,12 @@ def pipe():
return do_not_convert

if getattr(func, "_is_pipeline_def", False):
# TODO(klecki): The other way round as well?
raise ValueError("Pipeline definition cannot be marked with @do_not_convert.")

def wrapper(*args, **kwargs):
result = func(*args, **kwargs)

# Best effort at preventing user from not-converting pipeline code.
def disallow_data_node(node):
if isinstance(node, DataNode):
raise TypeError(
Expand All @@ -1974,7 +1981,8 @@ def disallow_data_node(node):
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)

return _conditionals._autograph.do_not_convert(wrapper)
# TODO(klecki): We may also just use _autograph.autograph_artifact(func) here.
return _conditionals._autograph.autograph_artifact(wrapper)


def _collect_ops(output_nodes):
Expand Down
13 changes: 7 additions & 6 deletions dali/test/python/conditionals/test_external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,19 @@ def test_parallel_es_with_not_converted_callback():
es_with_nonlocal_not_converted_source(True)


@raises(ValueError, "EEEEAEEAEAE")
@raises(
TypeError,
"Functions that process DataNodes should not be marked with @do_not_convert. Found return "
"element of class DataNode when calling*",
)
def test_do_not_convert_data_node():
# TODO(klecki): Somehow this breaks the single required argument
@do_not_convert
def source(si):
def helper_constant():
return types.Constant(np.array([10]))

print(f"{inspect.signature(source)=}")

@pipeline_def(batch_size=4, num_threads=1, device_id=None, enable_conditionals=True)
def pipe():
return fn.external_source(source=source, batch=False)
return helper_constant()

p = pipe()
p.build()
Expand Down
121 changes: 89 additions & 32 deletions dali/test/python/operator_1/test_numba_func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
import numpy as np
import os
from nvidia.dali import pipeline_def
from nvidia.dali.pipeline import do_not_convert
import nvidia.dali as dali
import nvidia.dali.fn as fn
import nvidia.dali.types as dali_types
Expand Down Expand Up @@ -127,37 +128,6 @@ def get_data_zeros(shapes, dtype):
return [np.zeros(shape, dtype=dtype) for shape in shapes]


@pipeline_def
def numba_func_pipe(
shapes,
dtype,
device="cpu",
run_fn=None,
out_types=None,
in_types=None,
outs_ndim=None,
ins_ndim=None,
setup_fn=None,
batch_processing=None,
blocks=None,
threads_per_block=None,
):
data = fn.external_source(lambda: get_data(shapes, dtype), batch=True, device=device)
return numba_function(
data,
run_fn=run_fn,
out_types=out_types,
in_types=in_types,
outs_ndim=outs_ndim,
ins_ndim=ins_ndim,
setup_fn=setup_fn,
batch_processing=batch_processing,
device=device,
blocks=blocks,
threads_per_block=threads_per_block,
)


def _testimpl_numba_func(
device,
shapes,
Expand All @@ -172,7 +142,38 @@ def _testimpl_numba_func(
expected_out,
blocks=None,
threads_per_block=None,
enable_conditionals=False,
):
@pipeline_def(enable_conditionals=enable_conditionals)
def numba_func_pipe(
shapes,
dtype,
device="cpu",
run_fn=None,
out_types=None,
in_types=None,
outs_ndim=None,
ins_ndim=None,
setup_fn=None,
batch_processing=None,
blocks=None,
threads_per_block=None,
):
data = fn.external_source(lambda: get_data(shapes, dtype), batch=True, device=device)
return numba_function(
data,
run_fn=run_fn,
out_types=out_types,
in_types=in_types,
outs_ndim=outs_ndim,
ins_ndim=ins_ndim,
setup_fn=setup_fn,
batch_processing=batch_processing,
device=device,
blocks=blocks,
threads_per_block=threads_per_block,
)

batch_size = len(shapes)
pipe = numba_func_pipe(
batch_size=batch_size,
Expand Down Expand Up @@ -308,6 +309,62 @@ def test_numba_func():
)


@with_setup(check_numba_compatibility_cpu)
def test_numba_func_with_cond():
# Check if the do_not_convert decorator doesn't mess with numba running the function.
_testimpl_numba_func(
device="cpu",
shapes=[(10, 10, 10)],
dtype=np.uint8,
run_fn=do_not_convert(set_all_values_to_255_batch),
out_types=[dali_types.UINT8],
in_types=[dali_types.UINT8],
outs_ndim=[3],
ins_ndim=[3],
setup_fn=None,
batch_processing=True,
expected_out=[np.full((10, 10, 10), 255, dtype=np.uint8)],
enable_conditionals=True,
)


@with_setup(check_numba_compatibility_cpu)
def test_numba_func_with_converted_processing():
# Check if the autograph doesn't mess too much with numba running the function.

device = ("cpu",)
shapes = ([(10, 10, 10)],)
dtype = (np.uint8,)
expected_out = ([np.full((10, 10, 10), 255, dtype=np.uint8)],)

@pipeline_def(enable_conditionals=True)
def numba_func_pipe():
def set_all_values_to_255_batch(out0, in0):
out0[0][:] = 255

data = fn.external_source(lambda: get_data(shapes, dtype), batch=True, device=device)
return numba_function(
data,
run_fn=set_all_values_to_255_batch,
out_types=[dali_types.UINT8],
in_types=[dali_types.UINT8],
outs_ndim=[3],
ins_ndim=[3],
setup_fn=None,
batch_processing=True,
device="cpu",
)

batch_size = len(shapes)
pipe = numba_func_pipe(batch_size=batch_size, num_threads=1, device_id=0)
pipe.build()
for it in range(3):
outs = pipe.run()
for i in range(batch_size):
out_arr = to_array(outs[0][i])
assert np.array_equal(out_arr, expected_out[i])


@with_setup(check_numba_compatibility_gpu)
def test_numba_func_gpu():
# shape, dtype, run_fn, out_types,
Expand Down

0 comments on commit d665985

Please sign in to comment.