Skip to content

Commit

Permalink
plt update net instance, test=model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeref996 committed Aug 30, 2024
1 parent 430e1ea commit b3efb14
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions framework/e2e/PaddleLT_new/strategy/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1
if isinstance(expect, str):
raise Exception("expect is exception !!!")

if isinstance(expect, eval(f"{framework}.Tensor")) or isinstance(expect, np.ndarray):
if expect is None or result is None:
if expect is None:
Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
if result is None:
Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
pass
elif isinstance(expect, eval(f"{framework}.Tensor")) or isinstance(expect, np.ndarray):
if isinstance(result, eval(f"{framework}.Tensor")):
if framework == "torch":
result = result.detach().numpy()
Expand Down Expand Up @@ -136,12 +142,6 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1
)
elif isinstance(expect, (bool, int, float)):
assert expect == result
elif expect is None or result is None:
if expect is None:
Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
if result is None:
Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
pass
else:
raise Exception("expect is unknown data struction in compare_tool!!!")

Expand Down

0 comments on commit b3efb14

Please sign in to comment.