Skip to content

Commit

Permalink
[Cherry-Pick][TOPI] improve inclusive_scan for thrust
Browse files Browse the repository at this point in the history
Fix comments
  • Loading branch information
yongwww authored and MasterJH5574 committed Feb 29, 2024
1 parent 2c1ce3a commit f178458
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
return tvmop_to_thrust_func_name[tvmop]


def _can_use_scan_thrust(binop):
"""
Check if scan_thrust can be utilized based on the current target and binary op.
"""
target = tvm.target.Target.current()
if target is None:
return False
return binop == tvm.tir.generic.add and any(
[
can_use_thrust(target, "tvm.contrib.thrust.sum_scan"),
can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"),
]
)


def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0):
"""Low level IR to do exclusive sum scan along rows of 2D input.
Expand Down Expand Up @@ -363,17 +378,9 @@ def exclusive_scan(
"""

def do_scan(data, output_dtype):
target = tvm.target.Target.current()

# TODO: add support for a prod_scan
if (
target
and binop == tvm.tir.generic.add
and (
can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
)
):
if _can_use_scan_thrust(binop):
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down Expand Up @@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
output : tvm.te.Tensor
A N-D tensor of the same rank N as the input data.
"""

if _can_use_scan_thrust(binop):
if output_dtype is None or output_dtype == "":
output_dtype = data.dtype
ndim = len(data.shape)
if axis < 0:
axis += ndim

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
output = transpose(output, axes)
return output

ex_scan = exclusive_scan(
data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value
)
Expand Down

0 comments on commit f178458

Please sign in to comment.