Skip to content

Commit

Permalink
Remove test that won't work
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Jan 17, 2024
1 parent 769f2f5 commit b2b8221
Showing 1 changed file with 4 additions and 39 deletions.
43 changes: 4 additions & 39 deletions dali/test/python/operator_1/test_numba_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,14 @@ def test_numba_func():

@with_setup(check_numba_compatibility_cpu)
def test_numba_func_with_cond():
# Check if the do_not_convert decorator doesn't mess with numba running the function.
# When the function is not converted, the numba still works with no issues.
# AG conversion or using a complex enough decorator would break this.
# TODO(klecki): Can we add any additional safeguards?
_testimpl_numba_func(
device="cpu",
shapes=[(10, 10, 10)],
dtype=np.uint8,
run_fn=do_not_convert(set_all_values_to_255_batch),
run_fn=set_all_values_to_255_batch,
out_types=[dali_types.UINT8],
in_types=[dali_types.UINT8],
outs_ndim=[3],
Expand All @@ -328,43 +330,6 @@ def test_numba_func_with_cond():
)


@with_setup(check_numba_compatibility_cpu)
def test_numba_func_with_converted_processing():
# Check if the autograph doesn't mess too much with numba running the function.

device = ("cpu",)
shapes = ([(10, 10, 10)],)
dtype = (np.uint8,)
expected_out = ([np.full((10, 10, 10), 255, dtype=np.uint8)],)

@pipeline_def(enable_conditionals=True)
def numba_func_pipe():
def set_all_values_to_255_batch(out0, in0):
out0[0][:] = 255

data = fn.external_source(lambda: get_data(shapes, dtype), batch=True, device=device)
return numba_function(
data,
run_fn=set_all_values_to_255_batch,
out_types=[dali_types.UINT8],
in_types=[dali_types.UINT8],
outs_ndim=[3],
ins_ndim=[3],
setup_fn=None,
batch_processing=True,
device="cpu",
)

batch_size = len(shapes)
pipe = numba_func_pipe(batch_size=batch_size, num_threads=1, device_id=0)
pipe.build()
for it in range(3):
outs = pipe.run()
for i in range(batch_size):
out_arr = to_array(outs[0][i])
assert np.array_equal(out_arr, expected_out[i])


@with_setup(check_numba_compatibility_gpu)
def test_numba_func_gpu():
# shape, dtype, run_fn, out_types,
Expand Down

0 comments on commit b2b8221

Please sign in to comment.