Skip to content

Commit

Permalink
Make non_linear interpolation independent of a chosen value type
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Mar 4, 2024
1 parent 4e2a055 commit 92188d8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 52 deletions.
63 changes: 14 additions & 49 deletions src/atlas/interpolation/nonlinear/Missing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,71 +18,36 @@ namespace interpolation {
namespace nonlinear {


#define B NonLinearFactoryBuilder
#define M1 MissingIfAllMissing
#define M2 MissingIfAnyMissing
#define M3 MissingIfHeaviestMissing
#define T array::DataType::str

static NonLinearFactoryBuilder<MissingIfAllMissing> __nl1(MissingIfAllMissing::static_type());
static NonLinearFactoryBuilder<MissingIfAnyMissing> __nl2(MissingIfAnyMissing::static_type());
static NonLinearFactoryBuilder<MissingIfHeaviestMissing> __nl3(MissingIfHeaviestMissing::static_type());

static B<M1<double>> __nl1(M1<double>::static_type());
static B<M1<double>> __nl2(M1<double>::static_type() + "-" + T<double>());
static B<M1<float>> __nl3(M1<float>::static_type() + "-" + T<float>());
static B<M1<int>> __nl4(M1<int>::static_type() + "-" + T<int>());
static B<M1<long>> __nl5(M1<long>::static_type() + "-" + T<long>());
static B<M1<unsigned long>> __nl6(M1<unsigned long>::static_type() + "-" + T<unsigned long>());

static B<M2<double>> __nl7(M2<double>::static_type());
static B<M2<double>> __nl8(M2<double>::static_type() + "-" + T<double>());
static B<M2<float>> __nl9(M2<float>::static_type() + "-" + T<float>());
static B<M2<int>> __nl10(M2<int>::static_type() + "-" + T<int>());
static B<M2<long>> __nl11(M2<long>::static_type() + "-" + T<long>());
static B<M2<unsigned long>> __nl12(M2<unsigned long>::static_type() + "-" + T<unsigned long>());

static B<M3<double>> __nl13(M3<double>::static_type());
static B<M3<double>> __nl14(M3<double>::static_type() + "-" + T<double>());
static B<M3<float>> __nl15(M3<float>::static_type() + "-" + T<float>());
static B<M3<int>> __nl16(M3<int>::static_type() + "-" + T<int>());
static B<M3<long>> __nl17(M3<long>::static_type() + "-" + T<long>());
static B<M3<unsigned long>> __nl18(M3<unsigned long>::static_type() + "-" + T<unsigned long>());
// Deprecated factory entries with "-real32" and "-real64" suffix for backwards compatibility.
static NonLinearFactoryBuilder<MissingIfAllMissing> __nl1_real32(MissingIfAllMissing::static_type()+"-real32");
static NonLinearFactoryBuilder<MissingIfAnyMissing> __nl2_real32(MissingIfAnyMissing::static_type()+"-real32");
static NonLinearFactoryBuilder<MissingIfHeaviestMissing> __nl3_real32(MissingIfHeaviestMissing::static_type()+"-real32");
static NonLinearFactoryBuilder<MissingIfAllMissing> __nl1_real64(MissingIfAllMissing::static_type()+"-real64");
static NonLinearFactoryBuilder<MissingIfAnyMissing> __nl2_real64(MissingIfAnyMissing::static_type()+"-real64");
static NonLinearFactoryBuilder<MissingIfHeaviestMissing> __nl3_real64(MissingIfHeaviestMissing::static_type()+"-real64");

namespace {
struct force_link {
template <typename M>
void load_builder() {
B<M>("tmp");
NonLinearFactoryBuilder<M>("tmp");
}
force_link() {
load_builder<M1<double>>();
load_builder<M1<float>>();
load_builder<M1<int>>();
load_builder<M1<long>>();
load_builder<M1<unsigned long>>();

load_builder<M2<double>>();
load_builder<M2<float>>();
load_builder<M2<int>>();
load_builder<M2<long>>();
load_builder<M2<unsigned long>>();

load_builder<M3<double>>();
load_builder<M3<float>>();
load_builder<M3<int>>();
load_builder<M3<long>>();
load_builder<M3<unsigned long>>();
load_builder<MissingIfAllMissing>();
load_builder<MissingIfAnyMissing>();
load_builder<MissingIfHeaviestMissing>();
}
};
} // namespace
void force_link_missing() {
static force_link static_linking;
}

#undef T
#undef M3
#undef M2
#undef M1
#undef B


} // namespace nonlinear
} // namespace interpolation
Expand Down
39 changes: 36 additions & 3 deletions src/atlas/interpolation/nonlinear/Missing.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "atlas/field/MissingValue.h"
#include "atlas/interpolation/nonlinear/NonLinear.h"
#include "atlas/util/DataType.h"


