From 7c71965a220e2696dba123d060cd4ad049405563 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 30 Dec 2023 21:58:41 -0500 Subject: [PATCH] [Fixbug] Fix a performance bug in auto-scheduler (#402) Introduced in #399 --- python/hidet/drivers/build_module.py | 5 ++ python/hidet/ir/schedulers/base.py | 34 +----------- python/hidet/utils/stack_limit.py | 55 ++++++++++++++----- .../torch/models/test_torch_densenet121.py | 2 +- 4 files changed, 49 insertions(+), 47 deletions(-) diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 1e10fab34..99fa64ec3 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -26,6 +26,7 @@ from hidet.ir.target import Target from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument from hidet.utils.multiprocess import parallel_imap +from hidet.utils.stack_limit import set_stack_limit logger = logging.Logger(__name__) logger.setLevel(logging.INFO) @@ -100,6 +101,10 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output else: raise ValueError(f'Invalid target: {target}') + # set the recursion limit before every lowering, because some other packages might change this value to a lower + # value that we need + set_stack_limit() + # lower ir module instruments = [] if hidet.option.get_save_lower_ir(): diff --git a/python/hidet/ir/schedulers/base.py b/python/hidet/ir/schedulers/base.py index 2a4331416..a250a14cf 100644 --- a/python/hidet/ir/schedulers/base.py +++ b/python/hidet/ir/schedulers/base.py @@ -10,7 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Union, List, Dict, Sequence, Tuple, Set, Optional -from collections import defaultdict from hidet.ir.type import DataType, tensor_pointer_type from hidet.ir.expr import TensorElement, Expr, Var, SymbolVar, Constant, scalar_var, convert, cast @@ -19,7 +18,7 @@ from hidet.ir.func import Function from hidet.ir.module import IRModule from hidet.ir.builders import FunctionBuilder, StmtBuilder -from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter, IRVisitor +from hidet.ir.functors import ExprRewriter, ExprVisitor, ComputeVisitor, ComputeRewriter, TypeRewriter from hidet.ir.tools import IRPrinter, collect, rewrite, infer_type, simplify, collect_free_vars from hidet.ir.compute import ScalarInput, TensorInput, GridCompute, ReduceCompute, ArgReduceCompute from hidet.ir.compute import TensorNode, ScalarNode @@ -33,26 +32,6 @@ class ScalarComputeFound(Exception): pass -class UsageCounter(IRVisitor): - def __init__(self): - super().__init__() - self.usage: Dict[Union[TensorNode, ScalarNode], int] = defaultdict(int) - - def visit_TensorInput(self, node: TensorInput): - self.usage[node] += 1 - - def visit_ScalarInput(self, node: ScalarInput): - self.usage[node] += 1 - - def visit_GridCompute(self, node: GridCompute): - self.usage[node] += 1 - super().visit_GridCompute(node) - - def visit_ReduceCompute(self, node: ReduceCompute): - self.usage[node] += 1 - super().visit_ReduceCompute(node) - - class GridComputeInlineChecker(ExprVisitor, ComputeVisitor): def check(self, gc: GridCompute) -> bool: """Check whether the grid compute can be inlined. @@ -96,10 +75,6 @@ def can_inline_grid_compute(gc: GridCompute) -> bool: class GridComputeInliner(ExprRewriter, ComputeRewriter): - def __init__(self, usage_count: Dict[Union[TensorNode, ScalarNode], int]): - super().__init__() - self.usage_count: Dict[Union[TensorNode, ScalarNode], int] = usage_count - def inline(self, node: TensorNode): return self.visit(node) @@ -111,7 +86,7 @@ def visit_TensorElement(self, e: TensorElement): indices = [self(index) for index in e.indices] if isinstance(base, GridCompute): assert isinstance(e.base, TensorNode) - if can_inline_grid_compute(base) or self.usage_count[e.base] == 1: + if can_inline_grid_compute(base): return rewrite(base.value, {axis: index for axis, index in zip(base.axes, indices)}) return ExprRewriter.visit_TensorElement(self, e) @@ -143,10 +118,7 @@ def inline_grid_compute(nodes: List[TensorNode]) -> List[TensorNode]: ret: List[TensorNode] The nodes after inlining. """ - usage_counter = UsageCounter() - usage_counter.visit(nodes) - - inliner = GridComputeInliner(usage_counter.usage) + inliner = GridComputeInliner() return [inliner.inline(node) for node in nodes] diff --git a/python/hidet/utils/stack_limit.py b/python/hidet/utils/stack_limit.py index 97b0622c5..543732aee 100644 --- a/python/hidet/utils/stack_limit.py +++ b/python/hidet/utils/stack_limit.py @@ -17,18 +17,43 @@ import sys import resource -# allow up to 512 MiB stack space -expected_stack_size = 2**29 # 512 MiB -stack_limit: Tuple[int, int] = resource.getrlimit(resource.RLIMIT_STACK) -if stack_limit[1] != resource.RLIM_INFINITY and stack_limit[1] < expected_stack_size: - warnings.warn( - f'The hard limit for stack size is too small ({stack_limit[1] / 2**20:.1f} MiB), ' - f'we recommend to increase it to {expected_stack_size / 2**20:.1f} MiB. ' - 'If you are the root user on Linux OS, you could refer to `man limits.conf` to increase this limit.' - ) - resource.setrlimit(resource.RLIMIT_STACK, (stack_limit[1], stack_limit[1])) -else: - resource.setrlimit(resource.RLIMIT_STACK, (expected_stack_size, stack_limit[1])) - -# allow up to 10^5 recursive python calls, increase this when needed -sys.setrecursionlimit(100000) + +def set_stack_size(size: int = 2**29): # 512 MiB + """ + Set the stack size for python. + + Parameters + ---------- + size: int + The stack size for python, in bytes. + """ + expected_stack_size = size + stack_limit: Tuple[int, int] = resource.getrlimit(resource.RLIMIT_STACK) + if stack_limit[1] != resource.RLIM_INFINITY and stack_limit[1] < expected_stack_size: + warnings.warn( + f'The hard limit for stack size is too small ({stack_limit[1] / 2**20:.1f} MiB), ' + f'we recommend to increase it to {expected_stack_size / 2**20:.1f} MiB. ' + 'If you are the root user on Linux OS, you could refer to `man limits.conf` to increase this limit.' + ) + resource.setrlimit(resource.RLIMIT_STACK, (stack_limit[1], stack_limit[1])) + else: + resource.setrlimit(resource.RLIMIT_STACK, (expected_stack_size, stack_limit[1])) + + +def set_stack_limit(limit: int = 100000): # 10^5 recursive calls + """ + Set the stack limit for python. + + Parameters + ---------- + limit: int + The stack limit for python. + """ + sys.setrecursionlimit(max(sys.getrecursionlimit(), limit)) + + +# allow more recursive python calls +set_stack_limit() + +# allow more stack space +set_stack_size() diff --git a/tests/frontends/torch/models/test_torch_densenet121.py b/tests/frontends/torch/models/test_torch_densenet121.py index 219800979..61504967e 100644 --- a/tests/frontends/torch/models/test_torch_densenet121.py +++ b/tests/frontends/torch/models/test_torch_densenet121.py @@ -18,7 +18,7 @@ def test_densenet121(shape): model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True).cuda().eval().to(torch.float16) x = torch.randn(*shape).cuda().to(torch.float16) * 0.1796 + 0.5491 - check_module(model, [x], atol=2e-2, rtol=2e-2, dynamic=False) + check_module(model, [x], atol=4e-2, rtol=4e-2, dynamic=False) if __name__ == '__main__':