Skip to content

Commit

Permalink
[Fixbug] Fix a performance bug in auto-scheduler (#402)
Browse files Browse the repository at this point in the history
Introduced in #399
  • Loading branch information
yaoyaoding authored Dec 31, 2023
1 parent c378ad3 commit 7c71965
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 47 deletions.
5 changes: 5 additions & 0 deletions python/hidet/drivers/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from hidet.ir.target import Target
from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument
from hidet.utils.multiprocess import parallel_imap
from hidet.utils.stack_limit import set_stack_limit

logger = logging.Logger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -100,6 +101,10 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
else:
raise ValueError(f'Invalid target: {target}')

# set the recursion limit before every lowering, because some other packages might change this value to a lower
# value that we need
set_stack_limit()

# lower ir module
instruments = []
if hidet.option.get_save_lower_ir():
Expand Down
34 changes: 3 additions & 31 deletions python/hidet/ir/schedulers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, Dict, Sequence, Tuple, Set, Optional
from collections import defaultdict

from hidet.ir.type import DataType, tensor_pointer_type
from hidet.ir.expr import TensorElement, Expr, Var, SymbolVar, Constant, scalar_var, convert, cast
Expand All @@ -19,7 +18,7 @@
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.builders import FunctionBuilder, StmtBuilder
from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter, IRVisitor
from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter
from hidet.ir.tools import IRPrinter, collect, rewrite, infer_type, simplify, collect_free_vars
from hidet.ir.compute import ScalarInput, TensorInput, GridCompute, ReduceCompute, ArgReduceCompute
from hidet.ir.compute import TensorNode, ScalarNode
Expand All @@ -33,26 +32,6 @@ class ScalarComputeFound(Exception):
pass


class UsageCounter(IRVisitor):
def __init__(self):
super().__init__()
self.usage: Dict[Union[TensorNode, ScalarNode], int] = defaultdict(int)

def visit_TensorInput(self, node: TensorInput):
self.usage[node] += 1

def visit_ScalarInput(self, node: ScalarInput):
self.usage[node] += 1

def visit_GridCompute(self, node: GridCompute):
self.usage[node] += 1
super().visit_GridCompute(node)

def visit_ReduceCompute(self, node: ReduceCompute):
self.usage[node] += 1
super().visit_ReduceCompute(node)


class GridComputeInlineChecker(ExprVisitor, ComputeVisitor):
def check(self, gc: GridCompute) -> bool:
"""Check whether the grid compute can be inlined.
Expand Down Expand Up @@ -96,10 +75,6 @@ def can_inline_grid_compute(gc: GridCompute) -> bool:


class GridComputeInliner(ExprRewriter, ComputeRewriter):
def __init__(self, usage_count: Dict[Union[TensorNode, ScalarNode], int]):
super().__init__()
self.usage_count: Dict[Union[TensorNode, ScalarNode], int] = usage_count

def inline(self, node: TensorNode):
return self.visit(node)

Expand All @@ -111,7 +86,7 @@ def visit_TensorElement(self, e: TensorElement):
indices = [self(index) for index in e.indices]
if isinstance(base, GridCompute):
assert isinstance(e.base, TensorNode)
if can_inline_grid_compute(base) or self.usage_count[e.base] == 1:
if can_inline_grid_compute(base):
return rewrite(base.value, {axis: index for axis, index in zip(base.axes, indices)})
return ExprRewriter.visit_TensorElement(self, e)

Expand Down Expand Up @@ -143,10 +118,7 @@ def inline_grid_compute(nodes: List[TensorNode]) -> List[TensorNode]:
ret: List[TensorNode]
The nodes after inlining.
"""
usage_counter = UsageCounter()
usage_counter.visit(nodes)

inliner = GridComputeInliner(usage_counter.usage)
inliner = GridComputeInliner()

return [inliner.inline(node) for node in nodes]

Expand Down
55 changes: 40 additions & 15 deletions python/hidet/utils/stack_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,43 @@
import sys
import resource

# allow up to 512 MiB stack space
expected_stack_size = 2**29 # 512 MiB
stack_limit: Tuple[int, int] = resource.getrlimit(resource.RLIMIT_STACK)
if stack_limit[1] != resource.RLIM_INFINITY and stack_limit[1] < expected_stack_size:
warnings.warn(
f'The hard limit for stack size is too small ({stack_limit[1] / 2**20:.1f} MiB), '
f'we recommend to increase it to {expected_stack_size / 2**20:.1f} MiB. '
'If you are the root user on Linux OS, you could refer to `man limits.conf` to increase this limit.'
)
resource.setrlimit(resource.RLIMIT_STACK, (stack_limit[1], stack_limit[1]))
else:
resource.setrlimit(resource.RLIMIT_STACK, (expected_stack_size, stack_limit[1]))

# allow up to 10^5 recursive python calls, increase this when needed
sys.setrecursionlimit(100000)

def set_stack_size(size: int = 2**29): # 512 MiB
"""
Set the stack size for python.
Parameters
----------
size: int
The stack size for python, in bytes.
"""
expected_stack_size = size
stack_limit: Tuple[int, int] = resource.getrlimit(resource.RLIMIT_STACK)
if stack_limit[1] != resource.RLIM_INFINITY and stack_limit[1] < expected_stack_size:
warnings.warn(
f'The hard limit for stack size is too small ({stack_limit[1] / 2**20:.1f} MiB), '
f'we recommend to increase it to {expected_stack_size / 2**20:.1f} MiB. '
'If you are the root user on Linux OS, you could refer to `man limits.conf` to increase this limit.'
)
resource.setrlimit(resource.RLIMIT_STACK, (stack_limit[1], stack_limit[1]))
else:
resource.setrlimit(resource.RLIMIT_STACK, (expected_stack_size, stack_limit[1]))


def set_stack_limit(limit: int = 100000): # 10^5 recursive calls
"""
Set the stack limit for python.
Parameters
----------
limit: int
The stack limit for python.
"""
sys.setrecursionlimit(max(sys.getrecursionlimit(), limit))


# allow more recursive python calls
set_stack_limit()

# allow more stack space
set_stack_size()
2 changes: 1 addition & 1 deletion tests/frontends/torch/models/test_torch_densenet121.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_densenet121(shape):
model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True).cuda().eval().to(torch.float16)
x = torch.randn(*shape).cuda().to(torch.float16) * 0.1796 + 0.5491
check_module(model, [x], atol=2e-2, rtol=2e-2, dynamic=False)
check_module(model, [x], atol=4e-2, rtol=4e-2, dynamic=False)


if __name__ == '__main__':
Expand Down

0 comments on commit 7c71965

Please sign in to comment.