Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Aalanli committed Aug 24, 2023
1 parent 3ca79cb commit e1862a2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 120 deletions.
2 changes: 0 additions & 2 deletions python/hidet/graph/ops/matmul/matmul_f16.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# %%
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -367,4 +366,3 @@ def matmul_f16(a: Tensor, b: Tensor, parallel_k_parts=1) -> Tensor:
if a.dtype != dtypes.float16 or b.dtype != dtypes.float16:
raise ValueError('BatchMatmulF16Op only support float16, got {} and {}'.format(a.dtype, b.dtype))
return MatmulF16Op(a, b, parallel_k_parts).outputs[0]

126 changes: 12 additions & 114 deletions python/hidet/ir/tools/ir_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,113 +570,6 @@ def visit_ReferenceType(self, t: ReferenceType):
def visit_VoidType(self, t: VoidType):
return Text('void')

# def visit_FuncType(self, t: FuncType):
# if t.type_infer_func is not None:
# return Text('FuncType[type_infer_func]')
# else:
# return Text('FuncType(params={}, ret={})'.format(self(t.param_types), self(t.ret_type)))

# def visit_PlaceholderExpr(self, e: PlaceholderExpr):
# if e.required_type:
# type_doc = self(e.required_type) + '_'
# else:
# type_doc = ''

# if e.require_const:
# base = 'const'
# elif e.require_non_const:
# base = 'expr'
# else:
# base = 'any'

# return Text(type_doc + base)

# def print_tensor_nodes(self, nodes: List[TensorNode], exclude_nodes: List[TensorNode] = None) -> Doc:
# from hidet.ir.tools import collect # pylint: disable=import-outside-toplevel
# from hidet.utils.structure import DirectedGraph

# if exclude_nodes is None:
# exclude_nodes = []
# nodes: List[TensorNode] = collect(nodes, TensorNode)
# dag = DirectedGraph()
# for node in nodes:
# dag.add_node(node)
# if isinstance(node, GridCompute):
# depends = collect(node.value, TensorNode, stop_when_found=True)
# elif isinstance(node, TensorInput):
# depends = []
# else:
# raise NotImplementedError()
# for depend_node in depends:
# dag.add_edge(src=depend_node, dst=node)
# order = dag.topological_order()

# doc = Doc()
# for node in order:
# if node in exclude_nodes:
# continue
# if isinstance(node, TensorInput):
# pass
# elif isinstance(node, GridCompute):
# # example
# # y: float32[10, 10] where y[i, j] = x[i, j] + 1
# doc += NewLine()
# doc += self(node) + ': ' + self(node.type.dtype) + '[' + self(node.type.shape) + ']'
# doc += Text(' where ') + self(node) + '[' + self(node.axes) + '] = ' + self(node.value)
# else:
# raise NotImplementedError()
# return doc

# def visit_Task(self, e: Task):
# lines = [
# Text('name: ') + e.name,
# Text('parameters: ')
# + (
# NewLine()
# + doc_join(['{}: {}'.format(self.namer.get_name(v), self(v.type)) for v in e.params], NewLine())
# ).indent(),
# Text('inputs: ') + '[' + doc_join([self.namer.get_name(v) for v in e.inputs], ', ') + ']',
# Text('outputs: ') + '[' + doc_join([self.namer.get_name(v) for v in e.outputs], ', ') + ']',
# Text('computations: ') + self.print_tensor_nodes(e.outputs).indent(),
# Text('attributes: {') + self({k: str(v) for k, v in e.attrs.items()}) + '}',
# ]
# if len(e.assertions) > 0: # between computations and attributes
# lines.append(
# Text('assertions: ')
# + (
# NewLine() # self.assertions: List[Tuple[Expr, str]]
# + doc_join(['assert {}'.format(str(self.visit(v[0]))) for v in e.assertions], NewLine())
# ).indent()
# )
# front_part = doc_join(lines, NewLine())
# inverse_map_doc = Doc()
# if e.inverse_map:
# inverse_map_doc += NewLine() + Text('inverse_map:')
# for tensor, inverse_map in e.inverse_map.items():
# inverse_map_body = 'InverseMap([' + self(inverse_map.axes) + '] => [' + self(inverse_map.indices) + '])'
# inverse_map_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + inverse_map_body).indent()
# return Text('Task(') + (NewLine() + front_part + inverse_map_doc).indent() + NewLine() + ')'

