Skip to content

Commit

Permalink
Fix testing tools NaN comparison and correct new failing tests. (#4794
Browse files Browse the repository at this point in the history
)

* Initial commit of emstd bug fix and test.

* Initial commit of NaN test tooling bug-fix.

* Correct post-merge test failure.

* Some PR comments addressed

* Numeric class modified to better handle NaN and Inf values, new tests to verify standard behavior.

* Expanded the NaN and Inf test code to document our expectations.

* Added new tests to cover some short circuit behavior.

* Update engine/function/src/templates/Numeric.ftl

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>

* Update engine/function/src/templates/Numeric.ftl

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>

---------

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>
  • Loading branch information
lbooker42 and chipkent authored Nov 16, 2023
1 parent 8123f7c commit cd416bd
Show file tree
Hide file tree
Showing 24 changed files with 495 additions and 144 deletions.
169 changes: 86 additions & 83 deletions engine/function/src/templates/Numeric.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ public class Numeric {
try ( final ${pt.vectorIterator} vi = values.iterator() ) {
while ( vi.hasNext() ) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
if (isNaN(c)) {
return Double.NaN;
}
if (!isNull(c)) {
sum += c;
count++;
Expand Down Expand Up @@ -416,6 +419,12 @@ public class Numeric {
try ( final ${pt.vectorIterator} vi = values.iterator() ) {
while ( vi.hasNext() ) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
if (isNaN(c)) {
return Double.NaN;
}
if (isInf(c)) {
return Double.POSITIVE_INFINITY;
}
if (!isNull(c)) {
sum += Math.abs(c);
count++;
Expand Down Expand Up @@ -472,11 +481,13 @@ public class Numeric {

double sum = 0;
double sum2 = 0;
double count = 0;

long count = 0;
try ( final ${pt.vectorIterator} vi = values.iterator() ) {
while ( vi.hasNext() ) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
if (isNaN(c) || isInf(c)) {
return Double.NaN;
}
if (!isNull(c)) {
sum += (double)c;
sum2 += (double)c * (double)c;
Expand All @@ -485,19 +496,19 @@ public class Numeric {
}
}

// Return NaN if poisoned or too few values to compute sample variance.
if (count <= 1 || Double.isNaN(sum) || Double.isNaN(sum2)) {
// Return NaN if overflow or too few values to compute variance.
if (count <= 1 || Double.isInfinite(sum) || Double.isInfinite(sum2)) {
return Double.NaN;
}

// Perform the calculation in a way that minimizes the impact of floating point error.
final double eps = Math.ulp(sum2);
final double vs2bar = sum * (sum / count);
final double vs2bar = sum * (sum / (double)count);
final double delta = sum2 - vs2bar;
final double rel_eps = delta / eps;

// Return zero when the sample variance is leq the floating point error.
return Math.abs(rel_eps) > 1.0 ? delta / (count - 1) : 0.0;
return Math.abs(rel_eps) > 1.0 ? delta / ((double)count - 1) : 0.0;
}

<#list primitiveTypes as pt2>
Expand Down Expand Up @@ -590,7 +601,12 @@ public class Numeric {
while (vi.hasNext()) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
final ${pt2.primitive} w = wi.${pt2.iteratorNext}();

if (isNaN(c) || isInf(c)) {
return Double.NaN;
}
if (isNaN(w) || isInf(w)) {
return Double.NaN;
}
if (!isNull(c) && !isNull(w)) {
sum += w * c;
sum2 += w * c * c;
Expand All @@ -600,8 +616,8 @@ public class Numeric {
}
}

// Return NaN if poisoned or too few values to compute sample variance.
if (count <= 1 || Double.isNaN(sum) || Double.isNaN(sum2) || Double.isNaN(count) || Double.isNaN(count2)) {
// Return NaN if overflow or too few values to compute variance.
if (count <= 1 || Double.isInfinite(sum) || Double.isInfinite(sum2)) {
return Double.NaN;
}

Expand Down Expand Up @@ -1333,6 +1349,12 @@ public class Numeric {
while (v0i.hasNext()) {
final ${pt.primitive} v0 = v0i.${pt.iteratorNext}();
final ${pt2.primitive} v1 = v1i.${pt2.iteratorNext}();
if (isNaN(v0) || isInf(v0)) {
return Double.NaN;
}
if (isNaN(v1) || isInf(v1)) {
return Double.NaN;
}

if (!isNull(v0) && !isNull(v1)) {
sum0 += v0;
Expand Down Expand Up @@ -1421,6 +1443,12 @@ public class Numeric {
while (v0i.hasNext()) {
final ${pt.primitive} v0 = v0i.${pt.iteratorNext}();
final ${pt2.primitive} v1 = v1i.${pt2.iteratorNext}();
if (isNaN(v0) || isInf(v0)) {
return Double.NaN;
}
if (isNaN(v1) || isInf(v1)) {
return Double.NaN;
}

if (!isNull(v0) && !isNull(v1)) {
sum0 += v0;
Expand Down Expand Up @@ -1460,6 +1488,11 @@ public class Numeric {
try ( final ${pt.vectorIterator} vi = values.iterator() ) {
while ( vi.hasNext() ) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
<#if pt.valueType.isFloat >
if (isNaN(c)) {
return ${pt.boxed}.NaN;
}
</#if>
if (!isNull(c)) {
sum += c;
}
Expand Down Expand Up @@ -1496,10 +1529,33 @@ public class Numeric {

${pt.primitive} prod = 1;
int count = 0;
<#if pt.valueType.isFloat >
long zeroCount = 0;
long infCount = 0;
</#if>

try ( final ${pt.vectorIterator} vi = values.iterator() ) {
while ( vi.hasNext() ) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
<#if pt.valueType.isFloat >
if (isNaN(c)) {
return ${pt.boxed}.NaN;
} else if (Double.isInfinite(c)) {
if (zeroCount > 0) {
return ${pt.boxed}.NaN;
}
infCount++;
} else if (c == 0) {
if (infCount > 0) {
return ${pt.boxed}.NaN;
}
zeroCount++;
}
<#else>
if (c == 0) {
return 0;
}
</#if>
if (!isNull(c)) {
count++;
prod *= c;
Expand All @@ -1511,7 +1567,11 @@ public class Numeric {
return ${pt.null};
}

<#if pt.valueType.isFloat >
return zeroCount > 0 ? 0 : (${pt.primitive}) (prod);
<#else>
return (${pt.primitive}) (prod);
</#if>
}

/**
Expand Down Expand Up @@ -1549,24 +1609,7 @@ public class Numeric {
return null;
}

if (values.length == 0) {
return new ${pt.primitive}[0];
}

${pt.primitive}[] result = new ${pt.primitive}[values.length];
result[0] = values[0];

for (int i = 1; i < values.length; i++) {
if (isNull(result[i - 1])) {
result[i] = values[i];
} else if (isNull(values[i])) {
result[i] = result[i - 1];
} else {
result[i] = (${pt.primitive})Math.min(result[i - 1], values[i]);
}
}

return result;
return cummin(new ${pt.vectorDirect}(values));
}

/**
Expand Down Expand Up @@ -1630,24 +1673,7 @@ public class Numeric {
return null;
}

if (values.length == 0) {
return new ${pt.primitive}[0];
}

${pt.primitive}[] result = new ${pt.primitive}[values.length];
result[0] = values[0];

for (int i = 1; i < values.length; i++) {
if (isNull(result[i - 1])) {
result[i] = values[i];
} else if (isNull(values[i])) {
result[i] = result[i - 1];
} else {
result[i] = (${pt.primitive})Math.max(result[i - 1], values[i]);
}
}

return result;
return cummax(new ${pt.vectorDirect}(values));
}

/**
Expand Down Expand Up @@ -1711,24 +1737,7 @@ public class Numeric {
return null;
}

if (values.length == 0) {
return new ${pt.primitive}[0];
}

${pt.primitive}[] result = new ${pt.primitive}[values.length];
result[0] = values[0];

for (int i = 1; i < values.length; i++) {
if (isNull(result[i - 1])) {
result[i] = values[i];
} else if (isNull(values[i])) {
result[i] = result[i - 1];
} else {
result[i] = (${pt.primitive}) (result[i - 1] + values[i]);
}
}

return result;
return cumsum(new ${pt.vectorDirect}(values));
}

/**
Expand Down Expand Up @@ -1792,24 +1801,7 @@ public class Numeric {
return null;
}

if (values.length == 0) {
return new ${pt.primitive}[0];
}

${pt.primitive}[] result = new ${pt.primitive}[values.length];
result[0] = values[0];

for (int i = 1; i < values.length; i++) {
if (isNull(result[i - 1])) {
result[i] = values[i];
} else if (isNull(values[i])) {
result[i] = result[i - 1];
} else {
result[i] = (${pt.primitive}) (result[i - 1] * values[i]);
}
}

return result;
return cumprod(new ${pt.vectorDirect}(values));
}

/**
Expand Down Expand Up @@ -2322,7 +2314,13 @@ public class Numeric {
while (vi.hasNext()) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
final ${pt2.primitive} w = wi.${pt2.iteratorNext}();

if (isNaN(c)) {
return Double.NaN;
}
if (isNaN(w)) {
return Double.NaN;
}

if (!isNull(c) && !isNull(w)) {
vsum += c * w;
}
Expand Down Expand Up @@ -2405,7 +2403,12 @@ public class Numeric {
while (vi.hasNext()) {
final ${pt.primitive} c = vi.${pt.iteratorNext}();
final ${pt2.primitive} w = wi.${pt2.iteratorNext}();

if (isNaN(c)) {
return Double.NaN;
}
if (isNaN(w)) {
return Double.NaN;
}
if (!isNull(c) && !isNull(w)) {
vsum += c * w;
wsum += w;
Expand Down
Loading

0 comments on commit cd416bd

Please sign in to comment.