diff --git a/framework/e2e/PaddleLT_new/strategy/compare.py b/framework/e2e/PaddleLT_new/strategy/compare.py index 039c0547cf..9bec4389dd 100644 --- a/framework/e2e/PaddleLT_new/strategy/compare.py +++ b/framework/e2e/PaddleLT_new/strategy/compare.py @@ -40,7 +40,13 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1 if isinstance(expect, str): raise Exception("expect is exception !!!") - if isinstance(expect, eval(f"{framework}.Tensor")) or isinstance(expect, np.ndarray): + if expect is None or result is None: + if expect is None: + Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") + if result is None: + Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") + pass + elif isinstance(expect, eval(f"{framework}.Tensor")) or isinstance(expect, np.ndarray): if isinstance(result, eval(f"{framework}.Tensor")): if framework == "torch": result = result.detach().numpy() @@ -136,12 +142,6 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1 ) elif isinstance(expect, (bool, int, float)): assert expect == result - elif expect is None or result is None: - if expect is None: - Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") - if result is None: - Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") - pass else: raise Exception("expect is unknown data struction in compare_tool!!!")