# def visit_TensorNode(self, e: TensorNode):
# return self.namer.get_name(e)

# def visit_ScalarInput(self, node: ScalarInput):
# return self.namer.get_name(node)

# def visit_TensorInput(self, node: TensorInput):
# return self.namer.get_name(node)

# def visit_GridCompute(self, c: GridCompute):
# return self.namer.get_name(c)

# def visit_ReduceCompute(self, c: ReduceCompute):
# items = ['[' + self(c.shape) + ']', '(' + self(c.axes) + ') => ' + self(c.value), str(c.reduce_operation)]
# return 'reduce(' + doc_join(items, ', ') + ')'

# def visit_ArgReduceCompute(self, c: ArgReduceCompute):
# items = ['[' + self(c.extent) + ']', self(c.axis) + ' => ' + self(c.value), str(c.reduce_operation)]
# return 'arg_reduce(' + doc_join(items, ', ') + ')'

def visit_SpatialTaskMapping(self, mapping: SpatialTaskMapping):
items = ['task_shape=list(' + self(mapping.task_shape) + ')']
# if not same_list(mapping.ranks, list(range(len(mapping.task_shape)))):
Expand Down Expand Up @@ -813,14 +706,14 @@ def preprocess_symbolvar(tree: Tree):
attr = tree.children[0]
assert attr.data.value == 'attribute'
attr_dict = attr.children[1].children
for tree in attr_dict:
name, value = construct_pair(tree)
for tree_ in attr_dict:
name, value = construct_pair(tree_)
if str(name)[1:-1] == 'symbols':
symbol_var = {}
symbol_var_ = {}
for symbol in value.children:
symbol_name, symbol_type = construct_pair(symbol)
symbol_var[str(symbol_name)[1:-1]] = str(symbol_type)[1:-1]
return symbol_var
symbol_var_[str(symbol_name)[1:-1]] = str(symbol_type)[1:-1]
return symbol_var_
return None


Expand Down Expand Up @@ -898,12 +791,14 @@ def resolve_binary_op(op, a, b):
return logical_and(a, b)
elif op == '||':
return logical_or(a, b)
# pylint: disable=eval-used
return eval(f'a {op} b')


def resolve_unary_op(op, a):
if op == '!':
return logical_not(a)
# pylint: disable=eval-used
return eval(f'{op} a')


Expand Down Expand Up @@ -950,6 +845,7 @@ def visit(self, node) -> Node:
raise NotImplementedError("Unknown node type: {}".format(type(node)))
else:
print(type(node), node)
return node

def visit_STRING(self, value):
return str(value[1:-1])
Expand Down Expand Up @@ -1024,16 +920,18 @@ def visit_name(self, node):

def visit_unary_op(self, node):
op, expr = node
# pylint: disable=broad-except
try:
return resolve_unary_op(op, self.visit(expr))
except Exception as e:
except Exception:
return node

def visit_binary_op(self, node):
a, op, b = node
# pylint: disable=broad-except
try:
return resolve_binary_op(op, self.visit(a), self.visit(b))
except Exception as e:
except Exception:
return node

def visit_keyword_arg(self, node):
Expand Down
8 changes: 4 additions & 4 deletions tests/ir/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def generate_ir_modules():
]
for mod in [get_matmul_task(), get_bmatmul_task(), get_softmax_task(), get_attn_task()]:
for t in transforms:
# if hasattr(t, '__name__'):
# print(t.__name__)
# else:
# print(t.__class__.__name__)
if hasattr(t, '__name__'):
print(t.__name__)
else:
print(t.__class__.__name__)
mod = t(mod)
yield mod

Expand Down

0 comments on commit e1862a2

Please sign in to comment.