Skip to content

Commit

Permalink
Refined compare_asce [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
micheles committed Jul 30, 2024
1 parent 8578170 commit b0e7ea0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions openquake/commands/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,10 @@ def delta(a, b):
return res


def compare_column_values(array0, array1, what, rtol=1E-5):
def compare_column_values(array0, array1, what, atol=0, rtol=1E-5):
if isinstance(array0[0], (float, numpy.float32, numpy.float64)):
diff_idxs = numpy.where(delta(array0, array1) > rtol)[0]
diff = numpy.abs(array0 - array1)
diff_idxs = numpy.where(diff > atol + (array0+array1)/2 * rtol)[0]
else:
diff_idxs = numpy.where(array0 != array1)[0]
if len(diff_idxs) == 0:
Expand Down Expand Up @@ -476,7 +477,7 @@ def read_org_df(fname):
return df.rename(columns=dict(zip(df.columns, strip(df.columns))))


def compare_asce(file1_org: str, file2_org: str):
def compare_asce(file1_org: str, file2_org: str, atol=1E-3, rtol=1E-3):
"""
compare_asce('asce07.org', 'asce07_expected.org') exits with 0
if all values are equal within the tolerance, otherwise with 1.
Expand All @@ -487,7 +488,7 @@ def compare_asce(file1_org: str, file2_org: str):
for col in df1.columns:
ok = compare_column_values(strip(df1[col].to_numpy()),
strip(df2[col].to_numpy()),
col, rtol=1E-2)
col, atol, rtol)
equal.append(ok)
sys.exit(not all(equal))

Expand Down

0 comments on commit b0e7ea0

Please sign in to comment.