Skip to content

Commit

Permalink
[mlir][python] add type wrappers (llvm#71218)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Nov 27, 2023
1 parent 6a9613e commit 225648e
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 25 deletions.
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2558,8 +2558,8 @@ void mlir::python::populateIRCore(py::module &m) {
[](py::object & /*class*/) {
auto *context = PyThreadContextEntry::getDefaultContext();
if (!context)
throw py::value_error("No current Context");
return context;
return py::none().cast<py::object>();
return py::cast(context);
},
"Gets the Context bound to the current thread or raises ValueError")
.def_property_readonly(
Expand Down
24 changes: 8 additions & 16 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {

static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, py::arg("shape"),
py::arg("elementType"), py::kw_only(),
py::arg("element_type"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
Expand Down Expand Up @@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get_tuple",
[](py::list elementList, DefaultingPyMlirContext context) {
intptr_t num = py::len(elementList);
// Mapping py::list to SmallVector.
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>());
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
Expand Down Expand Up @@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<PyType> inputs, std::vector<PyType> results,
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
inputsRaw.data(), resultsRaw.size(),
resultsRaw.data());
MlirType t =
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
Expand All @@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
Expand All @@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
Expand Down
1 change: 1 addition & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
extras/types.py
dialects/_ods_common.py

# The main _mlir module has submodules: include stubs from each.
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions mlir/python/mlir/extras/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from functools import partial
from typing import Optional, List

from ..ir import (
Attribute,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float8E4M3B11FNUZType,
Float8E4M3FNType,
Float8E5M2Type,
FunctionType,
IndexType,
IntegerType,
MemRefType,
NoneType,
OpaqueType,
RankedTensorType,
StridedLayoutAttr,
StringAttr,
TupleType,
Type,
UnrankedMemRefType,
UnrankedTensorType,
VectorType,
)

index = lambda: IndexType.get()


def i(width):
return IntegerType.get_signless(width)


def si(width):
return IntegerType.get_signed(width)


def ui(width):
return IntegerType.get_unsigned(width)


bool = lambda: i(1)
i8 = lambda: i(8)
i16 = lambda: i(16)
i32 = lambda: i(32)
i64 = lambda: i(64)

si8 = lambda: si(8)
si16 = lambda: si(16)
si32 = lambda: si(32)
si64 = lambda: si(64)

ui8 = lambda: ui(8)
ui16 = lambda: ui(16)
ui32 = lambda: ui(32)
ui64 = lambda: ui(64)

f16 = lambda: F16Type.get()
f32 = lambda: F32Type.get()
f64 = lambda: F64Type.get()
bf16 = lambda: BF16Type.get()

f8E5M2 = lambda: Float8E5M2Type.get()
f8E4M3 = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()

none = lambda: NoneType.get()


def complex(type):
return ComplexType.get(type)


def opaque(dialect_namespace, type_data):
return OpaqueType.get(dialect_namespace, type_data)


def _shaped(*shape, element_type: Type = None, type_constructor=None):
if type_constructor is None:
raise ValueError("shaped is an abstract base class - cannot be constructed.")
if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
shape and isinstance(shape[-1], Type) and element_type is not None
):
raise ValueError(
f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
)
if element_type is not None:
type = element_type
sizes = shape
else:
type = shape[-1]
sizes = shape[:-1]
if sizes:
return type_constructor(sizes, type)
else:
return type_constructor(type)


def vector(
*shape,
element_type: Type = None,
scalable: Optional[List[bool]] = None,
scalable_dims: Optional[List[int]] = None,
):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
VectorType.get, scalable=scalable, scalable_dims=scalable_dims
),
)


def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
if encoding is not None:
encoding = StringAttr.get(encoding)
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
if encoding is not None:
raise ValueError("UnrankedTensorType does not support encoding.")
return _shaped(
*shape, element_type=element_type, type_constructor=UnrankedTensorType.get
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(RankedTensorType.get, encoding=encoding),
)


def memref(
*shape,
element_type: Type = None,
memory_space: Optional[int] = None,
layout: Optional[StridedLayoutAttr] = None,
):
if memory_space is not None:
memory_space = Attribute.parse(str(memory_space))
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
MemRefType.get, memory_space=memory_space, layout=layout
),
)


def tuple(*elements):
return TupleType.get_tuple(elements)


def function(*, inputs, results):
return FunctionType.get(inputs, results)
99 changes: 99 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
from mlir.ir import *
from mlir.dialects import arith, tensor, func, memref
import mlir.extras.types as T


def run(f):
Expand Down Expand Up @@ -772,3 +773,101 @@ def testCustomTypeTypeCaster():
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))


# CHECK-LABEL: TEST: testTypeWrappers
@run
def testTypeWrappers():
def stride(strides, offset=0):
return StridedLayoutAttr.get(offset, strides)

with Context(), Location.unknown():
ia = T.i(5)
sia = T.si(6)
uia = T.ui(7)
assert repr(ia) == "IntegerType(i5)"
assert repr(sia) == "IntegerType(si6)"
assert repr(uia) == "IntegerType(ui7)"

assert T.i(16) == T.i16()
assert T.si(16) == T.si16()
assert T.ui(16) == T.ui16()

c1 = T.complex(T.f16())
c2 = T.complex(T.i32())
assert repr(c1) == "ComplexType(complex<f16>)"
assert repr(c2) == "ComplexType(complex<i32>)"

vec_1 = T.vector(2, 3, T.f32())
vec_2 = T.vector(2, 3, 4, T.f32())
assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"

m1 = T.memref(2, 3, 4, T.f64())
assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"

m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"

m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"

m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
assert (
repr(m4)
== "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
)

S = ShapedType.get_dynamic_size()

t1 = T.tensor(S, 3, S, T.f64())
assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
ut1 = T.tensor(T.f64())
assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
t2 = T.tensor(S, 3, S, element_type=T.f64())
assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
ut2 = T.tensor(element_type=T.f64())
assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"

t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'

v = T.vector(3, 3, 3, T.f64())
assert repr(v) == "VectorType(vector<3x3x3xf64>)"

m5 = T.memref(S, 3, S, T.f64())
assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
um1 = T.memref(T.f64())
assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
m6 = T.memref(S, 3, S, element_type=T.f64())
assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
um2 = T.memref(element_type=T.f64())
assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"

m7 = T.memref(S, 3, S, T.f64())
assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
um3 = T.memref(T.f64())
assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"

scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"

scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
assert scalable_3 == scalable_1
assert scalable_4 == scalable_2

opaq = T.opaque("scf", "placeholder")
assert repr(opaq) == "OpaqueType(!scf.placeholder)"

tup1 = T.tuple(T.i16(), T.i32(), T.i64())
tup2 = T.tuple(T.f16(), T.f32(), T.f64())
assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"

func = T.function(
inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
)
assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
8 changes: 1 addition & 7 deletions mlir/test/python/ir/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@ def run(f):
def testContextEnterExit():
with Context() as ctx:
assert Context.current is ctx
try:
_ = Context.current
except ValueError as e:
# CHECK: No current Context
print(e)
else:
assert False, "Expected exception"
assert Context.current is None


run(testContextEnterExit)
Expand Down

0 comments on commit 225648e

Please sign in to comment.