diff --git a/framework/e2e/PaddleLT_new/strategy/compare.py b/framework/e2e/PaddleLT_new/strategy/compare.py index 1a32c36f6d..039c0547cf 100644 --- a/framework/e2e/PaddleLT_new/strategy/compare.py +++ b/framework/e2e/PaddleLT_new/strategy/compare.py @@ -107,7 +107,7 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1 exc_dict=exc_dict, ) else: - Logger("PLT_compare").get_log().info(f"expect有 {k}, 但是result没有 {k}, 所以跳过 {k} 精度对比") + Logger("PLT_compare").get_log().info(f"{exp_name} 有 {k}, 但是 {res_name} 没有 {k}, 所以跳过 {k} 精度对比") elif isinstance(expect, list) or isinstance(expect, tuple): for i, element in enumerate(expect): if isinstance(result, (np.generic, np.ndarray)) or isinstance(result, eval(f"{framework}.Tensor")): @@ -136,7 +136,11 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1 ) elif isinstance(expect, (bool, int, float)): assert expect == result - elif expect is None: + elif expect is None or result is None: + if expect is None: + Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") + if result is None: + Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name} 和 {res_name} 精度对比") pass else: raise Exception("expect is unknown data struction in compare_tool!!!")