Skip to content

Commit

Permalink
[Ir] support integer subbyte
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaocenxiaocen committed Jan 2, 2024
1 parent 7c71965 commit baffdb2
Show file tree
Hide file tree
Showing 10 changed files with 613 additions and 8 deletions.
2 changes: 2 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,8 @@ def visit_DataType(self, t: DataType):
'float32x8': '__m256',
'int8x4': 'char4',
'uint8x4': 'uint4',
'int4bx8': 'uint32_t',
'uint4bx8': 'uint32_t',
}

self.require_complex = self.require_complex or t.name in ['complex64', 'complex128']
Expand Down
26 changes: 24 additions & 2 deletions python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from hidet.ir.type import DataType
from .integer import int8, int16, int32, int64, uint8, uint16, uint32, uint64
from .integer import i8, i16, i32, i64, u8, u16, u32, u64
from .integer_subbyte import int4b, int3b, int2b, int1b, uint4b, uint3b, uint2b, uint1b
from .integer_subbyte import i4, i3, i2, i1, u4, u3, u2, u1
from .floats import float16, float32, float64, bfloat16, tfloat32
from .floats import f16, f32, f64, bf16, tf32
from .boolean import boolean
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, vectorize
from .vector import f16x2, f32x4, f32x8
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize
from .vector import f16x2, f32x4, f32x8, i4x8, u4x8
from .complex import complex64, complex128
from .promotion import promote_type
from .utils import dtype_to_numpy, finfo, iinfo
Expand All @@ -43,6 +45,16 @@
'float16x2': float16x2,
'int8x4': int8x4,
'uint8x4': uint8x4,
'int4b': int4b,
'int3b': int3b,
'int2b': int2b,
'int1b': int1b,
'uint4b': uint4b,
'uint3b': uint3b,
'uint2b': uint2b,
'uint1b': uint1b,
'int4bx8': int4bx8,
'uint4bx8': uint4bx8,
}

sname2dtype = {
Expand All @@ -66,6 +78,16 @@
'f32x8': f32x8,
'f16x2': f16x2,
'i8x4': int8x4,
'i4': int4b,
'i3': int3b,
'i2': int2b,
'i1': int1b,
'u4': uint4b,
'u3': uint3b,
'u2': uint2b,
'u1': uint1b,
'i4x8': int4bx8,
'u4x8': uint4bx8,
}


Expand Down
48 changes: 48 additions & 0 deletions python/hidet/ir/dtypes/integer_subbyte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.type import DataType
from .integer import IntegerType, IntInfo, uint8, uint32


class IntegerSubbyteType(IntegerType):
def __init__(self, name, short_name, storage, nbits, signed, min_value, max_value):
nbytes = storage.nbytes
super().__init__(name, short_name, nbytes, min_value, max_value)
self._storage: DataType = storage
self._nbits: int = nbits
self._signed: bool = signed
self._bits_mask: int = (1 << self._nbits) - 1
self._sign_mask: int = 1 << (self._nbits - 1) if self._signed else 0

def iinfo(self) -> IntInfo:
return IntInfo(self._nbits, self._max_value, self._min_value, self)


int4b = IntegerSubbyteType('int4b', 'i4', uint8, 4, True, -8, 7)
int3b = IntegerSubbyteType('int3b', 'i3', uint32, 3, True, -4, 3)
int2b = IntegerSubbyteType('int2b', 'i2', uint8, 2, True, -2, 1)
int1b = IntegerSubbyteType('int1b', 'i1', uint8, 1, True, -1, 0)

uint4b = IntegerSubbyteType('uint4b', 'u4', uint8, 4, False, 0, 16)
uint3b = IntegerSubbyteType('uint3b', 'u3', uint32, 3, False, 0, 8)
uint2b = IntegerSubbyteType('uint2b', 'u2', uint8, 2, False, 0, 4)
uint1b = IntegerSubbyteType('uint1b', 'u1', uint8, 1, False, 0, 1)

i4 = int4b
i3 = int3b
i2 = int2b
i1 = int1b

u4 = uint4b
u3 = uint3b
u2 = uint2b
u1 = uint1b
15 changes: 14 additions & 1 deletion python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from hidet.ir.type import DataType
from .floats import float32, float16
from .integer import int8, uint8
from .integer_subbyte import int4b, uint4b


class VectorType(DataType):
def __init__(self, lane_type: DataType, num_lanes: int):
name = '{}x{}'.format(lane_type.name, num_lanes)
short_name = '{}x{}'.format(lane_type.short_name, num_lanes)
nbytes = lane_type.nbytes * num_lanes
nbytes = lane_type.nbytes * num_lanes if not lane_type.is_integer_subbyte() else lane_type.nbits * num_lanes // 8
super().__init__(name, short_name, nbytes)
self._num_lanes: int = num_lanes
self._lane_type: DataType = lane_type
Expand Down Expand Up @@ -87,6 +88,18 @@ def max_value(self):
float16x2 = VectorType(float16, 2)
f16x2 = float16x2

int4bx2 = VectorType(int4b, 2)
i4x2 = int4bx2

uint4bx2 = VectorType(uint4b, 2)
u4x2 = uint4bx2

int4bx8 = VectorType(int4b, 8)
i4x8 = int4bx8

