Skip to content

Commit

Permalink
[Operator] Enhancements to Reduce (#366)
Browse files Browse the repository at this point in the history
In some input shapes, the current reduce schedule will underutilize the
GPU.
E.g., `reduce [1, 128, 128, 3] , dims=[1, 2]` will spawn 1 threadblock
with 3 threads that each iterate over 128*128 elements.
This PR made two changes to optimize these cases:
1. Add resolve_decompose in the resolve logic of Reduce. This will force
launch separate kernels for each reduce dimension, increasing
concurrency.
2. In the default reduce schedule template, spawn multiple warps within
the reduce dimensions, which then will communicate via shared memory or
use atomics to perform the reduce.

Also added a resolve rule for AdaptivePoolChannelLast.
  • Loading branch information
hjjq authored Dec 20, 2023
1 parent 57d859a commit 2040a7c
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 55 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def get_bench_cmd(run_type, run_id, run_name, run_param_name, dtype):
cmd = get_bench_cmd(run_type, run_id, run_name, run_param_name, run_dtype_name)
outputs = run_command(cmd)
if outputs:
latency = float(outputs[-1].split('\n')[0]) # Get last line
# The second last line of All benchmark scripts' stdout is the latency. (Last line is empty)
latency = float(outputs.split('\n')[-2])
run_config['latency'] = latency
else:
run_config['latency'] = 999.99
Expand Down
12 changes: 8 additions & 4 deletions apps/compile_server/resources/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str:

# `version` is either a branch name, or 'pull/{n}' if coming from a pull request
if should_update(repo_timestamp):
branches = repo.git.branch("--all").split()
# If local branch already exists, delete it as we prepare to do a new fresh checkout
# This is because the local branch might be divergent with remote, so we just discard it
if version in branches:
repo.git.checkout('main')
repo.git.branch('-D', version)
if 'pull/' in version:
branches = repo.git.branch("--all").split()
if version not in branches:
# `git fetch origin pull/{n}/head:pull/{n}` checks out PR#n into branch 'pull/{n}'
repo.remotes.origin.fetch(version + '/head:' + version)
# Equivalent to `git fetch origin pull/{n}/head:pull/{n}`. Checks out PR#n into branch 'pull/{n}'
repo.remotes.origin.fetch(version + '/head:' + version)
repo.git.checkout(version)
repo.remotes.origin.pull(version + '/head')
else:
Expand Down
21 changes: 21 additions & 0 deletions python/hidet/graph/ops/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,24 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]:
elif reduce_type == 'avg':
return [mean(x, dims=dims[2:], keep_dim=True)]
return None


@register_resolve_rule(AdaptivePoolChannelLastOp)
class AdaptivePoolChannelLastResolveRule(ResolveRule):
def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, AdaptivePoolChannelLastOp)
x: Tensor = op.inputs[0]
# TODO: Deal with generic N-dimensional convolution
if len(x.shape) != 4:
return None
output_size = op.attrs['output_size']
reduce_type = op.reduce_type
resolve_to_reduce = output_size == 1 if isinstance(output_size, int) else all(d == 1 for d in output_size)
if resolve_to_reduce:
from hidet.graph.ops import mean, max

if reduce_type == 'max':
return [max(x, dims=[1, 2], keep_dim=True)]
elif reduce_type == 'avg':
return [mean(x, dims=[1, 2], keep_dim=True)]
return None
181 changes: 133 additions & 48 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hidet.ir.type import DataType
from hidet.ir.dtypes.vector import VectorType, vectorize
from hidet.ir.library import tune
from hidet.utils.py import cdiv
from ..arithmetic import square, sqrt
from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType
from ..utils import compute, input_like, normalize_dim, arg_reduce
Expand Down Expand Up @@ -61,7 +62,7 @@ def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if rank - 1 in self.dims:
return tune.extract_ir_modules(self.cuda_schedule_reduce_by_warp)
else:
return self.cuda_schedule_reduce_by_default()
return tune.extract_ir_modules(self.cuda_schedule_reduce_by_default)

@tune.space(2, use_atomic=[True, False])
@tune.space(1, use_atomic=[True, False])
Expand All @@ -87,7 +88,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule:
lanes = num_eles
vtype = vectorize(xdtype, lanes)
read_shape = shape[:]
read_shape[-1] /= lanes
read_shape[-1] //= lanes
block_size = (read_shape[-1] + warp_size - 1) // warp_size * warp_size
block_size = hidet.ir.expr.if_then_else(block_size > 1024, 1024, block_size)

Expand Down Expand Up @@ -192,77 +193,123 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
ir_module = module.ir_module()
return ir_module

def cuda_schedule_reduce_by_default(self) -> IRModule:
@tune.space(2, max_block_size=[256, 512, 1024], use_atomic=[True, False])
@tune.space(1, max_block_size=[256, 512, 1024], use_atomic=[True, False])
def cuda_schedule_reduce_by_default(self, max_block_size=256, use_atomic=True) -> IRModule:
import hidet
from hidet.ir.compute import ReduceOperation
from hidet.ir.type import data_type, Int
from hidet.ir.type import data_type, Int, tensor_type
from hidet.ir.expr import cast, address, is_constant
from hidet.ir.layout import row_major
from hidet.lang import spatial, repeat, attrs, tensor_pointer
from hidet.ir.expr import cast, address
from hidet.lang.cuda import dynamic_shared_memory, syncthreads

