Skip to content

Commit

Permalink
modify pir in distribute api test (#2946)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoguoguo626807 authored Sep 12, 2024
1 parent 7d678d6 commit f61a640
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion distributed/CE_API/case/dist_fleet_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ def mlp(input_x, input_y, hid_dim=128, label_dim=2):
step = 5
train_info = []
for i in range(step):
cost_val = exe.run(program=paddle.static.default_main_program(), feed=gen_data(), fetch_list=[cost.name])
cost_val = exe.run(program=paddle.static.default_main_program(), feed=gen_data(), fetch_list=[cost])
train_info.append(cost_val[0])
print(train_info)
2 changes: 1 addition & 1 deletion distributed/CE_API/case/dist_train_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ def mlp(input_x, input_y, hid_dim=128, label_dim=2):
step = 5
train_info = []
for i in range(step):
cost_val = exe.run(program=paddle.static.default_main_program(), feed=gen_data(), fetch_list=[cost.name])
cost_val = exe.run(program=paddle.static.default_main_program(), feed=gen_data(), fetch_list=[cost])
train_info.append(cost_val[0])
print(train_info)

0 comments on commit f61a640

Please sign in to comment.