diff --git a/.github/scripts/run_tests.py b/.github/scripts/run_tests.py index 3e16c8ffe..80283c93c 100644 --- a/.github/scripts/run_tests.py +++ b/.github/scripts/run_tests.py @@ -57,7 +57,8 @@ def get_bench_cmd(run_type, run_id, run_name, run_param_name, dtype): cmd = get_bench_cmd(run_type, run_id, run_name, run_param_name, run_dtype_name) outputs = run_command(cmd) if outputs: - latency = float(outputs[-1].split('\n')[0]) # Get last line + # The second last line of All benchmark scripts' stdout is the latency. (Last line is empty) + latency = float(outputs.split('\n')[-2]) run_config['latency'] = latency else: run_config['latency'] = 999.99 diff --git a/apps/compile_server/resources/compilation.py b/apps/compile_server/resources/compilation.py index 4b25e449f..ae4cc5a33 100644 --- a/apps/compile_server/resources/compilation.py +++ b/apps/compile_server/resources/compilation.py @@ -53,11 +53,15 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str: # `version` is either a branch name, or 'pull/{n}' if coming from a pull request if should_update(repo_timestamp): + branches = repo.git.branch("--all").split() + # If local branch already exists, delete it as we prepare to do a new fresh checkout + # This is because the local branch might be divergent with remote, so we just discard it + if version in branches: + repo.git.checkout('main') + repo.git.branch('-D', version) if 'pull/' in version: - branches = repo.git.branch("--all").split() - if version not in branches: - # `git fetch origin pull/{n}/head:pull/{n}` checks out PR#n into branch 'pull/{n}' - repo.remotes.origin.fetch(version + '/head:' + version) + # Equivalent to `git fetch origin pull/{n}/head:pull/{n}`. Checks out PR#n into branch 'pull/{n}' + repo.remotes.origin.fetch(version + '/head:' + version) repo.git.checkout(version) repo.remotes.origin.pull(version + '/head') else: diff --git a/python/hidet/graph/ops/pool.py b/python/hidet/graph/ops/pool.py index e41b4546b..339300aaf 100644 --- a/python/hidet/graph/ops/pool.py +++ b/python/hidet/graph/ops/pool.py @@ -452,3 +452,24 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]: elif reduce_type == 'avg': return [mean(x, dims=dims[2:], keep_dim=True)] return None + + +@register_resolve_rule(AdaptivePoolChannelLastOp) +class AdaptivePoolChannelLastResolveRule(ResolveRule): + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + assert isinstance(op, AdaptivePoolChannelLastOp) + x: Tensor = op.inputs[0] + # TODO: Deal with generic N-dimensional convolution + if len(x.shape) != 4: + return None + output_size = op.attrs['output_size'] + reduce_type = op.reduce_type + resolve_to_reduce = output_size == 1 if isinstance(output_size, int) else all(d == 1 for d in output_size) + if resolve_to_reduce: + from hidet.graph.ops import mean, max + + if reduce_type == 'max': + return [max(x, dims=[1, 2], keep_dim=True)] + elif reduce_type == 'avg': + return [mean(x, dims=[1, 2], keep_dim=True)] + return None diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index aaf95142c..7a0979aac 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -17,6 +17,7 @@ from hidet.ir.type import DataType from hidet.ir.dtypes.vector import VectorType, vectorize from hidet.ir.library import tune +from hidet.utils.py import cdiv from ..arithmetic import square, sqrt from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType from ..utils import compute, input_like, normalize_dim, arg_reduce @@ -61,7 +62,7 @@ def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]: if rank - 1 in self.dims: return tune.extract_ir_modules(self.cuda_schedule_reduce_by_warp) else: - return self.cuda_schedule_reduce_by_default() + return tune.extract_ir_modules(self.cuda_schedule_reduce_by_default) @tune.space(2, use_atomic=[True, False]) @tune.space(1, use_atomic=[True, False]) @@ -87,7 +88,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule: lanes = num_eles vtype = vectorize(xdtype, lanes) read_shape = shape[:] - read_shape[-1] /= lanes + read_shape[-1] //= lanes block_size = (read_shape[-1] + warp_size - 1) // warp_size * warp_size block_size = hidet.ir.expr.if_then_else(block_size > 1024, 1024, block_size) @@ -192,77 +193,123 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]): ir_module = module.ir_module() return ir_module - def cuda_schedule_reduce_by_default(self) -> IRModule: + @tune.space(2, max_block_size=[256, 512, 1024], use_atomic=[True, False]) + @tune.space(1, max_block_size=[256, 512, 1024], use_atomic=[True, False]) + def cuda_schedule_reduce_by_default(self, max_block_size=256, use_atomic=True) -> IRModule: import hidet from hidet.ir.compute import ReduceOperation - from hidet.ir.type import data_type, Int + from hidet.ir.type import data_type, Int, tensor_type + from hidet.ir.expr import cast, address, is_constant + from hidet.ir.layout import row_major from hidet.lang import spatial, repeat, attrs, tensor_pointer - from hidet.ir.expr import cast, address + from hidet.lang.cuda import dynamic_shared_memory, syncthreads + + WARP_SIZE = 32 + max_num_warps = max_block_size // WARP_SIZE x, y = self.inputs[0], self.outputs[0] xdtype = x.type.dtype shape: List[Int] = list(x.shape) + accumulate_dtype = data_type(self.attrs['accumulate_dtype']) + reduce_type = self.attrs['reduce_type'] + ro = ReduceOperation.from_name(reduce_type) + perform_atomic_reduce = use_atomic and ro.has_atomic(accumulate_dtype) + lanes = 1 vtype: Union[VectorType, DataType] = xdtype if xdtype.nbytes < 4: num_eles: int = 4 // xdtype.nbytes if shape[-1] % num_eles == 0: lanes = num_eles - vtype = VectorType(xdtype, lanes) + vtype = vectorize(xdtype, lanes) read_shape = shape[:] - read_shape[-1] /= lanes - + read_shape[-1] //= lanes + x_vectorized_shape = read_shape[:] dims = self.dims + read_remain_shape = [v for i, v in enumerate(shape) if i not in dims] + read_remain_shape[-1] //= lanes if self.keep_dim: - remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)] + write_remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)] + write_remain_shape[-1] //= lanes else: - remain_shape = [v for i, v in enumerate(shape) if i not in dims] - remain_shape[-1] /= lanes - - reduce_extent = hidet.utils.prod(x.shape[i] for i in dims) - - remain_extent = hidet.utils.prod(remain_shape) - block_size = hidet.ir.expr.if_then_else(256 < remain_extent, 256, remain_extent) - remain_layout = spatial(*remain_shape) - - spatial_shape = [] - repeat_shape = [] - for i in range(len(read_shape)): - if i in dims: - spatial_shape.append(1) - repeat_shape.append(read_shape[i]) - else: - spatial_shape.append(read_shape[i]) - repeat_shape.append(1) - task_layout = repeat(*repeat_shape) * spatial(*spatial_shape) - - grid_size = (remain_layout.num_workers + block_size - 1) // block_size - accumulate_dtype = self.attrs['accumulate_dtype'] - reduce_type = self.attrs['reduce_type'] - ro = ReduceOperation.from_name(reduce_type) + write_remain_shape = read_remain_shape[:] + y_vectorized_shape = write_remain_shape + reduce_shape = [v for i, v in enumerate(shape) if i in dims] + reduce_extent = hidet.utils.prod(reduce_shape) + remain_extent = hidet.utils.prod(read_remain_shape) + reduce_mapping = spatial(*reduce_shape) + remain_mapping = spatial(*read_remain_shape) + remain_threads = hidet.ir.expr.if_then_else(remain_extent > max_block_size, max_block_size, remain_extent) + remain_warps = cdiv(remain_threads, WARP_SIZE) + max_reduce_warps = max_num_warps // remain_warps + reduce_warps = cdiv(reduce_extent, WARP_SIZE) + reduce_warps = hidet.ir.expr.if_then_else(max_reduce_warps < reduce_warps, max_reduce_warps, reduce_warps) + repeats_per_reduce = cdiv(reduce_extent, reduce_warps) + num_warps = reduce_warps * remain_warps + block_size = num_warps * WARP_SIZE + grid_size = cdiv(remain_extent, remain_warps * WARP_SIZE) + read_task_mapping = ( + spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(repeats_per_reduce, 1) + ) + write_task_mapping = spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE) + remain_write_mapping = spatial(*write_remain_shape) + + use_smem = not (is_constant(reduce_warps) and reduce_warps == 1) + smem_length = remain_warps * WARP_SIZE * lanes + smem_flattened_layout = row_major(smem_length) + smem_task_mapping = spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(1, lanes) + smem_type = tensor_type(accumulate_dtype, layout=smem_flattened_layout) + smem_needed = smem_type.storage_bytes() if use_smem else 0 + + def unflatten_read_idx(indices): + # indices should only contain a 2D coordinate in the (remain, reduce) space + assert len(indices) == 2 + reduce_indices = reduce_mapping.map(indices[0]) + remain_indices = remain_mapping.map(indices[1]) + unflattened_indices = [] + remain_dim = 0 + reduce_dim = 0 + for i in range(len(shape)): + if i in dims: + unflattened_indices.append(reduce_indices[reduce_dim]) + reduce_dim += 1 + else: + unflattened_indices.append(remain_indices[remain_dim]) + remain_dim += 1 + return unflattened_indices with hidet.script_module() as module: @hidet.script def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]): - # Each 256-thread ThreadBlock handles 512 columns attrs.cuda.grid_dim = grid_size attrs.cuda.block_dim = block_size attrs.cuda.min_blocks = 1 + attrs.cuda.dynamic_smem_bytes = smem_needed - x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype)) - y_vectorized = tensor_pointer(vtype, shape=remain_shape, init=cast(y, ~vtype)) - + x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype)) + y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype)) rv = register_tensor(accumulate_dtype, [lanes]) for lane_id in grid(lanes, "u+"): - rv[lane_id] = ro.initial_value(data_type(accumulate_dtype)) - + rv[lane_id] = ro.initial_value(accumulate_dtype) write_val = register_tensor(vtype, [1]) - if threadIdx.x + blockIdx.x * block_size < remain_extent: - for indices in task_layout.on(threadIdx.x + blockIdx.x * block_size): + smem_staging = tensor_pointer(accumulate_dtype, layout=smem_flattened_layout) + if use_smem: + smem_base = dynamic_shared_memory(byte_offset=0, dtype=accumulate_dtype) + smem_staging = cast(smem_base, ~accumulate_dtype) + + # Init smem if needed + if use_smem and threadIdx.x * lanes < smem_length: + for lane in range(lanes): + smem_staging[threadIdx.x * lanes + lane] = rv[0] + + # Read from global memory and perform local reduce + for flat_indices in read_task_mapping.on(threadIdx.x + blockIdx.x * block_size): + if flat_indices[0] < reduce_extent and flat_indices[1] < remain_extent: + indices = unflatten_read_idx(flat_indices) vec_read = x_vectorized[indices] if lanes > 1: for lane_id in grid(lanes, "u+"): @@ -271,14 +318,52 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]): else: rv[0] = ro.combine(rv[0], cast(vec_read, accumulate_dtype)) - if lanes > 1: - lane_vec = cast(~write_val, ~vtype.lane_type) - for lane_id in grid(lanes, "u+"): - lane_vec[lane_id] = ro.finalize(acc=rv[lane_id], size=reduce_extent) + # At this point, all threads contain their local reduction value in their register rv[] + # Next, need to reduce those values into respective smem location if needed + if use_smem: + syncthreads() + if perform_atomic_reduce: + for indices in smem_task_mapping.on(threadIdx.x): + remain_idx = indices[1] + ro.atomic_combine(~smem_staging[remain_idx], rv[remain_idx % lanes]) + syncthreads() else: - write_val[0] = ro.finalize(acc=rv[0], size=reduce_extent) - for indices in remain_layout.on(threadIdx.x + blockIdx.x * block_size): - y_vectorized[indices] = write_val[0] + # Reduce via multiround writebacks + syncthreads + for k in range(reduce_warps): + for indices in smem_task_mapping.on(threadIdx.x): + reduce_round = indices[0] + if reduce_round == k: + remain_idx = indices[1] + smem_staging[remain_idx] = ro.combine( + smem_staging[remain_idx], rv[remain_idx % lanes] + ) + syncthreads() + + # At this point, the shared memory (or rv, if not using smem) contains the final reduction value. + # Next, need to write back to global memory + if threadIdx.x < remain_warps * WARP_SIZE: + for indices in smem_task_mapping.on(threadIdx.x): + remain_idx = indices[1] + if lanes > 1: + lane_vec = cast(~write_val, ~vtype.lane_type) + if use_smem: + lane_vec[remain_idx % lanes] = ro.finalize( + acc=smem_staging[remain_idx], size=reduce_extent + ) + else: + lane_vec[remain_idx % lanes] = ro.finalize( + acc=rv[remain_idx % lanes], size=reduce_extent + ) + else: + if use_smem: + write_val[0] = ro.finalize(acc=smem_staging[remain_idx], size=reduce_extent) + else: + write_val[0] = ro.finalize(acc=rv[remain_idx % lanes], size=reduce_extent) + for flat_indices in write_task_mapping.on(threadIdx.x + blockIdx.x * block_size): + # flat_indices[0] will always be 0 because threadIdx.x < reduce_warps * WARP_SIZE + if flat_indices[1] < remain_extent: + indices = remain_write_mapping.map(flat_indices[1]) + y_vectorized[indices] = write_val[0] ir_module = module.ir_module() return ir_module diff --git a/python/hidet/graph/ops/reduce/resolve.py b/python/hidet/graph/ops/reduce/resolve.py index 933fc7886..7b6e6087a 100644 --- a/python/hidet/graph/ops/reduce/resolve.py +++ b/python/hidet/graph/ops/reduce/resolve.py @@ -13,7 +13,7 @@ from hidet.graph.operator import Operator, Tensor from hidet.graph.transforms import ResolveRule, register_resolve_rule -from hidet.graph.ops.utils import is_contiguous_dims +from hidet.graph.ops.utils import is_contiguous_dims, normalize_dim from hidet.utils import prod from .reduce import ReduceBaseOp @@ -56,9 +56,22 @@ def resolve_simplify(self, op: Operator) -> Optional[List[Tensor]]: return None + def resolve_decompose(self, op: Operator) -> Optional[List[Tensor]]: + dims = op.attrs['dims'] + x: Tensor = op.inputs[0] + shape = x.shape + dims = normalize_dim(dims, len(shape)) + if (len(shape) - 1) not in dims and len(dims) > 1: + # start from highest dim to support keepdims=True + dims.sort(reverse=True) + for dim in dims: + x = op.reforward([x], {'dims': [dim]})[0] + return [x] + return None + def resolve(self, op: Operator) -> Optional[List[Tensor]]: assert isinstance(op, ReduceBaseOp) - resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_simplify] + resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_simplify, self.resolve_decompose] for resolve_func in resolve_funcs: outs = resolve_func(op) if outs is not None: