From fd4137871930128f2936e6292390cbf1b4d8aed1 Mon Sep 17 00:00:00 2001 From: Ed J Date: Fri, 27 Sep 2024 23:48:02 +0000 Subject: [PATCH] Primitive::matmult handle bad values, treating like NaN --- Basic/Primitive/primitive.pd | 10 +++++++--- Changes | 1 + Libtmp/Transform/transform.pd | 4 ++-- t/primitive-matmult.t | 17 +++++++++++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/Basic/Primitive/primitive.pd b/Basic/Primitive/primitive.pd index 81d983d85..c7a171409 100644 --- a/Basic/Primitive/primitive.pd +++ b/Basic/Primitive/primitive.pd @@ -197,14 +197,14 @@ L method. EOD pp_def('matmult', - HandleBad=>0, + HandleBad=>1, Pars => 'a(t,h); b(w,t); [o]c(w,h);', GenericTypes => [ppdefs_all], PMCode => pp_line_numbers(__LINE__, <<'EOPM'), sub PDL::matmult { my ($x,$y,$c) = @_; $y = PDL->topdl($y); - $c = PDL->null unless do { local $@; eval { $c->isa('PDL') } }; + $c = PDL->null if !UNIVERSAL::isa($c, 'PDL'); while($x->getndims < 2) {$x = $x->dummy(-1)} while($y->getndims < 2) {$y = $y->dummy(-1)} return ($c .= $x * $y) if( ($x->dim(0)==1 && $x->dim(1)==1) || @@ -231,17 +231,21 @@ loop (h=::tsiz,w=::tsiz) %{ loop (t=::tsiz,h=h_outer:h_outer+tsiz,w=w_outer:w_outer+tsiz) %{ // Cache the accumulated value for the output $GENERIC() cc = $c(); + PDL_IF_BAD(if ($ISBADVAR(cc,c)) continue;,) // Cache data pointers before 't' run through tile $GENERIC() *ad = &($a()); $GENERIC() *bd = &($b()); // Hotspot - run the 't' summation PDL_Indx t_outer = t; + PDL_IF_BAD(char c_isbad = 0;,) loop (t=t_outer:t_outer+tsiz) %{ + PDL_IF_BAD(if ($ISBADVAR(*ad,a) || $ISBADVAR(*bd,b)) { c_isbad = 1; break; },) cc += *ad * *bd; ad += atdi; bd += btdi; %} // put the output back to be further accumulated later + PDL_IF_BAD(if (c_isbad) { $SETBAD(c()); continue; },) $c() = cc; %} %} @@ -264,7 +268,7 @@ footprint within cache as long as possible on most modern CPUs. For usage, see L, a description of the overloaded 'x' operator EOD - ); +); pp_def('innerwt', HandleBad => 1, diff --git a/Changes b/Changes index 3e4699fa6..dff8e91ed 100644 --- a/Changes +++ b/Changes @@ -10,6 +10,7 @@ - IO::GD add OO to_rpic for ndarrays (3,x,y) if truecolour, y=0 at bottom, like rpic - IO::GD stop read_true_png segfaulting with non-true PNG - IO::Pic use IO::GD for JPEG if available, helps Windows with no NetPBM +- Primitive::matmult handle bad values, treating like NaN 2.092 2024-09-07 - add Type::howbig diff --git a/Libtmp/Transform/transform.pd b/Libtmp/Transform/transform.pd index 697420edc..29a816b20 100644 --- a/Libtmp/Transform/transform.pd +++ b/Libtmp/Transform/transform.pd @@ -2092,8 +2092,8 @@ sub compose { $data; }; $me->{inv} = sub { - my($data,$p) = @_; - my($ip) = $data->is_inplace; + my ($data,$p) = @_; + my $ip = $data->is_inplace; for my $t ( @{$p->{clist}} ) { croak("Error: tried to invert a non-invertible PDL::Transform inside a composition!\n offending transform: $t\n") unless(defined($t->{inv}) and ref($t->{inv}) eq 'CODE'); diff --git a/t/primitive-matmult.t b/t/primitive-matmult.t index c5f0f1fb3..528d43e9f 100644 --- a/t/primitive-matmult.t +++ b/t/primitive-matmult.t @@ -77,4 +77,21 @@ ok tapprox( PB() x 2, PB() * 2, 'ndarray x Perl scalar' ); ok tapprox( pdl(3) x PB(), PB() *3 ), '1D ndarray x ndarray'; +subtest 'nans' => sub { + my $A = pdl '[1 nan 0; 0 1 0; 0 0 1]'; + my $B = PDL->sequence(2,3); + my $C = $A x $B; + $C->inplace->setnantobad; + $C->inplace->setbadtoval(6); + ok tapprox($C, pdl '[6 6; 2 3; 4 5]'); +}; + +subtest 'badvals' => sub { + my $A = pdl '[1 BAD 0; 0 1 0; 0 0 1]'; + my $B = PDL->sequence(2,3); + my $C = $A x $B; + $C->inplace->setbadtoval(6); + ok tapprox($C, pdl '[6 6; 2 3; 4 5]'); +}; + done_testing;