Skip to content

Commit

Permalink
reduce memory requirements in the sandbox tool -atlas-interpolation-eoc-
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrdar committed Aug 2, 2023
1 parent 5902e18 commit 8af09d5
Showing 1 changed file with 36 additions and 42 deletions.
78 changes: 36 additions & 42 deletions src/sandbox/interpolation/atlas-interpolation-eoc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AtlasEOAComputation : public AtlasTool {
add_option(new SimpleOption<std::string>("type","Type of interpolation: bilinear, cubic, ..."));
add_option(new SimpleOption<std::string>("order","Order of interpolation, when applicable"));
add_option(new SimpleOption<std::string>("interpolation.structured","Is the interpolation for structured grids"));
add_option(new SimpleOption<long>("k-nearest-neighbours", "The number of neighbours in k-nearest-neighbours"));

add_option(new Separator("Initial data"));
add_option(new SimpleOption<bool>("init_via_highres", "Get initial data by remapping a highres grid data"));
Expand Down Expand Up @@ -135,24 +136,24 @@ std::function<double(const PointLonLat&)> get_init(const AtlasTool::Args& args)
else if (init == "solid_body_rotation_wind_magnitude") {
double beta;
args.get("solid_body_rotation.angle", beta = 0.);
util::function::SolidBodyRotation sbr(beta);
return [sbr](const PointLonLat& p) { return sbr.windMagnitude(p.lon(), p.lat()); };
util::function::SolidBodyRotation func(beta);
return [func](const PointLonLat& p) { return func.windMagnitude(p.lon(), p.lat()); };
}
else if (init == "MDPI_sinusoid") {
auto sbr = util::function::MDPI_sinusoid;
return [sbr](const PointLonLat& p) { return sbr(p.lon(), p.lat()); };
auto func = util::function::MDPI_sinusoid;
return [func](const PointLonLat& p) { return func(p.lon(), p.lat()); };
}
else if (init == "MDPI_harmonic") {
auto sbr = util::function::MDPI_harmonic;
return [sbr](const PointLonLat& p) { return sbr(p.lon(), p.lat()); };
auto func = util::function::MDPI_harmonic;
return [func](const PointLonLat& p) { return func(p.lon(), p.lat()); };
}
else if (init == "MDPI_vortex") {
auto sbr = util::function::MDPI_vortex;
return [sbr](const PointLonLat& p) { return sbr(p.lon(), p.lat()); };
auto func = util::function::MDPI_vortex;
return [func](const PointLonLat& p) { return func(p.lon(), p.lat()); };
}
else if (init == "MDPI_gulfstream") {
auto sbr = util::function::MDPI_gulfstream;
return [sbr](const PointLonLat& p) { return sbr(p.lon(), p.lat()); };
auto func = util::function::MDPI_gulfstream;
return [func](const PointLonLat& p) { return func(p.lon(), p.lat()); };
}
else {
if (args.has("init")) {
Expand Down Expand Up @@ -191,7 +192,7 @@ FunctionSpace create_functionspace(Mesh& mesh, int halo, std::string type, bool

void compute_errors(const Field source, const Field target,
std::function<double(const PointLonLat&)> func,
Mesh src_mesh, Mesh tgt_mesh, util::Config& stats) {
Mesh src_mesh, Mesh tgt_mesh, int eoc_cycles, util::Config& stats) {
auto src_vals = array::make_view<double, 1>(source);
auto tgt_vals = array::make_view<double, 1>(target);
const auto src_node_ghost = array::make_view<int, 1>(src_mesh.nodes().ghost());
Expand All @@ -201,8 +202,12 @@ void compute_errors(const Field source, const Field target,

// get tgt_points and tgt_areas
// TODO: no need for polygon intersections here, we just need src_points and src_areas
int src_halo = 2;
while(eoc_cycles-- > 1) {
src_halo *= 2;
}
auto src_fs = create_functionspace(src_mesh, 2, "NodeColumns", 0);
auto tgt_fs = create_functionspace(tgt_mesh, 2, "NodeColumns", 0);
auto tgt_fs = create_functionspace(tgt_mesh, 1, "NodeColumns", 0);
auto interpolation = CSPInterpolation();
interpolation.do_setup(src_fs, tgt_fs);

Expand Down Expand Up @@ -281,18 +286,6 @@ void compute_errors(const Field source, const Field target,
double err_cons2 = tgt_mass - src_mass;
double relcons_src = 100. * err_cons/src_mass;
double relcons_tgt = 100. * err_cons/tgt_mass;
/*
std::cout << "src l2 error : " << serr_remap_l2 << std::endl;
std::cout << "src l_inf error : " << serr_remap_linf << std::endl;
std::cout << "tgt l2 error : " << terr_remap_l2 << std::endl;
std::cout << "tgt l_inf error : " << terr_remap_linf << std::endl;
std::cout << "conservation error : " << err_cons2 << std::endl;
std::cout << "total src mass : " << src_mass << std::endl;
std::cout << "total tgt mass : " << tgt_mass << std::endl;
std::cout << "rel. cons. error on src : " << 100. * err_cons/src_mass << " %" << std::endl;
std::cout << "rel. cons. error on tgt : " << 100. * err_cons/tgt_mass << " %" << std::endl;
std::cout << "==================" << std::endl;
*/
stats.set("err.ana2src_l2", serr_remap_l2);
stats.set("err.ana2src_linf", serr_remap_linf);
stats.set("err.ana2tgt_l2", terr_remap_l2);
Expand Down Expand Up @@ -328,17 +321,19 @@ int AtlasEOAComputation::execute(const AtlasTool::Args& args) {
src_grid = StructuredGrid(sstream.str());
}

Grid highres_grid;
std::string highres_gridname = "";
Mesh highres_mesh;
FunctionSpace highres_src_fs;
Field highres_src_field;
if (args.getBool("init_via_highres", false)) {
Grid highres_grid;
// project solution to the finest mesh for intialising the source fields
sstream.str("");
sstream << sgrid_type << 2 * eoc_maxres;
highres_grid = Grid(sstream.str());
highres_gridname = highres_grid.name();
highres_mesh = Mesh(highres_grid);
highres_src_fs = create_functionspace(highres_mesh, 2, "NodeColumns", 0);
highres_src_fs = create_functionspace(highres_mesh, 2, "CellColumns", 0);
const auto lonlat = array::make_view<double, 2>(highres_src_fs.lonlat());
highres_src_field = highres_src_fs.createField<double>();
auto highres_src_view = array::make_view<double, 1>(highres_src_field);
Expand All @@ -347,6 +342,7 @@ int AtlasEOAComputation::execute(const AtlasTool::Args& args) {
highres_src_view(n) = f(PointLonLat{lonlat(n, LON), lonlat(n, LAT)});
}
highres_src_field.set_dirty(true);
highres_src_field.haloExchange();
}

double err[eoc_cycles];
Expand All @@ -363,30 +359,22 @@ int AtlasEOAComputation::execute(const AtlasTool::Args& args) {
}
std::cout << src_grid.name() << " --> " << tgt_grid.name() << std::endl;

timers.target_setup.start();
auto tgt_mesh = Mesh{tgt_grid, grid::Partitioner(args.getString("target.partitioner", "serial"))};
auto tgt_fs =
create_functionspace(tgt_mesh, 2, args.getString("target.functionspace", ""), args.getBool("interpolation.structured", false));
auto tgt_field = tgt_fs.createField<double>();
timers.target_setup.stop();

timers.source_setup.start();
auto src_meshgenerator =
MeshGenerator{src_grid.meshgenerator() | option::halo(2) | util::Config("pole_elements", "")};
auto src_partitioner = grid::MatchingPartitioner{tgt_mesh, util::Config("partitioner",args.getString("source.partitioner", "spherical-polygon"))};
auto src_mesh = src_meshgenerator.generate(src_grid, src_partitioner);
auto src_mesh = src_meshgenerator.generate(src_grid);
auto src_fs =
create_functionspace(src_mesh, 2, args.getString("source.functionspace", ""), args.getBool("interpolation.structured", false));
create_functionspace(src_mesh, 4, args.getString("source.functionspace", ""), args.getBool("interpolation.structured", false));
auto src_field = src_fs.createField<double>();
timers.source_setup.stop();

timers.initial_condition.start();
if (args.getBool("init_via_highres", false) and (refine_source or gres == eoc_startres)) {
std::cout << "Prepare the initial data on " << src_grid.name() << " from " << highres_grid.name() << std::endl;
std::cout << "Prepare the initial data on " << src_grid.name() << " from " << highres_gridname << std::endl;
auto init_interpolation = Interpolation(option::type("conservative-spherical-polygon") | args, highres_src_fs, src_fs);
init_interpolation.execute(highres_src_field, src_field);
util::Config stats;
compute_errors(highres_src_field, src_field, get_init(args), highres_mesh, src_mesh, stats);
compute_errors(highres_src_field, src_field, get_init(args), highres_mesh, src_mesh, eoc_cycles, stats);
double errcons;
stats.get("err.cons", errcons);
std::cout << "highres_src -> src :: cons : " << errcons << std::endl;
Expand All @@ -398,24 +386,30 @@ int AtlasEOAComputation::execute(const AtlasTool::Args& args) {
for (idx_t n = 0; n < lonlat.shape(0); ++n) {
src_view(n) = f(PointLonLat{lonlat(n, LON), lonlat(n, LAT)});
}
src_field.set_dirty(true);
}
src_field.set_dirty(true);
src_field.haloExchange();
timers.initial_condition.stop();

timers.target_setup.start();
auto tgt_mesh = Mesh{tgt_grid, grid::Partitioner(args.getString("target.partitioner", "serial"))};
auto tgt_fs =
create_functionspace(tgt_mesh, 1, args.getString("target.functionspace", ""), args.getBool("interpolation.structured", false));
auto tgt_field = tgt_fs.createField<double>();
timers.target_setup.stop();

timers.interpolation_setup.start();
auto interpolation =
Interpolation(args, src_fs, tgt_fs);
//Log::info() << interpolation << std::endl;
timers.interpolation_setup.stop();


timers.interpolation_execute.start();
auto metadata = interpolation.execute(src_field, tgt_field);
timers.interpolation_execute.stop();

// compute EOCs
util::Config stats;
compute_errors(src_field, tgt_field, get_init(args), src_mesh, tgt_mesh, stats);
compute_errors(src_field, tgt_field, get_init(args), src_mesh, tgt_mesh, eoc_cycles, stats);
stats.get("err.ana2tgt_l2", err[counter]);
std::cout << "l2-error to analytical solution : " << err[counter] << std::endl;
if (counter > 1) {
Expand Down

0 comments on commit 8af09d5

Please sign in to comment.