From 4980a7c4b2fc193c179596a6fb359eb1cbf2c58b Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 27 Dec 2023 20:49:31 -0500 Subject: [PATCH] Support advanced options for pooling operators 1. Support count_include_pad and divisor_override in average pool. 2. Refactor avg and max pool operators computation definition --- .../frontend/torch/register_functions.py | 26 +- .../graph/frontend/torch/register_modules.py | 7 + python/hidet/graph/ops/pool.py | 604 +++++++++++------- .../graph/transforms/subgraph_rewrite.py | 3 +- python/hidet/ir/schedulers/base.py | 35 +- tests/operators/test_pool.py | 59 +- 6 files changed, 445 insertions(+), 289 deletions(-) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 95a187b37..8b66104b3 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -328,27 +328,23 @@ def unsqueeze(x: Tensor, dim: int): def avg_pool2d( x: Tensor, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None ): - if ceil_mode: - raise NotImplementedError("ceil_mode=True") - if not count_include_pad: - raise NotImplementedError("count_include_pad=False") - if divisor_override is not None: - raise NotImplementedError("divisor_override is not None") if stride is None: stride = kernel_size - y = ops.avg_pool2d(x, kernel_size, stride, padding) + y = ops.avg_pool2d( + x, + kernel_size, + stride, + padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) return y @register_function(torch.nn.functional.avg_pool3d) def avg_pool3d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None): - if ceil_mode: - raise NotImplementedError("ceil_mode=True") - if not count_include_pad: - raise NotImplementedError("count_include_pad=False") - if divisor_override is not None: - raise NotImplementedError("divisor_override is not None") - y = ops.avg_pool3d(x, kernel_size, stride, padding) + y = ops.avg_pool3d(x, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) return y @@ -1238,7 +1234,7 @@ def isinf(x: Tensor) -> Tensor: @register_function(torch.nn.functional.pad) -def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0): +def torch_pad(x: Tensor, pad: Union[Tuple[int, ...], List[int]], mode: str = 'constant', value=0): if isinstance(pad, tuple): pad = list(pad) # Torch's pad list has form [p2left, p2right, p1left, p1right, p0left, p0right] diff --git a/python/hidet/graph/frontend/torch/register_modules.py b/python/hidet/graph/frontend/torch/register_modules.py index df3055216..dfb5f14d0 100644 --- a/python/hidet/graph/frontend/torch/register_modules.py +++ b/python/hidet/graph/frontend/torch/register_modules.py @@ -162,6 +162,13 @@ def __call__(self, x: Tensor) -> Tensor: ) +@register_module(torch.nn.ZeroPad2d) +class HidetZeroPad2d(HidetModule): + def __call__(self, x: Tensor) -> Tensor: + assert isinstance(self.mod, torch.nn.ZeroPad2d) + return regs.torch_pad(x=x, pad=self.mod.padding, mode='constant', value=0.0) + + @register_module(torch.nn.Linear) class HidetLinear(HidetModule): def __init__(self, torch_module: torch.nn.Module): diff --git a/python/hidet/graph/ops/pool.py b/python/hidet/graph/ops/pool.py index e41b4546b..ae3feaef9 100644 --- a/python/hidet/graph/ops/pool.py +++ b/python/hidet/graph/ops/pool.py @@ -9,130 +9,183 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Sequence, List, Dict, Any, Optional +from typing import Union, Sequence, List, Optional from hidet.ir.expr import Expr, Int, convert, if_then_else, logical_and +from hidet.ir.dtypes import boolean, int32 +from hidet.ir import primitives from .utils import Task, Operator, Tensor, TensorNode, compute, reduce, input_like, normalize_stride, normalize_kernel from .utils import normalize_padding, normalize_output from ..transforms import ResolveRule, register_resolve_rule -class Pool2dChannelLastTask(Task): - def __init__(self, x: TensorNode, kernel, strides, padding, ceil_mode: bool, reduce_type: str): - assert reduce_type in ['max', 'avg'] - kernel = normalize_kernel(kernel) - strides = normalize_stride(strides) - padding = normalize_padding(padding) - batch_size, height, width, channels = x.shape - if ceil_mode: - out_height = (height + padding[0] + padding[2] - kernel[0] + strides[0] - 1) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1] + strides[1] - 1) // strides[1] + 1 +class PoolNdBaseTask(Task): + @staticmethod + def preprocess( + x: TensorNode, + kernel: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + ceil_mode: bool, + channel_last: bool, + reduce_type: str, + ): + assert len(x.shape) >= 3 + + in_shape: List[Expr] = list(x.shape) + if not channel_last: + batch_dim: int = 0 + channel_dim: int = 1 + spatial_dims: List[int] = list(range(2, len(in_shape))) else: - out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 - pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.type.dtype) - pad = compute( - name='pad', - shape=[batch_size, height + padding[0] + padding[2], width + padding[1] + padding[3], channels], - fcompute=lambda n, h, w, c: if_then_else( - logical_and(padding[0] <= h, h < height + padding[0], padding[1] <= w, w < width + padding[1]), - x[n, h - padding[0], w - padding[1], c], - pad_value, - ), - ) - y = compute( - name='y', - shape=[batch_size, out_height, out_width, channels], - fcompute=lambda n, h, w, c: reduce( - shape=[kernel[0], kernel[1]], - fcompute=lambda rx, ry: pad[n, h * strides[0] + rx, w * strides[1] + ry, c], - reduce_type=reduce_type, - ), + batch_dim: int = 0 + channel_dim: int = len(in_shape) - 1 + spatial_dims: List[int] = list(range(1, len(in_shape) - 1)) + + kernel = normalize_kernel(kernel, dim=len(spatial_dims)) + stride = normalize_stride(stride, dim=len(spatial_dims)) + padding = normalize_padding(padding, dim=len(spatial_dims)) + + # calculate output shape + out_shape: List[Expr] = [int32.zero] * len(in_shape) + out_shape[batch_dim] = in_shape[batch_dim] + out_shape[channel_dim] = in_shape[channel_dim] + for i, dim in enumerate(spatial_dims): + if ceil_mode: + out_shape[dim] = ( + in_shape[dim] + padding[i] + padding[i + len(spatial_dims)] - kernel[i] + stride[i] - 1 + ) // stride[i] + 1 + else: + out_shape[dim] = (in_shape[dim] + padding[i] + padding[i + len(spatial_dims)] - kernel[i]) // stride[ + i + ] + 1 + + # calculate padding shape + pad_shape: List[Expr] = [int32.zero] * len(in_shape) + pad_shape[batch_dim] = in_shape[batch_dim] + pad_shape[channel_dim] = in_shape[channel_dim] + for i, dim in enumerate(spatial_dims): + pad_shape[dim] = in_shape[dim] + padding[i] + padding[i + len(spatial_dims)] + + def f_pad_compute(*indices: Expr) -> Expr: + if reduce_type == 'max': + pad_value = x.type.dtype.min_value + else: + assert reduce_type == 'avg' + pad_value = x.type.dtype.zero + cond = boolean.true + x_indices: List[Expr] = [int32.zero] * len(in_shape) + x_indices[batch_dim] = indices[batch_dim] + x_indices[channel_dim] = indices[channel_dim] + for i, dim in enumerate(spatial_dims): + cond = logical_and(cond, padding[i] <= indices[dim], indices[dim] < padding[i] + in_shape[dim]) + x_indices[dim] = indices[dim] - padding[i] + return if_then_else(cond, x[x_indices], pad_value) + + pad = compute(name='pad', shape=pad_shape, fcompute=f_pad_compute) + return kernel, stride, padding, batch_dim, channel_dim, spatial_dims, in_shape, out_shape, pad + + +class AvgPoolNdTask(PoolNdBaseTask): + def __init__( + self, + x: TensorNode, + kernel: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + ceil_mode: bool, + count_include_pad: bool, + divisor_override: Optional[int], + channel_last: bool, + ): + kernel, stride, padding, batch_dim, channel_dim, spatial_dims, in_shape, out_shape, pad = self.preprocess( + x, kernel, stride, padding, ceil_mode, channel_last, 'avg' ) - super().__init__(name='{}_pool2d_channel_last'.format(reduce_type), inputs=[x], outputs=[y]) - -class Pool2dTask(Task): - def __init__(self, x: TensorNode, kernel, strides, padding, ceil_mode: bool, reduce_type: str): - assert reduce_type in ['max', 'avg'] - kernel = normalize_kernel(kernel) - strides = normalize_stride(strides) - padding = normalize_padding(padding) - batch_size, channels, height, width = x.shape - if ceil_mode: - out_height = (height + padding[0] + padding[2] - kernel[0] + strides[0] - 1) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1] + strides[1] - 1) // strides[1] + 1 - else: - out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 - pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.type.dtype) - pad = compute( - name='pad', - shape=[batch_size, channels, height + padding[0] + padding[2], width + padding[1] + padding[3]], - fcompute=lambda n, c, h, w: if_then_else( - logical_and(padding[0] <= h, h < height + padding[0], padding[1] <= w, w < width + padding[1]), - x[n, c, h - padding[0], w - padding[1]], - pad_value, + # calculate the sum of the pooling region + def f_sum_compute(out_indices: List[Expr], reduce_indices: List[int]) -> Expr: + pad_indices: List[Expr] = [int32.zero] * len(in_shape) + pad_indices[batch_dim] = out_indices[batch_dim] + pad_indices[channel_dim] = out_indices[channel_dim] + for i, dim in enumerate(spatial_dims): + pad_indices[dim] = out_indices[dim] * stride[i] + reduce_indices[i] + return pad[pad_indices] + + s = compute( + name='s', + shape=out_shape, + fcompute=lambda *out_indices: reduce( + shape=kernel, + fcompute=lambda *reduce_indices: f_sum_compute(out_indices, reduce_indices), + reduce_type='sum', ), ) - y = compute( - name='y', - shape=[batch_size, channels, out_height, out_width], - fcompute=lambda n, c, h, w: reduce( - shape=[kernel[0], kernel[1]], - fcompute=lambda rx, ry: pad[n, c, h * strides[0] + rx, w * strides[1] + ry], - reduce_type=reduce_type, - ), + + # calculate the output value by dividing the sum by the number of pooling region elements + def f_average_compute(*indices): + if divisor_override is not None: + area = int32(int(divisor_override)) + else: + area = 1 + for i, dim in enumerate(spatial_dims): + if count_include_pad: + if ceil_mode: + start = indices[dim] * stride[i] + end = primitives.min( + indices[dim] * stride[i] + kernel[i], + convert(in_shape[dim] + padding[i] + padding[i + len(spatial_dims)]), + ) + num_elements = end - start + area = area * num_elements + else: + area = area * kernel[i] + else: + start = primitives.max(indices[dim] * stride[i], convert(padding[i])) + end = primitives.min(indices[dim] * stride[i] + kernel[i], convert(in_shape[dim] + padding[i])) + num_elements = end - start + area = area * num_elements + return s[indices] / area + + y = compute(name='y', shape=out_shape, fcompute=f_average_compute) + super().__init__( + name='max_pool{}d'.format(len(spatial_dims)), + inputs=[x], + outputs=[y], + attributes={ + 'kernel': kernel, + 'strides': stride, + 'padding': padding, + 'ceil_mode': ceil_mode, + 'count_include_pad': count_include_pad, + 'divisor_override': divisor_override, + 'channel_last': channel_last, + }, ) - super().__init__(name='{}_pool2d'.format(reduce_type), inputs=[x], outputs=[y]) -class Pool3dTask(Task): - def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str): - assert reduce_type in ['max', 'avg'] - kernel = normalize_kernel(kernel, dim=3) - strides = normalize_stride(strides, dim=3) - padding = normalize_padding(padding, dim=3) - batch_size, channels, depth, height, width = x.shape - out_depth = (depth + padding[0] + padding[3] - kernel[0]) // strides[0] + 1 - out_height = (height + padding[1] + padding[4] - kernel[1]) // strides[1] + 1 - out_width = (width + padding[2] + padding[5] - kernel[2]) // strides[2] + 1 - pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.type.dtype) - pad = compute( - name='pad', - shape=[ - batch_size, - channels, - depth + padding[0] + padding[3], - height + padding[1] + padding[4], - width + padding[2] + padding[5], - ], - fcompute=lambda n, c, d, h, w: ( - if_then_else( - logical_and( - padding[0] <= d, - d < depth + padding[0], - padding[1] <= h, - h < height + padding[1], - padding[2] <= w, - w < width + padding[2], - ), - x[n, c, d - padding[0], h - padding[1], w - padding[2]], - pad_value, - ) - ), +class MaxPoolNdTask(PoolNdBaseTask): + def __init__(self, x: TensorNode, kernel, stride, padding, ceil_mode: bool, channel_last: bool): + kernel, stride, padding, batch_dim, channel_dim, spatial_dims, in_shape, out_shape, pad = self.preprocess( + x, kernel, stride, padding, ceil_mode, channel_last, 'max' ) + + def f_compute(out_indices: List[Expr], reduce_indices: List[Expr]) -> Expr: + pad_indices: List[Expr] = [int32.zero] * len(in_shape) + pad_indices[batch_dim] = out_indices[batch_dim] + pad_indices[channel_dim] = out_indices[channel_dim] + for i, dim in enumerate(spatial_dims): + pad_indices[dim] = out_indices[dim] * stride[i] + reduce_indices[i] + return pad[pad_indices] + y = compute( name='y', - shape=[batch_size, channels, out_depth, out_height, out_width], - fcompute=lambda n, c, d, h, w: reduce( - shape=[kernel[0], kernel[1], kernel[2]], - fcompute=lambda rz, rx, ry: pad[n, c, d * strides[0] + rz, h * strides[1] + rx, w * strides[2] + ry], - reduce_type=reduce_type, + shape=out_shape, + fcompute=lambda *out_indices: reduce( + shape=kernel, fcompute=lambda *reduce_indices: f_compute(out_indices, reduce_indices), reduce_type='max' ), ) - super().__init__(name='{}_pool3d'.format(reduce_type), inputs=[x], outputs=[y]) + super().__init__(name='max_pool{}d'.format(len(spatial_dims)), inputs=[x], outputs=[y]) class AdaptivePoolTask(Task): @@ -220,191 +273,251 @@ def reduce_compute(*reduce_indices: Expr) -> Expr: ) -class MaxPool2dOp(Operator): - def __init__( - self, - x: Tensor, - kernel: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]], - padding: Union[int, Sequence[int]], - ceil_mode: bool, - ): - super().__init__( - inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='max'), - ) - +class AvgPoolNdOp(Operator): + ndim: Optional[int] = None + last_channel: bool = False -class MaxPool2dChannelLastOp(Operator): def __init__( self, x: Tensor, kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], - ceil_mode: bool, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, ): - super().__init__( - inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, - task=Pool2dChannelLastTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='max'), - ) - + if len(x.shape) != self.ndim + 2: + raise ValueError( + 'AvgPool{}d expects {}D input, got {}D one.'.format(self.ndim, self.ndim + 2, len(x.shape)) + ) -class MaxPool3dOp(Operator): - def __init__( - self, - x: Tensor, - kernel: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]], - padding: Union[int, Sequence[int]], - ): super().__init__( inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool3dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='max'), + attributes={ + 'kernel': kernel, + 'stride': stride, + 'padding': padding, + 'ceil_mode': ceil_mode, + 'count_include_pad': count_include_pad, + 'divisor_override': divisor_override, + }, + task=AvgPoolNdTask( + x=input_like(x, 'x'), + kernel=kernel, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + channel_last=self.last_channel, + ), ) -class AvgPool2dOp(Operator): - def __init__( - self, - x: Tensor, - kernel: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]], - padding: Union[int, Sequence[int]], - ceil_mode: bool, - ): - super().__init__( - inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='avg'), - ) - +class MaxPoolNdOp(Operator): + ndim: Optional[int] = None + last_channel: bool = False -class AvgPool2dChannelLastOp(Operator): def __init__( self, x: Tensor, kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], - ceil_mode: bool, + ceil_mode: bool = False, ): + if len(x.shape) != self.ndim + 2: + raise ValueError( + 'MaxPool{}d expects {}D input, got {}D one.'.format(self.ndim, self.ndim + 2, len(x.shape)) + ) + super().__init__( inputs=[x], attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, - task=Pool2dChannelLastTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='avg'), + task=MaxPoolNdTask( + x=input_like(x, 'x'), + kernel=kernel, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + channel_last=self.last_channel, + ), ) -class AvgPool3dOp(Operator): - def __init__( - self, - x: Tensor, - kernel: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]], - padding: Union[int, Sequence[int]], - ): - super().__init__( - inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool3dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='avg'), - ) - +class AdaptivePoolNdOp(Operator): + spatial_ndim: Optional[int] = None + reduce_type: Optional[str] = None + last_channel_layout: Optional[bool] = None -class AdaptivePoolOp(Operator): - def __init__(self, x: Tensor, output_size, reduce_type: str, attrs: Dict[str, Any], spatial_ndim: int): - if len(x.shape) != spatial_ndim + 2: + def __init__(self, x: Tensor, output_size): + if len(x.shape) != self.spatial_ndim + 2: raise ValueError( 'Adaptive{}Pool{}d expects {}D input, got {}D one.'.format( - reduce_type.capitalize(), spatial_ndim, spatial_ndim + 2, len(x.shape) + self.reduce_type.capitalize(), self.spatial_ndim, self.spatial_ndim + 2, len(x.shape) ) ) - output_size = normalize_output(output_size, spatial_ndim) - self.reduce_type = reduce_type - super().__init__( - inputs=[x], - attributes=attrs, - task=AdaptivePoolTask(input_like(x, 'x'), output_size, reduce_type=reduce_type), - ) + output_size = normalize_output(output_size, self.spatial_ndim) + self.reduce_type = self.reduce_type + # todo: merge AdaptivePoolTask and AdaptivePoolChannelLastTask into one class + if self.last_channel_layout: + task = AdaptivePoolChannelLastTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type) + else: + task = AdaptivePoolTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type) -class AdaptivePoolChannelLastOp(Operator): - def __init__(self, x: Tensor, output_size, reduce_type: str, attrs: Dict[str, Any], spatial_ndim: int): - if len(x.shape) != spatial_ndim + 2: - raise ValueError( - 'Adaptive{}Pool{}d expects {}D input, got {}D one.'.format( - reduce_type.capitalize(), spatial_ndim, spatial_ndim + 2, len(x.shape) - ) - ) - output_size = normalize_output(output_size, spatial_ndim) - self.reduce_type = reduce_type - super().__init__( - inputs=[x], - attributes=attrs, - task=AdaptivePoolChannelLastTask(input_like(x, 'x'), output_size, reduce_type=reduce_type), - ) + super().__init__(inputs=[x], attributes={'output_size': output_size}, task=task) -class AdaptiveAvgPool1dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='avg', attrs={'output_size': output_size}, spatial_ndim=1) +class MaxPool1dOp(MaxPoolNdOp): + ndim: int = 1 + last_channel: bool = False -class AdaptiveAvgPool2dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='avg', attrs={'output_size': output_size}, spatial_ndim=2) +class MaxPool1dChannelLastOp(MaxPoolNdOp): + ndim: int = 1 + last_channel: bool = True -class AdaptiveAvgPool2dChannelLastOp(AdaptivePoolChannelLastOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='avg', attrs={'output_size': output_size}, spatial_ndim=2) +class MaxPool2dOp(MaxPoolNdOp): + ndim: int = 2 + last_channel: bool = False -class AdaptiveAvgPool3dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='avg', attrs={'output_size': output_size}, spatial_ndim=3) +class MaxPool2dChannelLastOp(MaxPoolNdOp): + ndim: int = 2 + last_channel: bool = True -class AdaptiveMaxPool1dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='max', attrs={'output_size': output_size}, spatial_ndim=1) +class MaxPool3dOp(MaxPoolNdOp): + ndim: int = 3 + last_channel: bool = False -class AdaptiveMaxPool2dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='max', attrs={'output_size': output_size}, spatial_ndim=2) +class MaxPool3dChannelLastOp(MaxPoolNdOp): + ndim: int = 3 + last_channel: bool = True -class AdaptiveMaxPool3dOp(AdaptivePoolOp): - def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): - super().__init__(x, output_size, reduce_type='max', attrs={'output_size': output_size}, spatial_ndim=3) +class AvgPool1dOp(AvgPoolNdOp): + ndim: int = 1 + last_channel: bool = False -def max_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: - return MaxPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] +class AvgPool1dChannelLastOp(AvgPoolNdOp): + ndim: int = 1 + last_channel: bool = True -def max_pool2d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: - return MaxPool2dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] +class AvgPool2dOp(AvgPoolNdOp): + ndim: int = 2 + last_channel: bool = False -def max_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: - return MaxPool3dOp(x, kernel, stride, padding).outputs[0] +class AvgPool2dChannelLastOp(AvgPoolNdOp): + ndim: int = 2 + last_channel: bool = True -def avg_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: - return AvgPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] +class AvgPool3dOp(AvgPoolNdOp): + ndim: int = 3 + last_channel: bool = False -def avg_pool2d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: - return AvgPool2dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] +class AvgPool3dChannelLastOp(AvgPoolNdOp): + ndim: int = 3 + last_channel: bool = True + + +class AdaptiveAvgPool1dOp(AdaptivePoolNdOp): + reduce_type = 'avg' + spatial_ndim = 1 + last_channel_layout = False + + +class AdaptiveAvgPool2dOp(AdaptivePoolNdOp): + reduce_type = 'avg' + spatial_ndim = 2 + last_channel_layout = False + + +class AdaptiveAvgPool3dOp(AdaptivePoolNdOp): + reduce_type = 'avg' + spatial_ndim = 3 + last_channel_layout = False + + +class AdaptiveAvgPool2dChannelLastOp(AdaptivePoolNdOp): + reduce_type = 'avg' + spatial_ndim = 2 + last_channel_layout = True + +class AdaptiveMaxPool1dOp(AdaptivePoolNdOp): + reduce_type = 'max' + spatial_ndim = 1 + last_channel_layout = False -def avg_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: - return AvgPool3dOp(x, kernel, stride, padding).outputs[0] + +class AdaptiveMaxPool2dOp(AdaptivePoolNdOp): + reduce_type = 'max' + spatial_ndim = 2 + last_channel_layout = False + + +class AdaptiveMaxPool3dOp(AdaptivePoolNdOp): + reduce_type = 'max' + spatial_ndim = 3 + last_channel_layout = False + + +def max_pool1d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool1dOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def max_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def max_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: + return MaxPool3dOp(x, kernel, stride, padding).outputs[0] + + +def avg_pool1d( + x: Tensor, + kernel, + stride, + padding, + ceil_mode=False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, +) -> Tensor: + return AvgPool1dOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def avg_pool2d( + x: Tensor, + kernel, + stride, + padding, + ceil_mode=False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, +) -> Tensor: + return AvgPool2dOp(x, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override).outputs[0] + + +def avg_pool3d( + x: Tensor, + kernel, + stride, + padding, + ceil_mode=False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, +) -> Tensor: + return AvgPool3dOp(x, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override).outputs[0] def adaptive_avg_pool1d(x: Tensor, output_size: Union[int, Sequence[int]]) -> Tensor: @@ -415,10 +528,6 @@ def adaptive_avg_pool2d(x: Tensor, output_size: Union[int, Sequence[int]]) -> Te return AdaptiveAvgPool2dOp(x, output_size).outputs[0] -def adaptive_avg_pool2d_channel_last(x: Tensor, output_size: Union[int, Sequence[int]]) -> Tensor: - return AdaptiveAvgPool2dChannelLastOp(x, output_size).outputs[0] - - def adaptive_avg_pool3d(x: Tensor, output_size: Union[int, Sequence[int]]) -> Tensor: return AdaptiveAvgPool3dOp(x, output_size).outputs[0] @@ -435,10 +544,39 @@ def adaptive_max_pool3d(x: Tensor, output_size: Union[int, Sequence[int]]) -> Te return AdaptiveMaxPool3dOp(x, output_size).outputs[0] -@register_resolve_rule(AdaptivePoolOp) +# channel last operators +def max_pool1d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool1dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def max_pool2d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool2dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def max_pool3d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool2dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def avg_pool1d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return AvgPool1dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def avg_pool2d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return AvgPool2dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def avg_pool3d_channel_last(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return AvgPool3dChannelLastOp(x, kernel, stride, padding, ceil_mode).outputs[0] + + +def adaptive_avg_pool2d_channel_last(x: Tensor, output_size: Union[int, Sequence[int]]) -> Tensor: + return AdaptiveAvgPool2dChannelLastOp(x, output_size).outputs[0] + + +@register_resolve_rule(AdaptivePoolNdOp) class AdaptivePoolResolveRule(ResolveRule): def resolve(self, op: Operator) -> Optional[List[Tensor]]: - assert isinstance(op, AdaptivePoolOp) + assert isinstance(op, AdaptivePoolNdOp) x: Tensor = op.inputs[0] output_size = op.attrs['output_size'] reduce_type = op.reduce_type diff --git a/python/hidet/graph/transforms/subgraph_rewrite.py b/python/hidet/graph/transforms/subgraph_rewrite.py index 1c670b951..88cdf511b 100644 --- a/python/hidet/graph/transforms/subgraph_rewrite.py +++ b/python/hidet/graph/transforms/subgraph_rewrite.py @@ -12,6 +12,7 @@ # pylint: disable=unused-import from typing import List, Optional, Dict, Tuple, Set import logging +import warnings from hidet.graph.flow_graph import FlowGraph, Operator, Tensor from hidet.graph.transforms import GraphPass, PassContext @@ -53,7 +54,7 @@ def process_graph(self, graph: FlowGraph) -> FlowGraph: if not updated: graph.update_nodes() return graph - print('Exceeded maximum number of transforms {}, stop early.'.format(self.max_num_transforms)) + warnings.warn('Exceeded maximum number of sub-graph transforms {}, stop early.'.format(self.max_num_transforms)) graph.update_nodes() return graph diff --git a/python/hidet/ir/schedulers/base.py b/python/hidet/ir/schedulers/base.py index 8916bc074..2a4331416 100644 --- a/python/hidet/ir/schedulers/base.py +++ b/python/hidet/ir/schedulers/base.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Union, List, Dict, Sequence, Tuple, Set, Optional +from collections import defaultdict from hidet.ir.type import DataType, tensor_pointer_type from hidet.ir.expr import TensorElement, Expr, Var, SymbolVar, Constant, scalar_var, convert, cast @@ -18,7 +19,7 @@ from hidet.ir.func import Function from hidet.ir.module import IRModule from hidet.ir.builders import FunctionBuilder, StmtBuilder -from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter +from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter, IRVisitor from hidet.ir.tools import IRPrinter, collect, rewrite, infer_type, simplify, collect_free_vars from hidet.ir.compute import ScalarInput, TensorInput, GridCompute, ReduceCompute, ArgReduceCompute from hidet.ir.compute import TensorNode, ScalarNode @@ -32,6 +33,26 @@ class ScalarComputeFound(Exception): pass +class UsageCounter(IRVisitor): + def __init__(self): + super().__init__() + self.usage: Dict[Union[TensorNode, ScalarNode], int] = defaultdict(int) + + def visit_TensorInput(self, node: TensorInput): + self.usage[node] += 1 + + def visit_ScalarInput(self, node: ScalarInput): + self.usage[node] += 1 + + def visit_GridCompute(self, node: GridCompute): + self.usage[node] += 1 + super().visit_GridCompute(node) + + def visit_ReduceCompute(self, node: ReduceCompute): + self.usage[node] += 1 + super().visit_ReduceCompute(node) + + class GridComputeInlineChecker(ExprVisitor, ComputeVisitor): def check(self, gc: GridCompute) -> bool: """Check whether the grid compute can be inlined. @@ -75,8 +96,9 @@ def can_inline_grid_compute(gc: GridCompute) -> bool: class GridComputeInliner(ExprRewriter, ComputeRewriter): - def __init__(self): + def __init__(self, usage_count: Dict[Union[TensorNode, ScalarNode], int]): super().__init__() + self.usage_count: Dict[Union[TensorNode, ScalarNode], int] = usage_count def inline(self, node: TensorNode): return self.visit(node) @@ -88,7 +110,8 @@ def visit_TensorElement(self, e: TensorElement): base = self(e.base) indices = [self(index) for index in e.indices] if isinstance(base, GridCompute): - if can_inline_grid_compute(base): + assert isinstance(e.base, TensorNode) + if can_inline_grid_compute(base) or self.usage_count[e.base] == 1: return rewrite(base.value, {axis: index for axis, index in zip(base.axes, indices)}) return ExprRewriter.visit_TensorElement(self, e) @@ -120,7 +143,11 @@ def inline_grid_compute(nodes: List[TensorNode]) -> List[TensorNode]: ret: List[TensorNode] The nodes after inlining. """ - inliner = GridComputeInliner() + usage_counter = UsageCounter() + usage_counter.visit(nodes) + + inliner = GridComputeInliner(usage_counter.usage) + return [inliner.inline(node) for node in nodes] diff --git a/tests/operators/test_pool.py b/tests/operators/test_pool.py index 7fb80786b..cd909fe07 100644 --- a/tests/operators/test_pool.py +++ b/tests/operators/test_pool.py @@ -9,40 +9,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - -import numpy as np +from typing import Tuple, Optional import pytest import torch.nn.functional from hidet import ops -from hidet.testing import check_unary, check_torch_unary - - -def numpy_pool2d( - data: np.ndarray, kernel: Tuple[int, int], stride: Tuple[int, int], padding: Tuple[int, int, int, int], reduce_type -) -> np.ndarray: - assert reduce_type in ['max', 'avg'] - n, c, h, w = data.shape - kx, ky = kernel - sx, sy = stride - ph, pw = h + padding[0] + padding[2], w + padding[1] + padding[3] - padded = np.full_like(data, fill_value=0.0 if reduce_type == 'avg' else -1e30, shape=(n, c, ph, pw)) - padded[:, :, padding[0] : padding[0] + h, padding[1] : padding[1] + w] = data - oh, ow = (ph - kx) // sx + 1, (pw - ky) // sy + 1 - output = np.empty_like(data, shape=(n, c, oh, ow)) - for nn in range(n): - for cc in range(c): - for p in range(oh): - for q in range(ow): - if reduce_type == 'max': - output[nn, cc, p, q] = np.max(padded[nn, cc, p * sx : p * sx + kx, q * sy : q * sy + ky]) - elif reduce_type == 'avg': - output[nn, cc, p, q] = np.sum(padded[nn, cc, p * sx : p * sx + kx, q * sy : q * sy + ky]) / ( - kx * ky - ) - - return output +from hidet.testing import check_torch_unary @pytest.mark.parametrize( @@ -56,9 +28,11 @@ def numpy_pool2d( ], ) def test_max_pool2d(shape, kernel, stride, padding): - check_unary( + check_torch_unary( shape, - lambda x: numpy_pool2d(x, kernel, stride, padding, 'max'), + lambda x: torch.nn.functional.max_pool2d( + x, kernel_size=kernel, stride=stride, padding=[padding[0], padding[1]] + ), lambda x: ops.max_pool2d(x, kernel, stride, padding), atol=1e-6, rtol=1e-6, @@ -97,11 +71,24 @@ def test_max_pool3d(shape, kernel, stride, padding): [[1, 3, 32, 32], [7, 7], [2, 2], [3, 3, 3, 3]], # kernel 3, stride 2 ], ) -def test_avg_pool2d(shape, kernel, stride, padding): - check_unary( +@pytest.mark.parametrize('ceil_mode', [True, False]) +@pytest.mark.parametrize('count_include_pad', [True, False]) +@pytest.mark.parametrize('divisor_override', [None, 2]) +def test_avg_pool2d( + shape, kernel, stride, padding, ceil_mode, count_include_pad: bool, divisor_override: Optional[int] +): + check_torch_unary( shape, - lambda x: numpy_pool2d(x, kernel, stride, padding, 'avg'), - lambda x: ops.avg_pool2d(x, kernel, stride, padding), + lambda x: torch.nn.functional.avg_pool2d( + x, + kernel_size=kernel, + stride=stride, + padding=[padding[0], padding[1]], + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ), + lambda x: ops.avg_pool2d(x, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override), atol=1e-5, rtol=1e-5, )