WARP_SIZE = 32
max_num_warps = max_block_size // WARP_SIZE

x, y = self.inputs[0], self.outputs[0]
xdtype = x.type.dtype
shape: List[Int] = list(x.shape)

accumulate_dtype = data_type(self.attrs['accumulate_dtype'])
reduce_type = self.attrs['reduce_type']
ro = ReduceOperation.from_name(reduce_type)
perform_atomic_reduce = use_atomic and ro.has_atomic(accumulate_dtype)

lanes = 1
vtype: Union[VectorType, DataType] = xdtype
if xdtype.nbytes < 4:
num_eles: int = 4 // xdtype.nbytes
if shape[-1] % num_eles == 0:
lanes = num_eles
vtype = VectorType(xdtype, lanes)
vtype = vectorize(xdtype, lanes)

read_shape = shape[:]
read_shape[-1] /= lanes

read_shape[-1] //= lanes
x_vectorized_shape = read_shape[:]
dims = self.dims
read_remain_shape = [v for i, v in enumerate(shape) if i not in dims]
read_remain_shape[-1] //= lanes
if self.keep_dim:
remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)]
write_remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)]
write_remain_shape[-1] //= lanes
else:
remain_shape = [v for i, v in enumerate(shape) if i not in dims]
remain_shape[-1] /= lanes

reduce_extent = hidet.utils.prod(x.shape[i] for i in dims)

remain_extent = hidet.utils.prod(remain_shape)
block_size = hidet.ir.expr.if_then_else(256 < remain_extent, 256, remain_extent)
remain_layout = spatial(*remain_shape)

spatial_shape = []
repeat_shape = []
for i in range(len(read_shape)):
if i in dims:
spatial_shape.append(1)
repeat_shape.append(read_shape[i])
else:
spatial_shape.append(read_shape[i])
repeat_shape.append(1)
task_layout = repeat(*repeat_shape) * spatial(*spatial_shape)

grid_size = (remain_layout.num_workers + block_size - 1) // block_size
accumulate_dtype = self.attrs['accumulate_dtype']
reduce_type = self.attrs['reduce_type']
ro = ReduceOperation.from_name(reduce_type)
write_remain_shape = read_remain_shape[:]
y_vectorized_shape = write_remain_shape
reduce_shape = [v for i, v in enumerate(shape) if i in dims]
reduce_extent = hidet.utils.prod(reduce_shape)
remain_extent = hidet.utils.prod(read_remain_shape)
reduce_mapping = spatial(*reduce_shape)
remain_mapping = spatial(*read_remain_shape)
remain_threads = hidet.ir.expr.if_then_else(remain_extent > max_block_size, max_block_size, remain_extent)
remain_warps = cdiv(remain_threads, WARP_SIZE)
max_reduce_warps = max_num_warps // remain_warps
reduce_warps = cdiv(reduce_extent, WARP_SIZE)
reduce_warps = hidet.ir.expr.if_then_else(max_reduce_warps < reduce_warps, max_reduce_warps, reduce_warps)
repeats_per_reduce = cdiv(reduce_extent, reduce_warps)
num_warps = reduce_warps * remain_warps
block_size = num_warps * WARP_SIZE
grid_size = cdiv(remain_extent, remain_warps * WARP_SIZE)
read_task_mapping = (
spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(repeats_per_reduce, 1)
)
write_task_mapping = spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE)
remain_write_mapping = spatial(*write_remain_shape)

use_smem = not (is_constant(reduce_warps) and reduce_warps == 1)
smem_length = remain_warps * WARP_SIZE * lanes
smem_flattened_layout = row_major(smem_length)
smem_task_mapping = spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(1, lanes)
smem_type = tensor_type(accumulate_dtype, layout=smem_flattened_layout)
smem_needed = smem_type.storage_bytes() if use_smem else 0

def unflatten_read_idx(indices):
# indices should only contain a 2D coordinate in the (remain, reduce) space
assert len(indices) == 2
reduce_indices = reduce_mapping.map(indices[0])
remain_indices = remain_mapping.map(indices[1])
unflattened_indices = []
remain_dim = 0
reduce_dim = 0
for i in range(len(shape)):
if i in dims:
unflattened_indices.append(reduce_indices[reduce_dim])
reduce_dim += 1
else:
unflattened_indices.append(remain_indices[remain_dim])
remain_dim += 1
return unflattened_indices

with hidet.script_module() as module:

@hidet.script
def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
# Each 256-thread ThreadBlock handles 512 columns
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1
attrs.cuda.dynamic_smem_bytes = smem_needed

x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype))
y_vectorized = tensor_pointer(vtype, shape=remain_shape, init=cast(y, ~vtype))