uint4bx8 = VectorType(uint4b, 8)
u4x8 = uint4bx8


def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
table = {
Expand Down
12 changes: 12 additions & 0 deletions python/hidet/ir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def from_str(name):
else:
return DeclareScope.Default

def is_global(self):
return self == DeclareScope.Global

def is_shared(self):
return self == DeclareScope.Shared

def is_register(self):
return self == DeclareScope.Register

def is_memory(self):
return not self.is_register()


class ForStmtAttr:
def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None):
Expand Down
27 changes: 26 additions & 1 deletion python/hidet/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,21 @@ def as_data_type(self) -> Optional[DataType]:
class DataType(BaseType):
"""
The data type that defines how to interpret the data in memory.
Note:
1. The _storage field for non-subbyte types is the type itself, while the _storage
for subbyte types is the type of its actual storage. e.g., the storage for int4b is uint8
2. The _storage field will be overwritten during the construction of subbyte types
2. The _nbits field in the constructor denotes the bit length of the storage, and
it will be overwritten in the constructor of subbyte types
"""

def __init__(self, name: str, short_name: str, nbytes: int):
self._name: str = name
self._short_name: str = short_name
self._storage = self
self._nbytes: int = nbytes
self._nbits: int = self._nbytes * 8

def __str__(self):
return 'hidet.{}'.format(self.name)
Expand Down Expand Up @@ -129,8 +138,21 @@ def short_name(self) -> str:

@property
def nbytes(self) -> int:
if self._nbits < 8:
raise TypeError(f"Cannot access nbytes property for the type({self}")
return self._nbytes

@property
def nbits(self) -> int:
return self._nbits

@property
def storage(self) -> DataType:
return self._storage

def is_integer_subbyte(self) -> bool:
return self.is_integer() and self._nbits < 8

def is_float(self) -> bool:
raise NotImplementedError()

Expand Down Expand Up @@ -187,7 +209,10 @@ def __invert__(self):
return TensorPointerType.from_tensor_type(self)

def storage_bytes(self) -> Expr:
return self.layout.size * self.dtype.nbytes
if self.dtype.is_integer_subbyte():
return self.layout.size * self.dtype._nbits // 8
else:
return self.layout.size * self.dtype.nbytes

def const_shape(self) -> List[int]:
return [int(v) for v in self.shape]
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .check_launch_configuration import check_launch_configuration_pass
from .lower_special_cast import lower_special_cast_pass
from .annotate_header_and_libs import annotate_header_and_libs_pass
from .lower_integer_subbyte import lower_integer_subbyte_pass


def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule:
Expand All @@ -63,6 +64,7 @@ def lower(ir_module: IRModule) -> IRModule:
declare_to_let_pass(),
rule_based_simplify_pass(), # make ir more readable
flatten_tensor_index_pass(),
lower_integer_subbyte_pass(),
lower_special_cast_pass(),
inline_function_pass(),
resolve_primitive_func_pass(),
Expand Down
33 changes: 29 additions & 4 deletions python/hidet/transforms/flatten_tensor_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType, ArrayType
from typing import Dict

from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType, ArrayType, FuncType, func_type
from hidet.ir.expr import Var, TensorElement, TensorSlice, tensor_element
from hidet.ir.stmt import BufferStoreStmt, DeclareStmt
from hidet.ir.layout import row_major
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter
from hidet.ir.tools import simplify, TypeInfer
from hidet.transforms import Pass
Expand All @@ -28,6 +31,15 @@ class FlattenTensorAccessRewriter(IRRewriter):
def __init__(self):
super().__init__()
self.type_infer = TypeInfer()
self.func2func_type: Dict[str, FuncType] = {}

def visit_Var(self, v: Var):
if isinstance(v.type, FuncType):
if v.name in self.func2func_type:
func_ty = self.func2func_type[v.name]
if func_ty is not v.type:
return Var(v.hint, func_ty, v.name)
return super().visit_Var(v)

def visit_Function(self, func: Function):
for var in func.params:
Expand All @@ -39,7 +51,13 @@ def visit_Function(self, func: Function):
self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.tensor_type.dtype, [size]))
body = self(func.body)
params = [self(p) for p in func.params]
return Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs)
if body is func.body and all([p is p1 for p, p1 in zip(params, func.params)]):
return func
else:
new_func = Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs)
param_types = [p.type for p in params]
self.func2func_type[func.name] = func_type(param_types, func.ret_type)
return new_func

def get_layout(self, e) -> DataLayout:
e_type = self.type_infer(e)
Expand Down Expand Up @@ -103,9 +121,16 @@ def visit_TensorSlice(self, e: TensorSlice):


class FlattenTensorIndexPass(Pass):
def process_func(self, func: Function) -> Function:
def process_module(self, ir_module: IRModule) -> IRModule:
flatten_index = FlattenTensorAccessRewriter()
return flatten_index(func)

new_funcs = {}
for name, func in ir_module.functions.items():
new_funcs[name] = flatten_index(func)
if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs):
return ir_module
else:
return ir_module.copy().reset_funcs(new_funcs, ir_module.global_vars)


def flatten_tensor_index_pass():
Expand Down
Loading

0 comments on commit baffdb2

Please sign in to comment.