namespace atlas {
Expand All @@ -29,9 +30,19 @@ struct Missing : NonLinear {
};


template <typename T>
struct MissingIfAllMissing : Missing {
bool execute(NonLinear::Matrix& W, const Field& field) const {
switch(field.datatype().kind()) {
case (DataType::kind<double>()): return executeT<double>(W,field);
case (DataType::kind<float>()): return executeT<float>(W,field);
case (DataType::kind<int>()): return executeT<int>(W,field);
case (DataType::kind<long>()): return executeT<long>(W,field);
case (DataType::kind<unsigned long>()): return executeT<unsigned long>(W,field);
default: ATLAS_NOTIMPLEMENTED;
}
}
template<typename T>
bool executeT(NonLinear::Matrix& W, const Field& field) const {
field::MissingValue mv(field);
auto& missingValue = mv.ref();

Expand Down Expand Up @@ -104,9 +115,20 @@ struct MissingIfAllMissing : Missing {
};


template <typename T>
struct MissingIfAnyMissing : Missing {
bool execute(NonLinear::Matrix& W, const Field& field) const {
switch(field.datatype().kind()) {
case (DataType::kind<double>()): return executeT<double>(W,field);
case (DataType::kind<float>()): return executeT<float>(W,field);
case (DataType::kind<int>()): return executeT<int>(W,field);
case (DataType::kind<long>()): return executeT<long>(W,field);
case (DataType::kind<unsigned long>()): return executeT<unsigned long>(W,field);
default: ATLAS_NOTIMPLEMENTED;
}
}

template<typename T>
bool executeT(NonLinear::Matrix& W, const Field& field) const {
field::MissingValue mv(field);
auto& missingValue = mv.ref();

Expand Down Expand Up @@ -165,9 +187,20 @@ struct MissingIfAnyMissing : Missing {
};


template <typename T>
struct MissingIfHeaviestMissing : Missing {
bool execute(NonLinear::Matrix& W, const Field& field) const {
switch(field.datatype().kind()) {
case (DataType::kind<double>()): return executeT<double>(W,field);
case (DataType::kind<float>()): return executeT<float>(W,field);
case (DataType::kind<int>()): return executeT<int>(W,field);
case (DataType::kind<long>()): return executeT<long>(W,field);
case (DataType::kind<unsigned long>()): return executeT<unsigned long>(W,field);
default: ATLAS_NOTIMPLEMENTED;
}
}

template<typename T>
bool executeT(NonLinear::Matrix& W, const Field& field) const {
field::MissingValue mv(field);
auto& missingValue = mv.ref();

Expand Down
75 changes: 75 additions & 0 deletions src/tests/interpolation/test_interpolation_non_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,81 @@ CASE("Interpolation of rank 2 field with MissingValue") {
}


CASE("Interpolation with MissingValue on fieldset with heterogeneous type") {

auto init_field = [](Field& field){
field.metadata().set("missing_value", missingValue);
field.metadata().set("missing_value_epsilon", missingValueEps);
field.metadata().set("missing_value_type", "equals");

auto init_view = [](auto&& view) {
for (idx_t j = 0; j < view.shape(0); ++j) {
view(j) = 1;
}
view(4) = missingValue;
};
if (field.datatype().kind() == DataType::KIND_REAL32) {
init_view(array::make_view<float, 1>(field));
}
if (field.datatype().kind() == DataType::KIND_REAL64) {
init_view(array::make_view<double, 1>(field));
}
return false;
};

/*
Set input field full of 1's, with 9 nodes
1 ... 1 ... 1
: : :
1-----m ... 1 m: missing value
|i i| : i: interpolation on two points, this quadrilateral only
1-----1 ... 1
*/
RectangularDomain domain({0, 2}, {0, 2}, "degrees");
Grid gridA("L90", domain);

const idx_t nbNodes = 9;
ATLAS_ASSERT(gridA.size() == nbNodes);

Mesh meshA = MeshGenerator("structured").generate(gridA);

functionspace::NodeColumns fsA(meshA);
FieldSet fieldsA;
fieldsA.add(fsA.createField<double>(option::name("A_r64")));
fieldsA.add(fsA.createField<float>(option::name("A_r32")));

init_field(fieldsA["A_r64"]);
init_field(fieldsA["A_r32"]);


// Set output field (2 points)
functionspace::PointCloud fsB({PointLonLat{0.1, 0.1}, PointLonLat{0.9, 0.9}});
FieldSet fieldsB;
fieldsB.add(fsB.createField<double>(option::name("B_r64")));
fieldsB.add(fsB.createField<float>(option::name("B_r32")));

auto interpolate = [&](const std::string& missing_type) {
Config config;
config.set("type", "finite-element");
config.set("non_linear", missing_type);
Interpolation interpolation(config, fsA, fsB);
interpolation.execute(fieldsA, fieldsB);
};

SECTION( "missing-if-any-missing" ) {
interpolate("missing-if-any-missing");
}

SECTION( "missing-if-all-missing" ) {
interpolate("missing-if-all-missing");
}

SECTION( "missing-if-heaviest-missing" ) {
interpolate("missing-if-heaviest-missing");
}

}

} // namespace test
} // namespace atlas

Expand Down

0 comments on commit 92188d8

Please sign in to comment.