Skip to content

Commit

Permalink
change all_dynamic to dynamic (#10423)
Browse files Browse the repository at this point in the history
  • Loading branch information
linzs148 authored Jan 29, 2024
1 parent 8f055f3 commit 8250384
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions python/oneflow/framework/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ def get_graph(self):
size = self._deployable_module_options["size"]
else:
size = 9
if "all_dynamic" in self._deployable_module_options:
all_dynamic = self._deployable_module_options["all_dynamic"]
if "dynamic" in self._deployable_module_options:
dynamic = self._deployable_module_options["dynamic"]
else:
all_dynamic = False
dynamic = True
self._deployable_module_dpl_graph = get_oneflow_graph(
self._deployable_module_model.oneflow_module, size, all_dynamic
self._deployable_module_model.oneflow_module, size, dynamic
)
if "debug" in self._deployable_module_options:
self._deployable_module_dpl_graph.debug(
Expand Down Expand Up @@ -421,10 +421,10 @@ def save_graph(self, file_path):
flow.save(state_dict, file_path)


def get_oneflow_graph(model, size=9, all_dynamic=False):
def get_oneflow_graph(model, size=9, dynamic=True):
g = OneflowGraph(model)
g._dynamic_input_graph_cache.set_cache_size(size)
g._dynamic_input_graph_cache.enable_shared(not all_dynamic)
g._dynamic_input_graph_cache.enable_shared(dynamic)
return g


Expand Down Expand Up @@ -486,14 +486,12 @@ def compile_from_torch(
Note:
Map from torch to oneflow should be registered by `infer_compiler.register(torch2oflow_class_map={TorchModule: OneflowModule})` before `compile_from_torch` be called.
Args:s
Args:
torch_module (torch.nn.Module): Torch module to be compiled.
use_graph (bool, optional): If `True`, graph of compiled module can be saved and loaded to speedup the compile process.
Defaults to `True`.
use_graph (bool, optional): If `True`, graph of compiled module can be saved and loaded to speedup the compile process. Defaults to `True`.
options (dict, optional):
size (int, optional): graph cache size. Defaults to `9`.
all_dynamic (bool, optional): If `True`, graph of compiled module can't be shared with other modules.
Defaults to `False`.
dynamic (bool, optional): If `True`, graph of compiled module can be shared with other modules. Defaults to `True`.
debug (int, optional): debug level. Defaults to `-1`.
Returns:
Expand Down

0 comments on commit 8250384

Please sign in to comment.