Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small: Changes to DataCollection::Add and MeshBlockData::Intialize #1187

Merged
merged 5 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- [[PR 1161]](https://github.com/parthenon-hpc-lab/parthenon/pull/1161) Make flux field Metadata accessible, add Metadata::CellMemAligned flag, small perfomance upgrades

### Changed (changing behavior/API/variables/...)
- [[PR 1187]](https://github.com/parthenon-hpc-lab/parthenon/pull/1187) Make DataCollection::Add safer and generalize MeshBlockData::Initialize
- [[PR 1186]](https://github.com/parthenon-hpc-lab/parthenon/pull/1186) Bump Kokkos submodule to 4.4.1
- [[PR 1171]](https://github.com/parthenon-hpc-lab/parthenon/pull/1171) Add PARTHENON_USE_SYSTEM_PACKAGES build option
- [[PR 1172]](https://github.com/parthenon-hpc-lab/parthenon/pull/1172) Make parthenon manager robust against external MPI init and finalize calls
Expand Down
2 changes: 1 addition & 1 deletion src/interface/data_collection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DataCollection {
auto key = GetKey(name, src);
auto it = containers_.find(key);
if (it != containers_.end()) {
if (fields.size() && !(it->second)->Contains(fields)) {
if (fields.size() && !(it->second)->CreatedFrom(fields)) {
PARTHENON_THROW(key + " already exists in collection but fields do not match.");
}
return it->second;
Expand Down
8 changes: 8 additions & 0 deletions src/interface/mesh_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,14 @@ class MeshData {
[this, vars](const auto &b) { return b->ContainsExactly(vars); });
}

// Checks that the same set of variables was requested to create this container
// (which may be different than the set of variables in the container because of fluxes)
template <typename Vars_t>
bool CreatedFrom(const Vars_t &vars) const noexcept {
return std::all_of(block_data_.begin(), block_data_.end(),
[this, vars](const auto &b) { return b->CreatedFrom(vars); });
}

std::shared_ptr<SwarmContainer> GetSwarmData(int n) {
PARTHENON_REQUIRE(n >= 0 && n < block_data_.size(),
"MeshData::GetSwarmData requires n within [0, block_data_.size()]");
Expand Down
46 changes: 43 additions & 3 deletions src/interface/meshblock_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ class MeshBlockData {
resolved_packages = resolved_packages_in;
is_shallow_ = shallow_copy;

// Store the list of variables used to create this container
// so we can compare to it when searching the cache
varUidIn_.clear();
if constexpr (std::is_same_v<ID_t, std::string>) {
for (const auto &var : vars)
varUidIn_.insert(Variable<Real>::GetUniqueID(var));
} else {
for (const auto &var : vars)
varUidIn_.insert(var);
}

// clear all variables, maps, and pack caches
varVector_.clear();
varMap_.clear();
Expand Down Expand Up @@ -185,9 +196,24 @@ class MeshBlockData {
if (!found) add_var(src->GetVarPtr(flx_name));
}
}
} else {
PARTHENON_FAIL(
"Variable subset selection not yet implemented for MeshBlock input.");
} else if constexpr (std::is_same_v<SRC_t, MeshBlock>) {
for (const auto &v : vars) {
const auto &vid = resolved_packages->GetFieldVarID(v);
const auto &md = resolved_packages->GetFieldMetadata(v);
AddField(vid.base_name, md, vid.sparse_id);
// Add the associated flux as well if not explicitly
// asked for
if (md.IsSet(Metadata::WithFluxes)) {
auto flx_vid = resolved_packages->GetFieldVarID(md.GetFluxName());
bool found = false;
for (const auto &v2 : vars)
if (resolved_packages->GetFieldVarID(v2) == flx_vid) found = true;
if (!found) {
const auto &flx_md = resolved_packages->GetFieldMetadata(flx_vid);
AddField(flx_vid.base_name, flx_md, flx_vid.sparse_id);
}
}
}
}
}

Expand Down Expand Up @@ -525,6 +551,18 @@ class MeshBlockData {
return Contains(vars) && (vars.size() == varVector_.size());
}

bool CreatedFrom(const std::vector<Uid_t> &vars) {
return (vars.size() == varUidIn_.size()) &&
std::all_of(vars.begin(), vars.end(),
[this](const auto &v) { return this->varUidIn_.count(v); });
}
bool CreatedFrom(const std::vector<std::string> &vars) {
return (vars.size() == varUidIn_.size()) &&
std::all_of(vars.begin(), vars.end(), [this](const auto &v) {
return this->varUidIn_.count(Variable<Real>::GetUniqueID(v));
});
}

void SetAllVariablesToInitialized() {
std::for_each(varVector_.begin(), varVector_.end(),
[](auto &sp_var) { sp_var->data.initialized = true; });
Expand Down Expand Up @@ -561,6 +599,8 @@ class MeshBlockData {

VariableVector<T> varVector_; ///< the saved variable array
std::map<Uid_t, std::shared_ptr<Variable<T>>> varUidMap_;
std::set<Uid_t> varUidIn_; // Uid list from which this MeshBlockData was created,
// empty implies all variables were included

MapToVars<T> varMap_;
MetadataFlagToVariableMap<T> flagsToVars_;
Expand Down
1 change: 1 addition & 0 deletions src/interface/state_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ bool StateDescriptor::AddFieldImpl_(const VarID &vid, const Metadata &m_in,
AddFieldImpl_(fId, *(m.GetSPtrFluxMetadata()), control_vid);
m.SetFluxName(fId.label());
}
labelToVidMap_.insert({vid.label(), vid});
metadataMap_.insert({vid, m});
refinementFuncMaps_.Register(m, vid.label());
allocControllerReverseMap_.insert({vid, control_vid});
Expand Down
15 changes: 15 additions & 0 deletions src/interface/state_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ class StateDescriptor {
// retrieve all swarm names
std::vector<std::string> Swarms() noexcept;

const auto GetFieldVarID(const VarID &id) const {
PARTHENON_REQUIRE_THROWS(
metadataMap_.count(id),
"Asking for a variable that is not in this StateDescriptor.");
return id;
}

const auto &GetFieldVarID(const std::string &label) const {
return labelToVidMap_.at(label);
}
const auto &GetFieldMetadata(const std::string &label) const {
return metadataMap_.at(labelToVidMap_.at(label));
}
const auto &GetFieldMetadata(const VarID &id) const { return metadataMap_.at(id); }
const auto &AllFields() const noexcept { return metadataMap_; }
const auto &AllSparsePools() const noexcept { return sparsePoolMap_; }
const auto &AllSwarms() const noexcept { return swarmMetadataMap_; }
Expand Down Expand Up @@ -397,6 +411,7 @@ class StateDescriptor {
const std::string label_;

// for each variable label (full label for sparse variables) hold metadata
std::unordered_map<std::string, VarID> labelToVidMap_;
std::unordered_map<VarID, Metadata, VarIDHasher> metadataMap_;
std::unordered_map<VarID, VarID, VarIDHasher> allocControllerReverseMap_;
std::unordered_map<std::string, std::vector<std::string>> allocControllerMap_;
Expand Down
5 changes: 3 additions & 2 deletions src/mesh/mesh-amr_loadbalance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,8 +979,9 @@ void Mesh::RedistributeAndRefineMeshBlocks(ParameterInput *pin, ApplicationInput
auto &md_noncc = mesh_data.AddShallow(noncc, md, noncc_names);
}

CommunicateBoundaries(noncc); // Called to make sure shared values are correct,
// ghosts of non-cell centered vars may get some junk
CommunicateBoundaries(
noncc, noncc_names); // Called to make sure shared values are correct,
// ghosts of non-cell centered vars may get some junk
// Now there is the correct data for prolongating on un-shared topological elements
// on the new fine blocks
if (nprolong > 0) {
Expand Down
11 changes: 6 additions & 5 deletions src/mesh/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,8 @@ void Mesh::BuildTagMapAndBoundaryBuffers() {
}
}

void Mesh::CommunicateBoundaries(std::string md_name) {
void Mesh::CommunicateBoundaries(std::string md_name,
const std::vector<std::string> &fields) {
const int num_partitions = DefaultNumPartitions();
const int nmb = GetNumMeshBlocksThisRank(Globals::my_rank);
constexpr std::int64_t max_it = 1e10;
Expand All @@ -656,7 +657,7 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
do {
all_sent = true;
for (int i = 0; i < partitions.size(); ++i) {
auto &md = mesh_data.Add(md_name, partitions[i]);
auto &md = mesh_data.Add(md_name, partitions[i], fields);
if (!sent[i]) {
if (SendBoundaryBuffers(md) != TaskStatus::complete) {
all_sent = false;
Expand All @@ -680,7 +681,7 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
do {
all_received = true;
for (int i = 0; i < partitions.size(); ++i) {
auto &md = mesh_data.Add(md_name, partitions[i]);
auto &md = mesh_data.Add(md_name, partitions[i], fields);
if (!received[i]) {
if (ReceiveBoundaryBuffers(md) != TaskStatus::complete) {
all_received = false;
Expand All @@ -696,14 +697,14 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
"Too many iterations waiting to receive boundary communication buffers.");

for (auto &partition : partitions) {
auto &md = mesh_data.Add(md_name, partition);
auto &md = mesh_data.Add(md_name, partition, fields);
// unpack FillGhost variables
SetBoundaries(md);
}

// Now do prolongation, compute primitives, apply BCs
for (auto &partition : partitions) {
auto &md = mesh_data.Add(md_name, partition);
auto &md = mesh_data.Add(md_name, partition, fields);
if (multilevel) {
ApplyBoundaryConditionsOnCoarseOrFineMD(md, true);
ProlongateBoundaries(md);
Expand Down
3 changes: 2 additions & 1 deletion src/mesh/mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ class Mesh {

void SetupMPIComms();
void BuildTagMapAndBoundaryBuffers();
void CommunicateBoundaries(std::string md_name = "base");
void CommunicateBoundaries(std::string md_name = "base",
const std::vector<std::string> &fields = {});
void PreCommFillDerived();
void FillDerived();

Expand Down
Loading