-
Notifications
You must be signed in to change notification settings - Fork 2
/
tvm_pretty_print.py
executable file
·132 lines (101 loc) · 3.48 KB
/
tvm_pretty_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
"""Pretty-prints any TVM ObjectRef objects
Usage:
- Source this file in your `.gdbinit`
`source ~/path/to/the/tvm_pretty_print.py`
TODO:
- Longer string result from gdb.parse_and_eval, some objects get
truncated.
"""
import enum
from abc import abstractmethod
from typing import Optional
import gdb
class PrettyPrintLevel(enum.Flag):
Disabled = 0
DataType = enum.auto()
ObjectRef = enum.auto()
Default = DataType
All = Default | ObjectRef
class PrettyPrinter:
_printers = []
def __init_subclass__(cls, /, pprint_level):
cls.pprint_level = pprint_level
PrettyPrinter._printers.append(cls)
@classmethod
def register(cls, pprint_level):
for subclass in cls._printers:
if pprint_level & subclass.pprint_level:
gdb.pretty_printers.append(subclass.lookup)
@classmethod
@abstractmethod
def lookup(cls, val) -> Optional["Self"]:
"""Return the printer that should print the value"""
@abstractmethod
def to_string(self) -> str:
"""Convert the value to a string"""
class TVM_ObjectRef(PrettyPrinter, pprint_level=PrettyPrintLevel.ObjectRef):
@classmethod
def lookup(cls, val):
if val.type.code == gdb.TYPE_CODE_PTR:
obj = val.referenced_value()
elif val.type.code == gdb.TYPE_CODE_STRUCT:
obj = val
else:
return
try:
object_ref_type = gdb.lookup_type("::tvm::runtime::ObjectRef")
except Exception as e:
# TVM not loaded, so don't use this printer
return
ptr_type = object_ref_type.const().pointer()
try:
as_objref_pointer = obj.address.dynamic_cast(ptr_type)
return cls(as_objref_pointer)
except gdb.error:
# Not a subclass of ObjectRef, so don't use this printer
return
def __init__(self, pointer):
self.pointer = pointer
def to_string(self):
command = "::tvm::PrettyPrint(*(::tvm::runtime::ObjectRef*){}).c_str()".format(
int(self.pointer)
)
# TODO: Figure out better handling during segfaults, not safe
# to make calls into tvm at that point.
# TODO: No string length limit on parse_and_eval
output = gdb.parse_and_eval(command)
as_str = str(output)
parsed = (
as_str[as_str.find('"') : as_str.rfind('"') + 1]
.encode("ascii")
.decode("unicode_escape")
)
return parsed
class TVM_DataType(PrettyPrinter, pprint_level=PrettyPrintLevel.DataType):
@classmethod
def lookup(cls, val):
try:
datatype_type = gdb.lookup_type("::tvm::runtime::DataType")
except Exception as e:
# TVM not loaded, so don't use this printer
return
if val.type == datatype_type:
return cls(val)
def __init__(self, val):
self.val = val
def to_string(self):
data = self.val[self.val.type.fields()[0]]
data_fields = data.type.fields()
values = {
field.name: int(data[field].format_string(format="d"))
for field in data.type.fields()
}
import tvm
# Can't construct directly from the type_code/bits/lanes, but
# can set them afterwards.
dtype = tvm.DataType("int")
dtype.type_code = values["code"]
dtype.bits = values["bits"]
dtype.lanes = values["lanes"]
return repr(dtype)