Skip to content

Commit

Permalink
Fix bug when compiling faster_rcnn's backbone (#10414)
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyang9804 authored Jan 22, 2024
1 parent 7ddeacf commit 0320ed0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 2 additions & 0 deletions oneflow/core/job_rewriter/job_completer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ Maybe<void> JobCompleter::UpdateSharedGraphForNewInput(
*pair.second.mutable_at_shape() = attr_iter->second.at_shape();
} else if (pair.second.has_at_double()) {
pair.second.set_at_double(attr_iter->second.at_double());
} else if (pair.second.has_at_list_int64()) {
pair.second.mutable_at_list_int64()->CopyFrom(attr_iter->second.at_list_int64());
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __call__(self, *args, **kwargs):
Donot override this function.
"""
# For cache cache graphs with dynamic input shape.
if self._run_with_cache == True:
if self._run_with_cache:
return self._dynamic_input_graph_cache(*args, **kwargs)

if not self._is_compiled:
Expand Down Expand Up @@ -844,7 +844,7 @@ def build(self, *args, **kwargs):
return a_graph

def _compile(self, *args, **kwargs):
if self._run_with_cache == True:
if self._run_with_cache:
return self._dynamic_input_graph_cache._compile(*args, **kwargs)

if not self._is_compiled:
Expand Down Expand Up @@ -909,7 +909,7 @@ def _compile_from_shared(self, *args, **kwargs):
# Filter to get unique states in graph
state_op_names = self._filter_states()
# Generate new config.
if self._shared_graph._is_from_runtime_state_dict == True:
if self._shared_graph._is_from_runtime_state_dict:
# To avoid same graph name with the loaded graphs.
self._name = (
self._name + "_of_shared_from_loaded_" + self._shared_graph.name
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def runtime_state_dict(
Dict[str, Union[Dict[str, Tensor], str]],
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
]:
if self._run_with_cache == True:
if self._run_with_cache:
return self._dynamic_input_graph_cache.runtime_state_dict(
with_eager=with_eager
)
Expand Down Expand Up @@ -1135,7 +1135,7 @@ def gen_index_in_tuple(item):
destination["states"] = states_sub_destination

destination["exe_plan"] = self._c_nn_graph.plan
if self._enable_shared_from_this == True:
if self._enable_shared_from_this:
destination["forward_graph"] = self._forward_job_proto
destination["compile_graph"] = self._compiled_job_proto

Expand All @@ -1152,7 +1152,7 @@ def load_runtime_state_dict(
*,
warmup_with_run: bool = True,
) -> None:
if self._run_with_cache == True:
if self._run_with_cache:
return self._dynamic_input_graph_cache.load_runtime_state_dict(
state_dict, warmup_with_run=warmup_with_run
)
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
self._eager_outputs = _eager_outputs

# The base graph need extra info to create new shared graph
if self._enable_shared_from_this == True:
if self._enable_shared_from_this:
self._forward_job_proto = state_dict["forward_graph"]
self._compiled_job_proto = state_dict["compile_graph"]
self._build_eager_outputs = self._eager_outputs
Expand Down

0 comments on commit 0320ed0

Please sign in to comment.