diff --git a/pytorch_blade/torch_blade/clustering/support_fusion_algorithm.py b/pytorch_blade/torch_blade/clustering/support_fusion_algorithm.py index e8f31b08575..ea30233befe 100644 --- a/pytorch_blade/torch_blade/clustering/support_fusion_algorithm.py +++ b/pytorch_blade/torch_blade/clustering/support_fusion_algorithm.py @@ -12,6 +12,7 @@ import torch from torch_blade.logging import logger from torch_blade.algorithm import UnionSet, NxGraph +from torch_blade.config import Config class NoCycleFusedGraphBuilder(object): @@ -189,7 +190,10 @@ def can_merge(u, v): return True graph_topolist = graph_builder.group_topolist() - max_iter_count = 10 + # some graph unions may not converge in 10 iterations, provide customize setting from config + # TODO: refine cluster policy + cfg = Config.get_current_context_or_new() + max_iter_count = cfg.disc_cluster_max_iter_count while max_iter_count > 0: max_iter_count -= 1 # TODO: merge brother group nodes that not construct a cycle diff --git a/pytorch_blade/torch_blade/config.py b/pytorch_blade/torch_blade/config.py index 2a2f645f483..99a809a3345 100644 --- a/pytorch_blade/torch_blade/config.py +++ b/pytorch_blade/torch_blade/config.py @@ -151,6 +151,10 @@ def __init__(self): # Level 3: Level 2 + NoNaNs + NoSignedZeros # Level 4: Level 3 + fully llvm fast math self._disc_cpu_fast_math_level = 4 + # Note that default cluster max_iter_count number has risk of early stopping before + # cluster final convergence. If this phenomenon happens and it drastically influence + # the latency, user can try to enlarge this number. + self._disc_cluster_max_iter_count = 10 # min/max/opt settings for tuning trt engines with dynamic input shapes # looks like: # { @@ -321,6 +325,20 @@ def disc_cpu_fast_math_level(self, val): assert isinstance(val, int), "disc_cpu_fast_math_level should be int, got {}".format(type(val)) self._disc_cpu_fast_math_level = val + @property + def disc_cluster_max_iter_count(self): + """The flag to set number of max_iter_count. + + :type: int + :default: 10 + """ + return self._disc_cluster_max_iter_count + + @disc_cluster_max_iter_count.setter + def disc_cluster_max_iter_count(self, val): + assert isinstance(val, int), "disc_cluster_max_iter_count should be int, got {}".format(type(val)) + self._disc_cluster_max_iter_count = val + @property def dynamic_tuning_shapes(self): """The dynamic shapes configuration for TensorRT & TVM tuning