x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype))
y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype))
rv = register_tensor(accumulate_dtype, [lanes])
for lane_id in grid(lanes, "u+"):
rv[lane_id] = ro.initial_value(data_type(accumulate_dtype))

rv[lane_id] = ro.initial_value(accumulate_dtype)
write_val = register_tensor(vtype, [1])

if threadIdx.x + blockIdx.x * block_size < remain_extent:
for indices in task_layout.on(threadIdx.x + blockIdx.x * block_size):
smem_staging = tensor_pointer(accumulate_dtype, layout=smem_flattened_layout)
if use_smem:
smem_base = dynamic_shared_memory(byte_offset=0, dtype=accumulate_dtype)
smem_staging = cast(smem_base, ~accumulate_dtype)

# Init smem if needed
if use_smem and threadIdx.x * lanes < smem_length:
for lane in range(lanes):
smem_staging[threadIdx.x * lanes + lane] = rv[0]

# Read from global memory and perform local reduce
for flat_indices in read_task_mapping.on(threadIdx.x + blockIdx.x * block_size):
if flat_indices[0] < reduce_extent and flat_indices[1] < remain_extent:
indices = unflatten_read_idx(flat_indices)
vec_read = x_vectorized[indices]
if lanes > 1:
for lane_id in grid(lanes, "u+"):
Expand All @@ -271,14 +318,52 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
else:
rv[0] = ro.combine(rv[0], cast(vec_read, accumulate_dtype))

if lanes > 1:
lane_vec = cast(~write_val, ~vtype.lane_type)
for lane_id in grid(lanes, "u+"):
lane_vec[lane_id] = ro.finalize(acc=rv[lane_id], size=reduce_extent)
# At this point, all threads contain their local reduction value in their register rv[]
# Next, need to reduce those values into respective smem location if needed
if use_smem:
syncthreads()
if perform_atomic_reduce:
for indices in smem_task_mapping.on(threadIdx.x):
remain_idx = indices[1]
ro.atomic_combine(~smem_staging[remain_idx], rv[remain_idx % lanes])
syncthreads()
else:
write_val[0] = ro.finalize(acc=rv[0], size=reduce_extent)
for indices in remain_layout.on(threadIdx.x + blockIdx.x * block_size):
y_vectorized[indices] = write_val[0]
# Reduce via multiround writebacks + syncthreads
for k in range(reduce_warps):
for indices in smem_task_mapping.on(threadIdx.x):
reduce_round = indices[0]
if reduce_round == k:
remain_idx = indices[1]
smem_staging[remain_idx] = ro.combine(
smem_staging[remain_idx], rv[remain_idx % lanes]
)
syncthreads()

# At this point, the shared memory (or rv, if not using smem) contains the final reduction value.
# Next, need to write back to global memory
if threadIdx.x < remain_warps * WARP_SIZE:
for indices in smem_task_mapping.on(threadIdx.x):
remain_idx = indices[1]
if lanes > 1:
lane_vec = cast(~write_val, ~vtype.lane_type)
if use_smem:
lane_vec[remain_idx % lanes] = ro.finalize(
acc=smem_staging[remain_idx], size=reduce_extent
)
else:
lane_vec[remain_idx % lanes] = ro.finalize(
acc=rv[remain_idx % lanes], size=reduce_extent
)
else:
if use_smem:
write_val[0] = ro.finalize(acc=smem_staging[remain_idx], size=reduce_extent)
else:
write_val[0] = ro.finalize(acc=rv[remain_idx % lanes], size=reduce_extent)
for flat_indices in write_task_mapping.on(threadIdx.x + blockIdx.x * block_size):
# flat_indices[0] will always be 0 because threadIdx.x < reduce_warps * WARP_SIZE
if flat_indices[1] < remain_extent:
indices = remain_write_mapping.map(flat_indices[1])
y_vectorized[indices] = write_val[0]

ir_module = module.ir_module()
return ir_module
Expand Down
17 changes: 15 additions & 2 deletions python/hidet/graph/ops/reduce/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
from hidet.graph.ops.utils import is_contiguous_dims
from hidet.graph.ops.utils import is_contiguous_dims, normalize_dim
from hidet.utils import prod
from .reduce import ReduceBaseOp

Expand Down Expand Up @@ -56,9 +56,22 @@ def resolve_simplify(self, op: Operator) -> Optional[List[Tensor]]:

return None

def resolve_decompose(self, op: Operator) -> Optional[List[Tensor]]:
dims = op.attrs['dims']
x: Tensor = op.inputs[0]
shape = x.shape
dims = normalize_dim(dims, len(shape))
if (len(shape) - 1) not in dims and len(dims) > 1:
# start from highest dim to support keepdims=True
dims.sort(reverse=True)
for dim in dims:
x = op.reforward([x], {'dims': [dim]})[0]
return [x]
return None

def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, ReduceBaseOp)
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_simplify]
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_simplify, self.resolve_decompose]
for resolve_func in resolve_funcs:
outs = resolve_func(op)
if outs is not None:
Expand Down

0 comments on commit 2040a7c

Please sign in to comment.