Skip to content

Commit

Permalink
plt update net instance, test=model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeref996 committed Sep 2, 2024
1 parent b3efb14 commit fb02edd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 2 deletions.
8 changes: 6 additions & 2 deletions framework/e2e/PaddleLT_new/layertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import traceback

# from engine.engine_map import engine_map
from strategy.compare import base_compare
from strategy.compare import base_compare, infer_compare
from tools.yaml_loader import YamlLoader
from tools.logger import Logger
from tools.res_save import save_tensor, load_tensor, save_pickle
Expand Down Expand Up @@ -163,7 +163,11 @@ def _case_run(self):
compare_res_list.append(tmp)
else:
precision = comparing.get("precision")
compare_res = base_compare(
if comparing.get("compare_method") == "infer_compare":
compare_methon = infer_compare
else:
compare_methon = base_compare
compare_res = compare_methon(
result=result,
expect=expect,
res_name=latest,
Expand Down
37 changes: 37 additions & 0 deletions framework/e2e/PaddleLT_new/strategy/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,43 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1
return exc_dict


def infer_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1e-10, exc_dict={}):
"""
比较函数
:param result: 待测值
:param expect: 基线值
:param delta: 误差值
:param rtol: 相对误差
:return:
"""
# 去除反向结果的数据
forward_handled_result = {"logit": []}
forward_handled_expect = {"logit": []}

# 去除非tensor数值的影响
if isinstance(expect["logit"], (tuple, list)):
for item in expect["logit"]:
if not isinstance(item, (int, bool, float)):
forward_handled_expect["logit"].append(item)

if isinstance(result["logit"], (tuple, list)):
for item in result["logit"]:
if not isinstance(item, (int, bool, float)):
forward_handled_result["logit"].append(item)

exc_dict = base_compare(
result=result,
expect=expect,
res_name=res_name,
exp_name=exp_name,
logger=logger,
delta=1e-10,
rtol=1e-10,
exc_dict=exc_dict,
)
return exc_dict


def perf_compare_legacy(baseline, latest):
"""
比较函数
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ compare:
-
baseline: 'dy_train'
latest: 'dy2st_train_static_inputspec'
compare_method: "base_compare"
precision:
delta: 0.00001
rtol: 0.000001
-
baseline: 'dy_train'
latest: 'paddle_infer_new_exc_pir'
compare_method: "infer_compare"
precision:
delta: 0.00001
rtol: 0.000001
-
baseline: 'dy2st_train_static_inputspec'
latest: 'paddle_infer_new_exc_pir'
compare_method: "infer_compare"
precision:
delta: 0.00001
rtol: 0.000001
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ compare:
-
baseline: 'dy_train'
latest: 'dy2st_train_cinn'
compare_method: "base_compare"
precision:
delta: 0.00001
rtol: 0.000001
-
baseline: 'dy_train'
latest: 'paddle_infer_new_exc_pir'
compare_method: "infer_compare"
precision:
delta: 0.00001
rtol: 0.000001
-
baseline: 'dy2st_train_cinn'
latest: 'paddle_infer_new_exc_pir'
compare_method: "infer_compare"
precision:
delta: 0.00001
rtol: 0.000001

0 comments on commit fb02edd

Please sign in to comment.