Skip to content

Commit

Permalink
Merge pull request trilinos#11388 from iyamazaki/frosch-symb-clean
Browse files Browse the repository at this point in the history
Frosch : symbolic / compute setups
  • Loading branch information
iyamazaki authored Mar 22, 2024
2 parents a3853f8 + c3429bd commit a4f415f
Show file tree
Hide file tree
Showing 15 changed files with 346 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ namespace FROSch {

// count number of nonzeros per row
UN numLocalRows = rowMap->getLocalNumElements();
rowptr_type Rowptr ("Rowptr", numLocalRows+1);
rowptr_type Rowptr (Kokkos::ViewAllocateWithoutInitializing("Rowptr"), numLocalRows+1);
Kokkos::deep_copy(Rowptr, 0);
Kokkos::parallel_for(
"FROSch_CoarseSpace::countGlobalBasisMatrix", policy_row,
Expand Down Expand Up @@ -273,8 +273,8 @@ namespace FROSch {
#endif

// fill into the local matrix
indices_type Indices ("Indices", nnz);
values_type Values ("Values", nnz);
indices_type Indices (Kokkos::ViewAllocateWithoutInitializing("Indices"), nnz);
values_type Values (Kokkos::ViewAllocateWithoutInitializing("Values"), nnz);
auto AssembledBasisLocalMap = AssembledBasisMap_->getLocalMap();
Kokkos::parallel_for(
"FROSch_CoarseSpace::fillGlobalBasisMatrix", policy_row,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ namespace FROSch {
FROSCH_DETAILTIMER_START_LEVELID(updateLocalOverlappingMatricesTime,"AlgebraicOverlappingOperator::updateLocalOverlappingMatrices");
if (this->ExtractLocalSubdomainMatrix_Symbolic_Done_) {
// using original K_ as input
ExtractLocalSubdomainMatrix_Compute(this->K_, this->subdomainMatrix_, this->localSubdomainMatrix_);
ExtractLocalSubdomainMatrix_Compute(this->subdomainScatter_, this->K_, this->subdomainMatrix_, this->localSubdomainMatrix_);
this->OverlappingMatrix_ = this->localSubdomainMatrix_.getConst();
} else {
if (this->IsComputed_) {
Expand Down Expand Up @@ -325,6 +325,9 @@ namespace FROSch {
RCP<Import<LO,GO,NO> > scatter = ImportFactory<LO,GO,NO>::Build(this->OverlappingMatrix_->getRowMap(), this->OverlappingMap_);
this->subdomainMatrix_->doImport(*(this->OverlappingMatrix_), *scatter, ADD);

// Used to Map original K_ to overlapping suubdomainMatrix
this->subdomainScatter_ = ImportFactory<LO,GO,NO>::Build(this->K_->getRowMap(), this->OverlappingMap_);

// build local subdomain matrix
RCP<const Comm<LO> > SerialComm = rcp(new MpiComm<LO>(MPI_COMM_SELF));
RCP<Map<LO,GO,NO> > localSubdomainMap = MapFactory<LO,GO,NO>::Build(this->OverlappingMap_->lib(), this->OverlappingMap_->getLocalNumElements(), 0, SerialComm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ namespace FROSch {
bool coarseExtractLocalSubdomainMatrix_Symbolic_Done_ = false;
XMatrixPtr coarseSubdomainMatrix_;
XMatrixPtr coarseLocalSubdomainMatrix_;
XImportPtr coarseScatter_;

// Temp Vectors for apply()
mutable XMultiVectorPtr XTmp_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,21 @@ namespace FROSch {
FROSCH_TIMER_START_LEVELID(applyTime,"CoarseOperator::apply");
static int i = 0;
if (!Phi_.is_null() && this->IsComputed_) {
if (XTmp_.is_null()) XTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(x.getMap(),x.getNumVectors());
if (XTmp_.is_null() || XTmp_->getNumVectors() != x.getNumVectors()) {
XTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(x.getMap(),x.getNumVectors());
}
*XTmp_ = x;
if (!usePreconditionerOnly && mode == NO_TRANS) {
this->K_->apply(x,*XTmp_,mode,ScalarTraits<SC>::one(),ScalarTraits<SC>::zero());
}
if (XCoarseSolve_.is_null()) XCoarseSolve_ = MultiVectorFactory<SC,LO,GO,NO>::Build(GatheringMaps_[GatheringMaps_.size()-1],x.getNumVectors());
else XCoarseSolve_->replaceMap(GatheringMaps_[GatheringMaps_.size()-1]); // The map is replaced in applyCoarseSolve(). If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
if (YCoarseSolve_.is_null()) YCoarseSolve_ = MultiVectorFactory<SC,LO,GO,NO>::Build(GatheringMaps_[GatheringMaps_.size()-1],y.getNumVectors());
if (XCoarseSolve_.is_null() || XCoarseSolve_->getNumVectors() != x.getNumVectors()) {
XCoarseSolve_ = MultiVectorFactory<SC,LO,GO,NO>::Build(GatheringMaps_[GatheringMaps_.size()-1],x.getNumVectors());
} else {
XCoarseSolve_->replaceMap(GatheringMaps_[GatheringMaps_.size()-1]); // The map is replaced in applyCoarseSolve(). If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
}
if (YCoarseSolve_.is_null() || YCoarseSolve_->getNumVectors() != y.getNumVectors()) {
YCoarseSolve_ = MultiVectorFactory<SC,LO,GO,NO>::Build(GatheringMaps_[GatheringMaps_.size()-1],y.getNumVectors());
}
applyPhiT(*XTmp_,*XCoarseSolve_);
applyCoarseSolve(*XCoarseSolve_,*YCoarseSolve_,mode);
applyPhi(*YCoarseSolve_,*XTmp_);
Expand Down Expand Up @@ -192,12 +199,18 @@ namespace FROSch {
FROSCH_DETAILTIMER_START_LEVELID(applyCoarseSolveTime,"CoarseOperator::applyCoarseSolve");
if (OnCoarseSolveComm_) {
x.replaceMap(CoarseSolveMap_);
if (YTmp_.is_null()) YTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(CoarseSolveMap_,x.getNumVectors());
else YTmp_->replaceMap(CoarseSolveMap_); // The map is replaced later in this function. If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
if (YTmp_.is_null() || YTmp_->getNumVectors() != x.getNumVectors()) {
YTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(CoarseSolveMap_,x.getNumVectors());
} else {
YTmp_->replaceMap(CoarseSolveMap_); // The map is replaced later in this function. If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
}
CoarseSolver_->apply(x,*YTmp_,mode);
} else {
if (YTmp_.is_null()) YTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(CoarseSolveMap_,x.getNumVectors());
else YTmp_->replaceMap(CoarseSolveMap_); // The map is replaced later in this function. If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
if (YTmp_.is_null() || YTmp_->getNumVectors() != x.getNumVectors()) {
YTmp_ = MultiVectorFactory<SC,LO,GO,NO>::Build(CoarseSolveMap_,x.getNumVectors());
} else {
YTmp_->replaceMap(CoarseSolveMap_); // The map is replaced later in this function. If we do not build it from scratch, we should at least replace the map here. This may be important since the maps live on different communicators.
}
}
YTmp_->replaceMap(GatheringMaps_[GatheringMaps_.size()-1]);
y = *YTmp_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,122 @@ namespace FROSch {
UN j_out;
SCView data_out;
};


struct ScaleTag {};
struct CountNnzTag {};
struct TotalNnzTag {};
struct FillNzEntriesTag {};
template<class indicesView, class SCView, class localRowMapType, class localMVBasisType, class RowptrType, class IndicesType, class ValuesType>
struct detectLinearDependenciesFunctor
{
using Real = typename Teuchos::ScalarTraits<SC>::magnitudeType;
using STS = Kokkos::ArithTraits<SC>;
using RTS = Kokkos::ArithTraits<Real>;

UN numRows;
UN numCols;
SCView scale;
localMVBasisType localMVBasis;

SC tresholdDropping;
indicesView indicesGammaDofsAll;
localRowMapType localRowMap;
localRowMapType localRepeatedMap;

RowptrType Rowptr;
IndicesType Indices;
ValuesType Values;

// Constructor for ScaleTag
detectLinearDependenciesFunctor(UN numRows_, UN numCols_, SCView scale_, localMVBasisType localMVBasis_) :
numRows (numRows_),
numCols (numCols_),
scale (scale_),
localMVBasis (localMVBasis_)
{}

// Constructor for CountNnzTag
detectLinearDependenciesFunctor(UN numRows_, UN numCols_, localMVBasisType localMVBasis_, SC tresholdDropping_,
indicesView indicesGammaDofsAll_, localRowMapType localRowMap_, localRowMapType localRepeatedMap_,
RowptrType Rowptr_) :
numRows (numRows_),
numCols (numCols_),
scale (),
localMVBasis (localMVBasis_),
tresholdDropping (tresholdDropping_),
indicesGammaDofsAll (indicesGammaDofsAll_),
localRowMap (localRowMap_),
localRepeatedMap (localRepeatedMap_),
Rowptr (Rowptr_)
{}

// Constructor for FillNzEntriesTag
detectLinearDependenciesFunctor(UN numRows_, UN numCols_, SCView scale_, localMVBasisType localMVBasis_,
SC tresholdDropping_, indicesView indicesGammaDofsAll_,
localRowMapType localRowMap_, localRowMapType localRepeatedMap_,
RowptrType Rowptr_, IndicesType Indices_, ValuesType Values_) :
numRows (numRows_),
numCols (numCols_),
scale (scale_),
localMVBasis (localMVBasis_),
tresholdDropping (tresholdDropping_),
indicesGammaDofsAll (indicesGammaDofsAll_),
localRowMap (localRowMap_),
localRepeatedMap (localRepeatedMap_),
Rowptr (Rowptr_),
Indices (Indices_),
Values (Values_)
{}

KOKKOS_INLINE_FUNCTION
void operator()(const ScaleTag &, const int j) const {
scale(j) = STS::zero();
for (UN i = 0; i < numRows; i++) {
scale(j) += localMVBasis(i,j)*localMVBasis(i,j);
}
scale(j) = RTS::one()/RTS::sqrt(STS::abs(scale(j)));
}

KOKKOS_INLINE_FUNCTION
void operator()(const CountNnzTag &, const int i) const {
LO rowID = indicesGammaDofsAll[i];
GO iGlobal = localRepeatedMap.getGlobalElement(rowID);
LO iLocal = localRowMap.getLocalElement(iGlobal);
if (iLocal!=-1) { // This should prevent duplicate entries on the interface
for (UN j=0; j<numCols; j++) {
SC valueTmp=localMVBasis(i,j);
if (fabs(valueTmp)>tresholdDropping) {
Rowptr(iLocal+1) ++;
}
}
}
}


KOKKOS_INLINE_FUNCTION
void operator()(const TotalNnzTag&, const size_t i, UN &lsum) const {
lsum += Rowptr[i];
}

KOKKOS_INLINE_FUNCTION
void operator()(const FillNzEntriesTag &, const int i) const {
LO rowID = indicesGammaDofsAll[i];
GO iGlobal = localRepeatedMap.getGlobalElement(rowID);
LO iLocal = localRowMap.getLocalElement(iGlobal);
if (iLocal!=-1) { // This should prevent duplicate entries on the interface
UN nnz_i = Rowptr(iLocal);
for (UN j=0; j<numCols; j++) {
SC valueTmp=localMVBasis(i,j);
if (fabs(valueTmp)>tresholdDropping) {
Indices(nnz_i) = j; //localBasisMap.getGlobalElement(j);
Values(nnz_i) = valueTmp*scale(j);
nnz_i ++;
}
}
}
}
};
#endif

protected:
Expand Down
Loading

0 comments on commit a4f415f

Please sign in to comment.