diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5b373f6473..8ad84a7174 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -482,7 +482,7 @@ jobs: - name: Create & configure TTK build directory shell: cmd run: | - set CMAKE_PREFIX_PATH=%CONDA_ROOT%\Library\lib\cmake;%CONDA_ROOT%\Library\share\eigen3\cmake;%CONDA_ROOT%\Library\share\Qull\cmake;%CONDA_ROOT%\Library\cmake;%ProgramFiles%\TTK-ParaView\lib\cmake + set CMAKE_PREFIX_PATH=%CONDA_ROOT%\Library\lib\cmake;%CONDA_ROOT%\Library\share\eigen3\cmake;%CONDA_ROOT%\Library\share\Qull\cmake;%CONDA_ROOT%\Library\cmake;%ProgramFiles%\TTK-ParaView\lib\cmake; set CC=clang-cl.exe set CXX=clang-cl.exe call "%ProgramFiles%\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" diff --git a/core/base/persistenceDiagramClustering/PersistenceDiagramClustering.h b/core/base/persistenceDiagramClustering/PersistenceDiagramClustering.h index 7788a50301..e6c29f942f 100644 --- a/core/base/persistenceDiagramClustering/PersistenceDiagramClustering.h +++ b/core/base/persistenceDiagramClustering/PersistenceDiagramClustering.h @@ -53,6 +53,46 @@ namespace ttk { return this->distances; } + void setTimeLimit(double timeLimit) { + this->TimeLimit = timeLimit; + } + + void setForceUseOfAlgorithm(bool forceUseOfAlgorithm) { + this->ForceUseOfAlgorithm = forceUseOfAlgorithm; + } + + void setDeltaLim(double DeltaLimNew) { + this->DeltaLim = DeltaLimNew; + } + + void setUseAdditionalPrecision(bool Precision) { + this->UseAdditionalPrecision = Precision; + } + + void setUseProgressive(bool UseProgressive_) { + this->UseProgressive = UseProgressive_; + } + + void setUseInterruptible(bool UseInterruptible_) { + this->UseInterruptible = UseInterruptible_; + } + + void setDeterministic(bool Deterministic_) { + this->Deterministic = Deterministic_; + } + + void setAlpha(double Alpha_) { + this->Alpha = Alpha_; + } + + void setUseAccelerated(bool UseAccelerated_) { + this->UseAccelerated = UseAccelerated_; + } + + void setUseKmeansppInit(bool UseKmeansppInit_) { + this->UseKmeansppInit = UseKmeansppInit_; + } + protected: // Critical pairs used for clustering // 0:min-saddles ; 1:saddles-saddles ; 2:sad-max ; else : all diff --git a/core/base/topologicalOptimization/CMakeLists.txt b/core/base/topologicalOptimization/CMakeLists.txt new file mode 100644 index 0000000000..b6a8a573fb --- /dev/null +++ b/core/base/topologicalOptimization/CMakeLists.txt @@ -0,0 +1,17 @@ +ttk_add_base_library(topologicalOptimization + SOURCES + TopologicalOptimization.cpp + HEADERS + TopologicalOptimization.h + DEPENDS + triangulation + persistenceDiagram + persistenceDiagramClustering + ) + +if(TTK_ENABLE_TORCH) + target_include_directories(topologicalOptimization PUBLIC ${TORCH_INCLUDE_DIRS}) + target_compile_options(topologicalOptimization PUBLIC "${TORCH_CXX_FLAGS}") + target_link_libraries(topologicalOptimization PUBLIC "${TORCH_LIBRARIES}") + target_compile_definitions(topologicalOptimization PUBLIC TTK_ENABLE_TORCH) +endif() diff --git a/core/base/topologicalOptimization/TopologicalOptimization.cpp b/core/base/topologicalOptimization/TopologicalOptimization.cpp new file mode 100644 index 0000000000..999b055934 --- /dev/null +++ b/core/base/topologicalOptimization/TopologicalOptimization.cpp @@ -0,0 +1,5 @@ +#include + +ttk::TopologicalOptimization::TopologicalOptimization() { + this->setDebugMsgPrefix("TopologicalOptimization"); +} diff --git a/core/base/topologicalOptimization/TopologicalOptimization.h b/core/base/topologicalOptimization/TopologicalOptimization.h new file mode 100644 index 0000000000..616a94399f --- /dev/null +++ b/core/base/topologicalOptimization/TopologicalOptimization.h @@ -0,0 +1,1672 @@ +/// \ingroup base +/// \class ttk::TopologicalOptimization +/// \author Julien Tierny +/// \author Mohamed Amine Kissi +/// \date March 2024 + +#pragma once + +#ifdef TTK_ENABLE_TORCH +#include +#include +#endif + +// base code includes +#include +#include +#include +#include + +namespace ttk { + + class TopologicalOptimization : virtual public Debug { + public: + TopologicalOptimization(); + + template + int execute(const dataType *const inputScalars, + dataType *const outputScalars, + SimplexId *const inputOffsets, + triangulationType *triangulation, + const ttk::DiagramType &constraintDiagram) const; + + inline int preconditionTriangulation(AbstractTriangulation *triangulation) { + if(triangulation) { + vertexNumber_ = triangulation->getNumberOfVertices(); + triangulation->preconditionVertexNeighbors(); + } + return 0; + } + + /* + This function allows us to retrieve the indices of the critical points + that we must modify in order to match our current diagram to our target + diagram. + */ + template + void getIndices( + triangulationType *triangulation, + SimplexId *&inputOffsets, + dataType *const inputScalars, + const ttk::DiagramType &constraintDiagram, + int epoch, + std::vector &listAllIndicesToChange, + std::vector> &pair2MatchedPair, + std::vector> &pair2Delete, + std::vector &pairChangeMatchingPair, + std::vector &birthPairToDeleteCurrentDiagram, + std::vector &birthPairToDeleteTargetDiagram, + std::vector &deathPairToDeleteCurrentDiagram, + std::vector &deathPairToDeleteTargetDiagram, + std::vector &birthPairToChangeCurrentDiagram, + std::vector &birthPairToChangeTargetDiagram, + std::vector &deathPairToChangeCurrentDiagram, + std::vector &deathPairToChangeTargetDiagram, + std::vector> ¤tVertex2PairsCurrentDiagram, + std::vector &vertexInHowManyPairs) const; + +/* + This function allows you to copy the values of a pytorch tensor + to a vector in an optimized way. +*/ +#ifdef TTK_ENABLE_TORCH + int tensorToVectorFast(const torch::Tensor &tensor, + std::vector &result) const; +#endif + + inline void setUseFastPersistenceUpdate(bool UseFastPersistenceUpdate) { + useFastPersistenceUpdate_ = UseFastPersistenceUpdate; + } + + inline void setFastAssignmentUpdate(bool FastAssignmentUpdate) { + fastAssignmentUpdate_ = FastAssignmentUpdate; + } + + inline void setEpochNumber(int EpochNumber) { + epochNumber_ = EpochNumber; + } + + inline void setPDCMethod(int PDCMethod) { + pdcMethod_ = PDCMethod; + } + + inline void setMethodOptimization(int methodOptimization) { + methodOptimization_ = methodOptimization; + } + + inline void setFinePairManagement(int finePairManagement) { + finePairManagement_ = finePairManagement; + } + + inline void setChooseLearningRate(int chooseLearningRate) { + chooseLearningRate_ = chooseLearningRate; + } + + inline void setLearningRate(double learningRate) { + learningRate_ = learningRate; + } + + inline void setAlpha(double alpha) { + alpha_ = alpha; + } + + inline void setCoefStopCondition(double coefStopCondition) { + coefStopCondition_ = coefStopCondition; + } + + inline void + setOptimizationWithoutMatching(bool optimizationWithoutMatching) { + optimizationWithoutMatching_ = optimizationWithoutMatching; + } + + inline void setThresholdMethod(int thresholdMethod) { + thresholdMethod_ = thresholdMethod; + } + + inline void setThresholdPersistence(double thresholdPersistence) { + thresholdPersistence_ = thresholdPersistence; + } + + inline void setLowerThreshold(int lowerThreshold) { + lowerThreshold_ = lowerThreshold; + } + + inline void setUpperThreshold(int upperThreshold) { + upperThreshold_ = upperThreshold; + } + + inline void setPairTypeToDelete(int pairTypeToDelete) { + pairTypeToDelete_ = pairTypeToDelete; + } + + inline void setConstraintAveraging(bool ConstraintAveraging) { + constraintAveraging_ = ConstraintAveraging; + } + + inline void setPrintFrequency(int printFrequency) { + printFrequency_ = printFrequency; + } + + protected: + SimplexId vertexNumber_{}; + int epochNumber_; + + // enable the fast update of the persistence diagram + bool useFastPersistenceUpdate_; + + // enable the fast update of the pair assignments between the target diagram + bool fastAssignmentUpdate_; + + // if pdcMethod_ == 0 then we use Progressive approach + // if pdcMethod_ == 1 then we use Classical Auction approach + int pdcMethod_; + + // if methodOptimization_ == 0 then we use Direct gradient descent + // if methodOptimization_ == 1 then we use Adam + int methodOptimization_; + + // if finePairManagement_ == 0 then we let the algorithm choose + // if finePairManagement_ == 1 then we fill the domain + // if finePairManagement_ == 2 then we cut the domain + int finePairManagement_; + + // Adam + bool chooseLearningRate_; + double learningRate_; + + // Direct gradient descent + // alpha_ : the gradient step size + double alpha_; + + // Stopping criterion: when the loss becomes less than a percentage + // coefStopCondition_ (e.g. coefStopCondition_ = 0.01 => 1%) of the original + // loss (between input diagram and simplified diagram) + double coefStopCondition_; + + // Optimization without matching (OWM) + bool optimizationWithoutMatching_; + + // [OWM] if thresholdMethod_ == 0 : threshold on persistence + // [OWM] if thresholdMethod_ == 1 : threshold on pair type + int thresholdMethod_; + + // [OWM] thresholdPersistence_ : The threshold value on persistence. + double thresholdPersistence_; + + // [OWM] lowerThreshold_ : The lower threshold on pair type + int lowerThreshold_; + + // [OWM] upperThreshold_ : The upper threshold on pair type + int upperThreshold_; + + // [OWM] pairTypeToDelete_ : Remove only pairs of type pairTypeToDelete_ + int pairTypeToDelete_; + + bool constraintAveraging_; + + int printFrequency_{10}; + }; + +} // namespace ttk + +#ifdef TTK_ENABLE_TORCH +class PersistenceGradientDescent : public torch::nn::Module, + public ttk::TopologicalOptimization { +public: + PersistenceGradientDescent(torch::Tensor X_tensor) : torch::nn::Module() { + X = register_parameter("X", X_tensor, true); + } + torch::Tensor X; +}; + +#endif + +/* + This function allows us to retrieve the indices of the critical points + that we must modify in order to match our current diagram to our target + diagram. +*/ +template +void ttk::TopologicalOptimization::getIndices( + triangulationType *triangulation, + SimplexId *&inputOffsets, + dataType *const inputScalars, + const ttk::DiagramType &constraintDiagram, + int epoch, + std::vector &listAllIndicesToChange, + std::vector> &pair2MatchedPair, + std::vector> &pair2Delete, + std::vector &pairChangeMatchingPair, + std::vector &birthPairToDeleteCurrentDiagram, + std::vector &birthPairToDeleteTargetDiagram, + std::vector &deathPairToDeleteCurrentDiagram, + std::vector &deathPairToDeleteTargetDiagram, + std::vector &birthPairToChangeCurrentDiagram, + std::vector &birthPairToChangeTargetDiagram, + std::vector &deathPairToChangeCurrentDiagram, + std::vector &deathPairToChangeTargetDiagram, + std::vector> ¤tVertex2PairsCurrentDiagram, + std::vector &vertexInHowManyPairs) const { + + //========================================= + // Lazy Gradient + //========================================= + + bool needUpdateDefaultValue + = (useFastPersistenceUpdate_ ? (epoch == 0 || epoch < 0 ? true : false) + : true); + std::vector needUpdate(vertexNumber_, needUpdateDefaultValue); + if(useFastPersistenceUpdate_) { + /* + There is a 10% loss of performance + */ + this->printMsg( + "Get Indices | UseFastPersistenceUpdate_", debug::Priority::DETAIL); + + if(not(epoch == 0 || epoch < 0)) { +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(size_t index = 0; index < listAllIndicesToChange.size(); index++) { + if(listAllIndicesToChange[index] == 1) { + needUpdate[index] = true; + + // Find all the neighbors of the vertex + int vertexNumber = triangulation->getVertexNeighborNumber(index); + for(int i = 0; i < vertexNumber; i++) { + SimplexId vertexNeighborId = -1; + triangulation->getVertexNeighbor(index, i, vertexNeighborId); + needUpdate[vertexNeighborId] = true; + } + } + } + } + } + + SimplexId count = std::count(needUpdate.begin(), needUpdate.end(), true); + + this->printMsg( + "Get Indices | The number of vertices that need to be updated is: " + + std::to_string(count), + debug::Priority::DETAIL); + + //========================================= + // Compute the persistence diagram + //========================================= + ttk::Timer timePersistenceDiagram; + + ttk::PersistenceDiagram diagram; + std::vector diagramOutput; + ttk::preconditionOrderArray( + vertexNumber_, inputScalars, inputOffsets, threadNumber_); + diagram.setDebugLevel(debugLevel_); + diagram.setThreadNumber(threadNumber_); + diagram.preconditionTriangulation(triangulation); + + if(useFastPersistenceUpdate_) { + diagram.execute( + diagramOutput, inputScalars, 0, inputOffsets, triangulation, &needUpdate); + } else { + diagram.execute( + diagramOutput, inputScalars, epoch, inputOffsets, triangulation); + } + + //===================================== + // Matching Pairs + //===================================== + + if(optimizationWithoutMatching_) { + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto pair = diagramOutput[i]; + if((thresholdMethod_ == 0) + && (pair.persistence() < thresholdPersistence_)) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + } else if((thresholdMethod_ == 1) + && ((pair.dim < lowerThreshold_) + || (pair.dim > upperThreshold_))) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + } else if((thresholdMethod_ == 2) && (pair.dim == pairTypeToDelete_)) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + } + } + } else if(fastAssignmentUpdate_) { + + std::vector> vertex2PairsCurrentDiagram( + vertexNumber_, std::vector()); + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto &pair = diagramOutput[i]; + vertex2PairsCurrentDiagram[pair.birth.id].push_back(i); + vertex2PairsCurrentDiagram[pair.death.id].push_back(i); + vertexInHowManyPairs[pair.birth.id]++; + vertexInHowManyPairs[pair.death.id]++; + } + + std::vector> vertex2PairsTargetDiagram( + vertexNumber_, std::vector()); + for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) { + auto &pair = constraintDiagram[i]; + vertex2PairsTargetDiagram[pair.birth.id].push_back(i); + vertex2PairsTargetDiagram[pair.death.id].push_back(i); + } + + std::vector> matchedPairs; + for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) { + auto &pair = constraintDiagram[i]; + + SimplexId birthId = -1; + SimplexId deathId = -1; + + if(pairChangeMatchingPair[i] == 1) { + birthId = pair2MatchedPair[i][0]; + deathId = pair2MatchedPair[i][1]; + } else { + birthId = pair.birth.id; + deathId = pair.death.id; + } + + if(epoch == 0) { + for(auto &idPairBirth : vertex2PairsCurrentDiagram[birthId]) { + for(auto &idPairDeath : vertex2PairsCurrentDiagram[deathId]) { + if(idPairBirth == idPairDeath) { + matchedPairs.push_back({i, idPairBirth}); + } + } + } + } else if((vertex2PairsCurrentDiagram[birthId].size() == 1) + && (vertex2PairsCurrentDiagram[deathId].size() == 1)) { + if(vertex2PairsCurrentDiagram[birthId][0] + == vertex2PairsCurrentDiagram[deathId][0]) { + matchedPairs.push_back({i, vertex2PairsCurrentDiagram[deathId][0]}); + } + } + } + + std::vector matchingPairCurrentDiagram( + (SimplexId)diagramOutput.size(), -1); + std::vector matchingPairTargetDiagram( + (SimplexId)constraintDiagram.size(), -1); + + for(auto &match : matchedPairs) { + auto &indicePairTargetDiagram = match[0]; + auto &indicePairCurrentDiagram = match[1]; + + auto &pairCurrentDiagram = diagramOutput[indicePairCurrentDiagram]; + auto &pairTargetDiagram = constraintDiagram[indicePairTargetDiagram]; + + pair2MatchedPair[indicePairTargetDiagram][0] + = pairCurrentDiagram.birth.id; + pair2MatchedPair[indicePairTargetDiagram][1] + = pairCurrentDiagram.death.id; + + matchingPairCurrentDiagram[indicePairCurrentDiagram] = 1; + matchingPairTargetDiagram[indicePairTargetDiagram] = 1; + + SimplexId valueBirthPairToChangeCurrentDiagram + = (SimplexId)(pairCurrentDiagram.birth.id); + SimplexId valueDeathPairToChangeCurrentDiagram + = (SimplexId)(pairCurrentDiagram.death.id); + + double valueBirthPairToChangeTargetDiagram + = pairTargetDiagram.birth.sfValue; + double valueDeathPairToChangeTargetDiagram + = pairTargetDiagram.death.sfValue; + + birthPairToChangeCurrentDiagram.push_back( + valueBirthPairToChangeCurrentDiagram); + birthPairToChangeTargetDiagram.push_back( + valueBirthPairToChangeTargetDiagram); + deathPairToChangeCurrentDiagram.push_back( + valueDeathPairToChangeCurrentDiagram); + deathPairToChangeTargetDiagram.push_back( + valueDeathPairToChangeTargetDiagram); + } + + ttk::DiagramType thresholdCurrentDiagram{}; + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto &pair = diagramOutput[i]; + + if((pair2Delete[pair.birth.id].size() == 1) + && (pair2Delete[pair.death.id].size() == 1) + && (pair2Delete[pair.birth.id] == pair2Delete[pair.death.id])) { + + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + if(matchingPairCurrentDiagram[i] == -1) { + thresholdCurrentDiagram.push_back(pair); + } + } + + ttk::DiagramType thresholdConstraintDiagram{}; + std::vector pairIndiceLocal2Global{}; + for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) { + auto &pair = constraintDiagram[i]; + + if(matchingPairTargetDiagram[i] == -1) { + thresholdConstraintDiagram.push_back(pair); + pairIndiceLocal2Global.push_back(i); + } + } + + this->printMsg("Get Indices | thresholdCurrentDiagram.size(): " + + std::to_string(thresholdCurrentDiagram.size()), + debug::Priority::DETAIL); + + this->printMsg("Get Indices | thresholdConstraintDiagram.size(): " + + std::to_string(thresholdConstraintDiagram.size()), + debug::Priority::DETAIL); + + if(thresholdConstraintDiagram.size() == 0) { + for(SimplexId i = 0; i < (SimplexId)thresholdCurrentDiagram.size(); i++) { + auto &pair = thresholdCurrentDiagram[i]; + + if(!constraintAveraging_) { + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is not in a signal pair + // Then we only modify the pair.death.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) { + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.death.id is in a signal pair + // AND If the point pair.birth.id is not in a signal pair + // Then we only modify the pair.birth.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0) + && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is in a signal pair + // Then we do not modify either point + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + continue; + } + } + + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + + pair2Delete[pair.birth.id].push_back(i); + pair2Delete[pair.death.id].push_back(i); + } + } else { + + ttk::Timer timePersistenceDiagramClustering; + + ttk::PersistenceDiagramClustering persistenceDiagramClustering; + PersistenceDiagramBarycenter pdBarycenter{}; + std::vector intermediateDiagrams{ + thresholdConstraintDiagram, thresholdCurrentDiagram}; + std::vector>> allMatchings; + std::vector centroids{}; + + if(pdcMethod_ == 0) { + persistenceDiagramClustering.setDebugLevel(debugLevel_); + persistenceDiagramClustering.setThreadNumber(threadNumber_); + // setDeterministic ==> Deterministic algorithm + persistenceDiagramClustering.setDeterministic(true); + // setUseProgressive ==> Compute Progressive Barycenter + persistenceDiagramClustering.setUseProgressive(true); + // setUseInterruptible ==> Interruptible algorithm + persistenceDiagramClustering.setUseInterruptible(false); + // // setTimeLimit ==> Maximal computation time (s) + persistenceDiagramClustering.setTimeLimit(0.01); + // setUseAdditionalPrecision ==> Force minimum precision on matchings + persistenceDiagramClustering.setUseAdditionalPrecision(true); + // setDeltaLim ==> Minimal relative precision + persistenceDiagramClustering.setDeltaLim(1e-5); + // setUseAccelerated ==> Use Accelerated KMeans + persistenceDiagramClustering.setUseAccelerated(false); + // setUseKmeansppInit ==> KMeanspp Initialization + persistenceDiagramClustering.setUseKmeansppInit(false); + + std::vector clusterIds = persistenceDiagramClustering.execute( + intermediateDiagrams, centroids, allMatchings); + } else { + + centroids.resize(1); + const auto wassersteinMetric = std::to_string(2); + pdBarycenter.setWasserstein(wassersteinMetric); + pdBarycenter.setMethod(2); + pdBarycenter.setNumberOfInputs(2); + pdBarycenter.setDeterministic(true); + pdBarycenter.setUseProgressive(true); + pdBarycenter.setDebugLevel(debugLevel_); + pdBarycenter.setThreadNumber(threadNumber_); + pdBarycenter.setAlpha(1); + pdBarycenter.setLambda(1); + pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings); + } + + std::vector> allPairsSelected{}; + std::vector> matchingsBlockPairs( + centroids[0].size()); + + for(auto i = 1; i >= 0; --i) { + std::vector &matching = allMatchings[0][i]; + + const auto &diag{intermediateDiagrams[i]}; + + for(SimplexId j = 0; j < (SimplexId)matching.size(); j++) { + + const auto &m{matching[j]}; + const auto &bidderId{std::get<0>(m)}; + const auto &goodId{std::get<1>(m)}; + + if((goodId == -1) | (bidderId == -1)) { + continue; + } + + if(diag[bidderId].persistence() != 0) { + if(i == 1) { + matchingsBlockPairs[goodId].push_back(bidderId); + } else if(matchingsBlockPairs[goodId].size() > 0) { + matchingsBlockPairs[goodId].push_back(bidderId); + } + allPairsSelected.push_back( + {diag[bidderId].birth.id, diag[bidderId].death.id}); + } + } + } + + std::vector pairsToErase{}; + + std::map, SimplexId> currentToTarget; + for(auto &pair : allPairsSelected) { + currentToTarget[{pair[0], pair[1]}] = 1; + } + + for(auto &pair : intermediateDiagrams[1]) { + if(pair.isFinite != 0) { + if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) { + pairsToErase.push_back(pair); + } + } + } + + for(auto &pair : pairsToErase) { + + if(!constraintAveraging_) { + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is not in a signal pair + // Then we only modify the pair.death.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) { + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.death.id is in a signal pair + // AND If the point pair.birth.id is not in a signal pair + // Then we only modify the pair.birth.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0) + && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is in a signal pair + // Then we do not modify either point + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + continue; + } + } + + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + } + + for(const auto &entry : matchingsBlockPairs) { + // Delete pairs that have no equivalence + if(entry.size() == 1) { + + if(!constraintAveraging_) { + // If the point thresholdCurrentDiagram[entry[0]].birth.id is in a + // signal pair AND If the point + // thresholdCurrentDiagram[entry[0]].death.id is not in a signal + // pair Then we only modify the + // thresholdCurrentDiagram[entry[0]].death.id + if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .birth.id] + .size() + >= 1) + && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .death.id] + .size() + == 0)) { + deathPairToDeleteCurrentDiagram.push_back(static_cast( + thresholdCurrentDiagram[entry[0]].death.id)); + deathPairToDeleteTargetDiagram.push_back( + (thresholdCurrentDiagram[entry[0]].birth.sfValue + + thresholdCurrentDiagram[entry[0]].death.sfValue) + / 2); + continue; + } + + // If the point thresholdCurrentDiagram[entry[0]].death.id is in a + // signal pair AND If the point + // thresholdCurrentDiagram[entry[0]].birth.id is not in a signal + // pair Then we only modify the + // thresholdCurrentDiagram[entry[0]].birth.id + if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .birth.id] + .size() + == 0) + && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .death.id] + .size() + >= 1)) { + birthPairToDeleteCurrentDiagram.push_back(static_cast( + thresholdCurrentDiagram[entry[0]].birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (thresholdCurrentDiagram[entry[0]].birth.sfValue + + thresholdCurrentDiagram[entry[0]].death.sfValue) + / 2); + continue; + } + + // If the point thresholdCurrentDiagram[entry[0]].birth.id is in a + // signal pair AND If the point + // thresholdCurrentDiagram[entry[0]].death.id is in a signal pair + // Then we do not modify either point + if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .birth.id] + .size() + >= 1) + || (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]] + .death.id] + .size() + >= 1)) { + continue; + } + } + + birthPairToDeleteCurrentDiagram.push_back( + static_cast(thresholdCurrentDiagram[entry[0]].birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (thresholdCurrentDiagram[entry[0]].birth.sfValue + + thresholdCurrentDiagram[entry[0]].death.sfValue) + / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(thresholdCurrentDiagram[entry[0]].death.id)); + deathPairToDeleteTargetDiagram.push_back( + (thresholdCurrentDiagram[entry[0]].birth.sfValue + + thresholdCurrentDiagram[entry[0]].death.sfValue) + / 2); + continue; + } else if(entry.empty()) + continue; + + SimplexId valueBirthPairToChangeCurrentDiagram + = static_cast(thresholdCurrentDiagram[entry[0]].birth.id); + SimplexId valueDeathPairToChangeCurrentDiagram + = static_cast(thresholdCurrentDiagram[entry[0]].death.id); + + double valueBirthPairToChangeTargetDiagram + = thresholdConstraintDiagram[entry[1]].birth.sfValue; + double valueDeathPairToChangeTargetDiagram + = thresholdConstraintDiagram[entry[1]].death.sfValue; + + pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][0] + = thresholdCurrentDiagram[entry[0]].birth.id; + pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][1] + = thresholdCurrentDiagram[entry[0]].death.id; + + pairChangeMatchingPair[pairIndiceLocal2Global[entry[1]]] = 1; + + birthPairToChangeCurrentDiagram.push_back( + valueBirthPairToChangeCurrentDiagram); + birthPairToChangeTargetDiagram.push_back( + valueBirthPairToChangeTargetDiagram); + deathPairToChangeCurrentDiagram.push_back( + valueDeathPairToChangeCurrentDiagram); + deathPairToChangeTargetDiagram.push_back( + valueDeathPairToChangeTargetDiagram); + } + } + } + //=====================================// + // Basic Matching // + //=====================================// + else { + this->printMsg( + "Get Indices | Compute Wasserstein distance: ", debug::Priority::DETAIL); + + if(epoch == 0) { + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto &pair = diagramOutput[i]; + currentVertex2PairsCurrentDiagram[pair.birth.id].push_back(i); + currentVertex2PairsCurrentDiagram[pair.death.id].push_back(i); + } + } else { + std::vector> newVertex2PairsCurrentDiagram( + vertexNumber_, std::vector()); + + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto &pair = diagramOutput[i]; + newVertex2PairsCurrentDiagram[pair.birth.id].push_back(i); + newVertex2PairsCurrentDiagram[pair.death.id].push_back(i); + } + + currentVertex2PairsCurrentDiagram = newVertex2PairsCurrentDiagram; + } + + std::vector> vertex2PairsCurrentDiagram( + vertexNumber_, std::vector()); + for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) { + auto &pair = diagramOutput[i]; + vertex2PairsCurrentDiagram[pair.birth.id].push_back(i); + vertex2PairsCurrentDiagram[pair.death.id].push_back(i); + vertexInHowManyPairs[pair.birth.id]++; + vertexInHowManyPairs[pair.death.id]++; + } + + std::vector> vertex2PairsTargetDiagram( + vertexNumber_, std::vector()); + for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) { + auto &pair = constraintDiagram[i]; + vertex2PairsTargetDiagram[pair.birth.id].push_back(i); + vertex2PairsTargetDiagram[pair.death.id].push_back(i); + } + + //========================================= + // Compute wasserstein distance + //========================================= + ttk::Timer timePersistenceDiagramClustering; + + ttk::PersistenceDiagramClustering persistenceDiagramClustering; + PersistenceDiagramBarycenter pdBarycenter{}; + std::vector intermediateDiagrams{ + constraintDiagram, diagramOutput}; + std::vector centroids; + std::vector>> allMatchings; + + if(pdcMethod_ == 0) { + persistenceDiagramClustering.setDebugLevel(debugLevel_); + persistenceDiagramClustering.setThreadNumber(threadNumber_); + // SetForceUseOfAlgorithm ==> Force the progressive approch if 2 inputs + persistenceDiagramClustering.setForceUseOfAlgorithm(false); + // setDeterministic ==> Deterministic algorithm + persistenceDiagramClustering.setDeterministic(true); + // setUseProgressive ==> Compute Progressive Barycenter + persistenceDiagramClustering.setUseProgressive(true); + // setUseInterruptible ==> Interruptible algorithm + // persistenceDiagramClustering.setUseInterruptible(true); + persistenceDiagramClustering.setUseInterruptible(false); + // // setTimeLimit ==> Maximal computation time (s) + persistenceDiagramClustering.setTimeLimit(0.01); + // setUseAdditionalPrecision ==> Force minimum precision on matchings + persistenceDiagramClustering.setUseAdditionalPrecision(true); + // setDeltaLim ==> Minimal relative precision + persistenceDiagramClustering.setDeltaLim(0.00000001); + // setUseAccelerated ==> Use Accelerated KMeans + persistenceDiagramClustering.setUseAccelerated(false); + // setUseKmeansppInit ==> KMeanspp Initialization + persistenceDiagramClustering.setUseKmeansppInit(false); + + std::vector clusterIds = persistenceDiagramClustering.execute( + intermediateDiagrams, centroids, allMatchings); + } else { + centroids.resize(1); + const auto wassersteinMetric = std::to_string(2); + pdBarycenter.setWasserstein(wassersteinMetric); + pdBarycenter.setMethod(2); + pdBarycenter.setNumberOfInputs(2); + pdBarycenter.setDeterministic(true); + pdBarycenter.setUseProgressive(true); + pdBarycenter.setDebugLevel(debugLevel_); + pdBarycenter.setThreadNumber(threadNumber_); + pdBarycenter.setAlpha(1); + pdBarycenter.setLambda(1); + pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings); + } + + this->printMsg( + "Get Indices | Persistence Diagram Clustering Time: " + + std::to_string(timePersistenceDiagramClustering.getElapsedTime()), + debug::Priority::DETAIL); + + //========================================= + // Find matched pairs + //========================================= + + std::vector> allPairsSelected{}; + std::vector>> matchingsBlock( + centroids[0].size()); + std::vector> matchingsBlockPairs( + centroids[0].size()); + + for(auto i = 1; i >= 0; --i) { + std::vector &matching = allMatchings[0][i]; + + const auto &diag{intermediateDiagrams[i]}; + + for(SimplexId j = 0; j < (SimplexId)matching.size(); j++) { + + const auto &m{matching[j]}; + const auto &bidderId{std::get<0>(m)}; + const auto &goodId{std::get<1>(m)}; + + if((goodId == -1) | (bidderId == -1)) + continue; + + if(diag[bidderId].persistence() != 0) { + matchingsBlock[goodId].push_back( + {static_cast(diag[bidderId].birth.id), + static_cast(diag[bidderId].death.id), + diag[bidderId].persistence()}); + if(i == 1) { + matchingsBlockPairs[goodId].push_back(diag[bidderId]); + } else if(matchingsBlockPairs[goodId].size() > 0) { + matchingsBlockPairs[goodId].push_back(diag[bidderId]); + } + allPairsSelected.push_back( + {diag[bidderId].birth.id, diag[bidderId].death.id}); + } + } + } + + std::vector pairsToErase{}; + + std::map, SimplexId> currentToTarget; + for(auto &pair : allPairsSelected) { + currentToTarget[{pair[0], pair[1]}] = 1; + } + + for(auto &pair : intermediateDiagrams[1]) { + if(pair.isFinite != 0) { + if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) { + pairsToErase.push_back(pair); + } + } + } + + for(auto &pair : pairsToErase) { + + if(!constraintAveraging_) { + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is not in a signal pair + // Then we only modify the pair.death.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) { + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.death.id is in a signal pair + // AND If the point pair.birth.id is not in a signal pair + // Then we only modify the pair.birth.id + if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0) + && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + continue; + } + + // If the point pair.birth.id is in a signal pair + // AND If the point pair.death.id is in a signal pair + // Then we do not modify either point + if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1) + || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) { + continue; + } + } + + birthPairToDeleteCurrentDiagram.push_back( + static_cast(pair.birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(pair.death.id)); + deathPairToDeleteTargetDiagram.push_back( + (pair.birth.sfValue + pair.death.sfValue) / 2); + } + + for(const auto &entry : matchingsBlockPairs) { + // Delete pairs that have no equivalence + if(entry.size() == 1) { + birthPairToDeleteCurrentDiagram.push_back( + static_cast(entry[0].birth.id)); + birthPairToDeleteTargetDiagram.push_back( + (entry[0].birth.sfValue + entry[0].death.sfValue) / 2); + deathPairToDeleteCurrentDiagram.push_back( + static_cast(entry[0].death.id)); + deathPairToDeleteTargetDiagram.push_back( + (entry[0].birth.sfValue + entry[0].death.sfValue) / 2); + continue; + } else if(entry.empty()) + continue; + + SimplexId valueBirthPairToChangeCurrentDiagram + = static_cast(entry[0].birth.id); + SimplexId valueDeathPairToChangeCurrentDiagram + = static_cast(entry[0].death.id); + + double valueBirthPairToChangeTargetDiagram = entry[1].birth.sfValue; + double valueDeathPairToChangeTargetDiagram = entry[1].death.sfValue; + + birthPairToChangeCurrentDiagram.push_back( + valueBirthPairToChangeCurrentDiagram); + birthPairToChangeTargetDiagram.push_back( + valueBirthPairToChangeTargetDiagram); + deathPairToChangeCurrentDiagram.push_back( + valueDeathPairToChangeCurrentDiagram); + deathPairToChangeTargetDiagram.push_back( + valueDeathPairToChangeTargetDiagram); + } + } +} + +/* + This function allows you to copy the values of a pytorch tensor + to a vector in an optimized way. +*/ +#ifdef TTK_ENABLE_TORCH +int ttk::TopologicalOptimization::tensorToVectorFast( + const torch::Tensor &tensor, std::vector &result) const { + TORCH_CHECK( + tensor.dtype() == torch::kDouble, "The tensor must be of double type"); + const double *dataPtr = tensor.data_ptr(); + result.assign(dataPtr, dataPtr + tensor.numel()); + + return 0; +} +#endif + +template +int ttk::TopologicalOptimization::execute( + const dataType *const inputScalars, + dataType *const outputScalars, + SimplexId *const inputOffsets, + triangulationType *triangulation, + const ttk::DiagramType &constraintDiagram) const { + + Timer t; + double stoppingCondition = 0; + bool enableTorch = true; + + if(methodOptimization_ == 1) { +#ifndef TTK_ENABLE_TORCH + this->printWrn("Adam unavailable (Torch not found)."); + this->printWrn("Using direct gradient descent."); + enableTorch = false; +#endif + } + + //======================= + // Copy input data + //======================= + std::vector dataVector(vertexNumber_); + SimplexId *inputOffsetsCopie = inputOffsets; + +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(SimplexId k = 0; k < vertexNumber_; ++k) { + outputScalars[k] = inputScalars[k]; + dataVector[k] = inputScalars[k]; + if(std::isnan((double)outputScalars[k])) + outputScalars[k] = 0; + } + + //=============================== + // Normalize the data + //=============================== + + dataType minVal = *std::min_element(dataVector.begin(), dataVector.end()); + dataType maxVal = *std::max_element(dataVector.begin(), dataVector.end()); + +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(size_t i = 0; i < dataVector.size(); ++i) { + dataVector[i] = (dataVector[i] - minVal) / (maxVal - minVal); + } + + ttk::DiagramType normalizedConstraintDiagram(constraintDiagram.size()); + +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) { + auto pair = constraintDiagram[i]; + pair.birth.sfValue = (pair.birth.sfValue - minVal) / (maxVal - minVal); + pair.death.sfValue = (pair.death.sfValue - minVal) / (maxVal - minVal); + normalizedConstraintDiagram[i] = pair; + } + + std::vector losses; + std::vector inputScalarsX(vertexNumber_); + + //======================================== + // Direct gradient descent + //======================================== + if((methodOptimization_ == 0) || !(enableTorch)) { + std::vector listAllIndicesToChangeSmoothing(vertexNumber_, 0); + std::vector> pair2MatchedPair( + constraintDiagram.size(), std::vector(2)); + std::vector pairChangeMatchingPair(constraintDiagram.size(), -1); + std::vector> pair2Delete( + vertexNumber_, std::vector()); + std::vector> currentVertex2PairsCurrentDiagram( + vertexNumber_, std::vector()); + + for(int it = 0; it < epochNumber_; it++) { + + if(it % printFrequency_ == 0) { + debugLevel_ = 3; + } else { + debugLevel_ = 0; + } + + this->printMsg("DirectGradientDescent - iteration #" + std::to_string(it), + debug::Priority::PERFORMANCE); + + // pairs to change + std::vector birthPairToChangeCurrentDiagram{}; + std::vector birthPairToChangeTargetDiagram{}; + std::vector deathPairToChangeCurrentDiagram{}; + std::vector deathPairToChangeTargetDiagram{}; + + // pairs to delete + std::vector birthPairToDeleteCurrentDiagram{}; + std::vector birthPairToDeleteTargetDiagram{}; + std::vector deathPairToDeleteCurrentDiagram{}; + std::vector deathPairToDeleteTargetDiagram{}; + + std::vector vertexInHowManyPairs(vertexNumber_, 0); + + getIndices( + triangulation, inputOffsetsCopie, dataVector.data(), + normalizedConstraintDiagram, it, listAllIndicesToChangeSmoothing, + pair2MatchedPair, pair2Delete, pairChangeMatchingPair, + birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram, + deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram, + birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram, + deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram, + currentVertex2PairsCurrentDiagram, vertexInHowManyPairs); + std::fill(listAllIndicesToChangeSmoothing.begin(), + listAllIndicesToChangeSmoothing.end(), 0); + + //========================================================================== + // Retrieve the indices for the pairs that we want to send diagonally + //========================================================================== + double lossDeletePairs = 0; + + std::vector &indexBirthPairToDelete + = birthPairToDeleteCurrentDiagram; + std::vector &targetValueBirthPairToDelete + = birthPairToDeleteTargetDiagram; + std::vector &indexDeathPairToDelete + = deathPairToDeleteCurrentDiagram; + std::vector &targetValueDeathPairToDelete + = deathPairToDeleteTargetDiagram; + + this->printMsg("DirectGradientDescent - Number of pairs to delete: " + + std::to_string(indexBirthPairToDelete.size()), + debug::Priority::DETAIL); + + std::vector vertexInCellMultiple(vertexNumber_, -1); + std::vector> vertexToTargetValue( + vertexNumber_, std::vector()); + + if(indexBirthPairToDelete.size() == indexDeathPairToDelete.size()) { + for(size_t i = 0; i < indexBirthPairToDelete.size(); i++) { + lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]] + - targetValueBirthPairToDelete[i], + 2) + + std::pow(dataVector[indexDeathPairToDelete[i]] + - targetValueDeathPairToDelete[i], + 2); + SimplexId indexMax = indexBirthPairToDelete[i]; + SimplexId indexSelle = indexDeathPairToDelete[i]; + + if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) { + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexMax] == 1) { + dataVector[indexMax] + = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } else { + vertexInCellMultiple[indexMax] = 1; + vertexToTargetValue[indexMax].push_back( + targetValueBirthPairToDelete[i]); + } + + if(vertexInHowManyPairs[indexSelle] == 1) { + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } else { + vertexInCellMultiple[indexSelle] = 1; + vertexToTargetValue[indexSelle].push_back( + targetValueDeathPairToDelete[i]); + } + } else { + dataVector[indexMax] = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } + } else if(finePairManagement_ == 1) { + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexSelle] == 1) { + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } else { + vertexInCellMultiple[indexSelle] = 1; + vertexToTargetValue[indexSelle].push_back( + targetValueDeathPairToDelete[i]); + } + } else { + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } + } else if(finePairManagement_ == 2) { + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexMax] == 1) { + dataVector[indexMax] + = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } else { + vertexInCellMultiple[indexMax] = 1; + vertexToTargetValue[indexMax].push_back( + targetValueBirthPairToDelete[i]); + } + } else { + dataVector[indexMax] = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } + } + } + } else { + for(size_t i = 0; i < indexBirthPairToDelete.size(); i++) { + lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]] + - targetValueBirthPairToDelete[i], + 2); + SimplexId indexMax = indexBirthPairToDelete[i]; + + if(!(finePairManagement_ == 1)) { + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexMax] == 1) { + dataVector[indexMax] + = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } else { + vertexInCellMultiple[indexMax] = 1; + vertexToTargetValue[indexMax].push_back( + targetValueBirthPairToDelete[i]); + } + } else { + dataVector[indexMax] = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] + - targetValueBirthPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } + } else { // finePairManagement_ == 1 + continue; + } + } + + for(size_t i = 0; i < indexDeathPairToDelete.size(); i++) { + lossDeletePairs += std::pow(dataVector[indexDeathPairToDelete[i]] + - targetValueDeathPairToDelete[i], + 2); + SimplexId indexSelle = indexDeathPairToDelete[i]; + + if(!(finePairManagement_ == 2)) { + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexSelle] == 1) { + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } else { + vertexInCellMultiple[indexSelle] = 1; + vertexToTargetValue[indexSelle].push_back( + targetValueDeathPairToDelete[i]); + } + } else { + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToDelete[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } + } else { // finePairManagement_ == 2 + continue; + } + } + } + + this->printMsg("DirectGradientDescent - Loss Delete Pairs: " + + std::to_string(lossDeletePairs), + debug::Priority::PERFORMANCE); + //========================================================================== + // Retrieve the indices for the pairs that we want to change + //========================================================================== + double lossChangePairs = 0; + + std::vector &indexBirthPairToChange + = birthPairToChangeCurrentDiagram; + std::vector &targetValueBirthPairToChange + = birthPairToChangeTargetDiagram; + std::vector &indexDeathPairToChange + = deathPairToChangeCurrentDiagram; + std::vector &targetValueDeathPairToChange + = deathPairToChangeTargetDiagram; + + for(size_t i = 0; i < indexBirthPairToChange.size(); i++) { + lossChangePairs += std::pow(dataVector[indexBirthPairToChange[i]] + - targetValueBirthPairToChange[i], + 2) + + std::pow(dataVector[indexDeathPairToChange[i]] + - targetValueDeathPairToChange[i], + 2); + + SimplexId indexMax = indexBirthPairToChange[i]; + SimplexId indexSelle = indexDeathPairToChange[i]; + + if(constraintAveraging_) { + if(vertexInHowManyPairs[indexMax] == 1) { + dataVector[indexMax] + = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] - targetValueBirthPairToChange[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + } else { + vertexInCellMultiple[indexMax] = 1; + vertexToTargetValue[indexMax].push_back( + targetValueBirthPairToChange[i]); + } + + if(vertexInHowManyPairs[indexSelle] == 1) { + dataVector[indexSelle] = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] + - targetValueDeathPairToChange[i]); + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } else { + vertexInCellMultiple[indexSelle] = 1; + vertexToTargetValue[indexSelle].push_back( + targetValueDeathPairToChange[i]); + } + } else { + dataVector[indexMax] + = dataVector[indexMax] + - alpha_ * 2 + * (dataVector[indexMax] - targetValueBirthPairToChange[i]); + dataVector[indexSelle] + = dataVector[indexSelle] + - alpha_ * 2 + * (dataVector[indexSelle] - targetValueDeathPairToChange[i]); + listAllIndicesToChangeSmoothing[indexMax] = 1; + listAllIndicesToChangeSmoothing[indexSelle] = 1; + } + } + + this->printMsg("DirectGradientDescent - Loss Change Pairs: " + + std::to_string(lossChangePairs), + debug::Priority::PERFORMANCE); + + if(constraintAveraging_) { + for(SimplexId i = 0; i < (SimplexId)vertexInCellMultiple.size(); i++) { + double averageTargetValue = 0; + + if(vertexInCellMultiple[i] == 1) { + for(auto targetValue : vertexToTargetValue[i]) { + averageTargetValue += targetValue; + } + averageTargetValue + = averageTargetValue / (int)vertexToTargetValue[i].size(); + + dataVector[i] = dataVector[i] + - alpha_ * 2 * (dataVector[i] - averageTargetValue); + listAllIndicesToChangeSmoothing[i] = 1; + } + } + } + + //================================== + // Stop Condition + //================================== + + if(it == 0) { + stoppingCondition + = coefStopCondition_ * (lossDeletePairs + lossChangePairs); + } + + if(((lossDeletePairs + lossChangePairs) <= stoppingCondition)) + break; + } + +//======================================================== +// De-normalize data & Update output data +//======================================================== +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(SimplexId k = 0; k < vertexNumber_; ++k) { + outputScalars[k] = dataVector[k] * (maxVal - minVal) + minVal; + } + } + +//======================================= +// Adam Optimization +//======================================= +#ifdef TTK_ENABLE_TORCH + else if(methodOptimization_ == 1) { + //===================================================== + // Initialization of model parameters + //===================================================== + torch::Tensor F + = torch::from_blob(dataVector.data(), {SimplexId(dataVector.size())}, + torch::dtype(torch::kDouble)) + .to(torch::kDouble); + PersistenceGradientDescent model(F); + + torch::optim::Adam optimizer(model.parameters(), learningRate_); + + //======================================= + // Optimization + //======================================= + + std::vector> pair2MatchedPair( + constraintDiagram.size(), std::vector(2)); + std::vector pairChangeMatchingPair(constraintDiagram.size(), -1); + std::vector listAllIndicesToChange(vertexNumber_, 0); + std::vector> pair2Delete( + vertexNumber_, std::vector()); + std::vector> currentVertex2PairsCurrentDiagram( + vertexNumber_, std::vector()); + + for(int i = 0; i < epochNumber_; i++) { + + if(i % printFrequency_ == 0) { + debugLevel_ = 3; + } else { + debugLevel_ = 0; + } + + this->printMsg( + "Adam - epoch: " + std::to_string(i), debug::Priority::PERFORMANCE); + + ttk::Timer timeOneIteration; + + // Update the tensor with the new optimized values + tensorToVectorFast(model.X.to(torch::kDouble), inputScalarsX); + + // pairs to change + std::vector birthPairToChangeCurrentDiagram{}; + std::vector birthPairToChangeTargetDiagram{}; + std::vector deathPairToChangeCurrentDiagram{}; + std::vector deathPairToChangeTargetDiagram{}; + + // pairs to delete + std::vector birthPairToDeleteCurrentDiagram{}; + std::vector birthPairToDeleteTargetDiagram{}; + std::vector deathPairToDeleteCurrentDiagram{}; + std::vector deathPairToDeleteTargetDiagram{}; + + std::vector vertexInHowManyPairs(vertexNumber_, 0); + + // Retrieve the indices of the critical points that we must modify in + // order to match our current diagram to our target diagram. + getIndices( + triangulation, inputOffsetsCopie, inputScalarsX.data(), + normalizedConstraintDiagram, i, listAllIndicesToChange, + pair2MatchedPair, pair2Delete, pairChangeMatchingPair, + birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram, + deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram, + birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram, + deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram, + currentVertex2PairsCurrentDiagram, vertexInHowManyPairs); + + std::fill( + listAllIndicesToChange.begin(), listAllIndicesToChange.end(), 0); + //========================================================================== + // Retrieve the indices for the pairs that we want to send diagonally + //========================================================================== + + torch::Tensor valueOfXDeleteBirth = torch::index_select( + model.X, 0, torch::tensor(birthPairToDeleteCurrentDiagram)); + auto valueDeleteBirth = torch::from_blob( + birthPairToDeleteTargetDiagram.data(), + {static_cast(birthPairToDeleteTargetDiagram.size())}, + torch::kDouble); + torch::Tensor valueOfXDeleteDeath = torch::index_select( + model.X, 0, torch::tensor(deathPairToDeleteCurrentDiagram)); + auto valueDeleteDeath = torch::from_blob( + deathPairToDeleteTargetDiagram.data(), + {static_cast(deathPairToDeleteTargetDiagram.size())}, + torch::kDouble); + + torch::Tensor lossDeletePairs = torch::zeros({1}, torch::kDouble); + if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) { + lossDeletePairs + = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2)); + lossDeletePairs + = lossDeletePairs + + torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2)); + } else if(finePairManagement_ == 1) { + lossDeletePairs + = torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2)); + } else if(finePairManagement_ == 2) { + lossDeletePairs + = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2)); + } + + this->printMsg("Adam - Loss Delete Pairs: " + + std::to_string(lossDeletePairs.item()), + debug::Priority::PERFORMANCE); + + //========================================================================== + // Retrieve the indices for the pairs that we want to change + //========================================================================== + + torch::Tensor valueOfXChangeBirth = torch::index_select( + model.X, 0, torch::tensor(birthPairToChangeCurrentDiagram)); + auto valueChangeBirth = torch::from_blob( + birthPairToChangeTargetDiagram.data(), + {static_cast(birthPairToChangeTargetDiagram.size())}, + torch::kDouble); + torch::Tensor valueOfXChangeDeath = torch::index_select( + model.X, 0, torch::tensor(deathPairToChangeCurrentDiagram)); + auto valueChangeDeath = torch::from_blob( + deathPairToChangeTargetDiagram.data(), + {static_cast(deathPairToChangeTargetDiagram.size())}, + torch::kDouble); + + auto lossChangePairs + = torch::sum((torch::pow(valueOfXChangeBirth - valueChangeBirth, 2) + + torch::pow(valueOfXChangeDeath - valueChangeDeath, 2))); + + this->printMsg("Adam - Loss Change Pairs: " + + std::to_string(lossChangePairs.item()), + debug::Priority::PERFORMANCE); + + //==================================== + // Definition of final loss + //==================================== + + auto loss = lossDeletePairs + lossChangePairs; + + this->printMsg("Adam - Loss: " + std::to_string(loss.item()), + debug::Priority::PERFORMANCE); + + //========================================== + // Back Propagation + //========================================== + + losses.push_back(loss.item()); + + ttk::Timer timeBackPropagation; + optimizer.zero_grad(); + loss.backward(); + optimizer.step(); + + //========================================== + // Modified index checking + //========================================== + + // On trouve les indices qui ont changĂ© + std::vector NewinputScalarsX(vertexNumber_); + tensorToVectorFast(model.X.to(torch::kDouble), NewinputScalarsX); + +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(SimplexId k = 0; k < vertexNumber_; ++k) { + double diff = NewinputScalarsX[k] - inputScalarsX[k]; + if(diff != 0) { + listAllIndicesToChange[k] = 1; + } + } + + //======================================= + // Stop condition + //======================================= + if(i == 0) { + stoppingCondition = coefStopCondition_ * loss.item(); + } + + if(loss.item() < stoppingCondition) + break; + } + +//============================================ +// De-normalize data & Update output data +//============================================ +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for num_threads(threadNumber_) +#endif + for(SimplexId k = 0; k < vertexNumber_; ++k) { + outputScalars[k] + = model.X[k].item().to() * (maxVal - minVal) + minVal; + if(std::isnan((double)outputScalars[k])) + outputScalars[k] = 0; + } + } +#endif + //======================================== + // Information display + //======================================== + debugLevel_ = 3; + + // Total execution time + double time = t.getElapsedTime(); + + // Number Pairs Constraint Diagram + SimplexId numberPairsConstraintDiagram = (SimplexId)constraintDiagram.size(); + this->printMsg("Number of constrained pairs: " + + std::to_string(numberPairsConstraintDiagram), + debug::Priority::PERFORMANCE); + + this->printMsg("Stopping condition: " + std::to_string(stoppingCondition), + debug::Priority::PERFORMANCE); + + this->printMsg("Scalar field optimized", 1.0, time, this->threadNumber_); + + return 0; +} diff --git a/core/base/topologicalSimplification/CMakeLists.txt b/core/base/topologicalSimplification/CMakeLists.txt index 96a445acbd..74883e8f4b 100644 --- a/core/base/topologicalSimplification/CMakeLists.txt +++ b/core/base/topologicalSimplification/CMakeLists.txt @@ -7,4 +7,5 @@ ttk_add_base_library(topologicalSimplification triangulation legacyTopologicalSimplification localizedTopologicalSimplification + topologicalOptimization ) diff --git a/core/base/topologicalSimplification/TopologicalSimplification.h b/core/base/topologicalSimplification/TopologicalSimplification.h index d0c817d764..2b518ee768 100644 --- a/core/base/topologicalSimplification/TopologicalSimplification.h +++ b/core/base/topologicalSimplification/TopologicalSimplification.h @@ -16,13 +16,18 @@ /// \b Related \b publications \n /// "Generalized Topological Simplification of Scalar Fields on Surfaces" \n /// Julien Tierny, Valerio Pascucci \n -/// Proc. of IEEE VIS 2012.\n -/// IEEE Transactions on Visualization and Computer Graphics, 2012. +/// IEEE Transactions on Visualization and Computer Graphics.\n +/// Proc. of IEEE VIS 2012. /// -/// "Localized Topological Simplification of Scalar Data" -/// Jonas Lukasczyk, Christoph Garth, Ross Maciejewski, Julien Tierny +/// "Localized Topological Simplification of Scalar Data" \n +/// Jonas Lukasczyk, Christoph Garth, Ross Maciejewski, Julien Tierny \n +/// IEEE Transactions on Visualization and Computer Graphics.\n /// Proc. of IEEE VIS 2020. -/// IEEE Transactions on Visualization and Computer Graphics +/// +/// "A Practical Solver for Scalar Data Topological Simplification"\n +/// Mohamed Kissi, Mathieu Pont, Joshua A. Levine, Julien Tierny\n +/// IEEE Transactions on Visualization and Computer Graphics.\n +/// Proc. of IEEE VIS 2024. /// /// \sa ttkTopologicalSimplification.cpp %for a usage example. /// @@ -92,9 +97,11 @@ #pragma once // base code includes + #include #include #include +#include #include #include @@ -108,7 +115,7 @@ namespace ttk { public: TopologicalSimplification(); - enum class BACKEND { LEGACY, LTS }; + enum class BACKEND { LEGACY, LTS, TO }; /* * Either execute this file "legacy" algorithm, or the * lts algorithm. The choice depends on the value of the variable backend_. @@ -122,7 +129,8 @@ namespace ttk { SimplexId *const offsets, const SimplexId constraintNumber, const bool addPerturbation, - const triangulationType &triangulation); + triangulationType &triangulation, + const ttk::DiagramType &constraintDiagram = {}); inline void setBackend(const BACKEND arg) { backend_ = arg; @@ -142,6 +150,12 @@ namespace ttk { ltsObject_.preconditionTriangulation(triangulation); break; + case BACKEND::TO: + topologyOptimizer_.setDebugLevel(debugLevel_); + topologyOptimizer_.setThreadNumber(threadNumber_); + topologyOptimizer_.preconditionTriangulation(triangulation); + break; + default: this->printErr( "Error, the backend for topological simplification is invalid"); @@ -154,6 +168,48 @@ namespace ttk { BACKEND backend_{BACKEND::LTS}; LegacyTopologicalSimplification legacyObject_; lts::LocalizedTopologicalSimplification ltsObject_; + ttk::TopologicalOptimization topologyOptimizer_; + + SimplexId vertexNumber_{}; + bool UseFastPersistenceUpdate{true}; + bool FastAssignmentUpdate{true}; + int EpochNumber{1000}; + + // if PDCMethod == 0 then we use Progressive approach + // if PDCMethod == 1 then we use Classical Auction approach + int PDCMethod{1}; + + // if MethodOptimization == 0 then we use direct optimization + // if MethodOptimization == 1 then we use Adam + int MethodOptimization{0}; + + // if FinePairManagement == 0 then we let the algorithm choose + // if FinePairManagement == 1 then we fill the domain + // if FinePairManagement == 2 then we cut the domain + int FinePairManagement{0}; + + // Adam + bool ChooseLearningRate{false}; + double LearningRate{0.0001}; + + // Direct Optimization : Gradient Step Size + double Alpha{0.5}; + + // Stopping criterion: when the loss becomes less than a percentage (e.g. + // 1%) of the original loss (between input diagram and simplified diagram) + double CoefStopCondition{0.01}; + + // + bool OptimizationWithoutMatching{false}; + int ThresholdMethod{1}; + double Threshold{0.01}; + int LowerThreshold{-1}; + int UpperThreshold{2}; + int PairTypeToDelete{1}; + + bool ConstraintAveraging{true}; + + int PrintFrequency{10}; }; } // namespace ttk @@ -166,7 +222,8 @@ int ttk::TopologicalSimplification::execute( SimplexId *const offsets, const SimplexId constraintNumber, const bool addPerturbation, - const triangulationType &triangulation) { + triangulationType &triangulation, + const ttk::DiagramType &constraintDiagram) { switch(backend_) { case BACKEND::LTS: return ltsObject_ @@ -178,6 +235,29 @@ int ttk::TopologicalSimplification::execute( inputOffsets, offsets, constraintNumber, triangulation); + case BACKEND::TO: + topologyOptimizer_.setUseFastPersistenceUpdate(UseFastPersistenceUpdate); + topologyOptimizer_.setFastAssignmentUpdate(FastAssignmentUpdate); + topologyOptimizer_.setEpochNumber(EpochNumber); + topologyOptimizer_.setPDCMethod(PDCMethod); + topologyOptimizer_.setMethodOptimization(MethodOptimization); + topologyOptimizer_.setFinePairManagement(FinePairManagement); + topologyOptimizer_.setChooseLearningRate(ChooseLearningRate); + topologyOptimizer_.setLearningRate(LearningRate); + topologyOptimizer_.setAlpha(Alpha); + topologyOptimizer_.setCoefStopCondition(CoefStopCondition); + topologyOptimizer_.setOptimizationWithoutMatching( + OptimizationWithoutMatching); + topologyOptimizer_.setThresholdMethod(ThresholdMethod); + topologyOptimizer_.setThresholdPersistence(Threshold); + topologyOptimizer_.setLowerThreshold(LowerThreshold); + topologyOptimizer_.setUpperThreshold(UpperThreshold); + topologyOptimizer_.setPairTypeToDelete(PairTypeToDelete); + topologyOptimizer_.setConstraintAveraging(ConstraintAveraging); + topologyOptimizer_.setPrintFrequency(PrintFrequency); + + return topologyOptimizer_.execute(inputScalars, outputScalars, offsets, + &triangulation, constraintDiagram); default: this->printErr( "Error, the backend for topological simplification is invalid"); diff --git a/core/vtk/ttkTopologicalSimplification/ttk.module b/core/vtk/ttkTopologicalSimplification/ttk.module index 70ca33cdcc..a800dd40f3 100644 --- a/core/vtk/ttkTopologicalSimplification/ttk.module +++ b/core/vtk/ttkTopologicalSimplification/ttk.module @@ -7,3 +7,4 @@ HEADERS DEPENDS topologicalSimplification ttkAlgorithm + ttkPersistenceDiagram \ No newline at end of file diff --git a/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.cpp b/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.cpp index ceb3695e2f..34429f0b2f 100644 --- a/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.cpp +++ b/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -25,7 +26,7 @@ int ttkTopologicalSimplification::FillInputPortInformation( info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkDataSet"); return 1; } else if(port == 1) { - info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkPointSet"); + info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkUnstructuredGrid"); return 1; } return 0; @@ -47,13 +48,16 @@ int ttkTopologicalSimplification::RequestData( using ttk::SimplexId; - // Warning: this needs to be done before the preconditioning. - if(!this->UseLTS) { + if(this->Method == 0) { + this->setBackend(BACKEND::LTS); + } else if(this->Method == 1) { this->setBackend(BACKEND::LEGACY); + } else if(this->Method == 2) { + this->setBackend(BACKEND::TO); } const auto domain = vtkDataSet::GetData(inputVector[0]); - const auto constraints = vtkPointSet::GetData(inputVector[1]); + const auto constraints = vtkUnstructuredGrid::GetData(inputVector[1]); if(!domain || !constraints) return !this->printErr("Unable to retrieve required input data objects."); @@ -100,6 +104,11 @@ int ttkTopologicalSimplification::RequestData( return -1; } + // Constraints + ttk::DiagramType constraintDiagram; + const ttk::Debug dbg; + VTUToDiagram(constraintDiagram, constraints, dbg); + // create output arrays auto outputScalars = vtkSmartPointer::Take(inputScalars->NewInstance()); @@ -124,7 +133,7 @@ int ttkTopologicalSimplification::RequestData( ttkUtils::GetPointer(inputOrder), ttkUtils::GetPointer(outputOrder), numberOfConstraints, this->AddPerturbation, - *triangulation->getData())); + *triangulation->getData(), constraintDiagram)); } // something wrong in baseCode diff --git a/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.h b/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.h index 63509c13f3..c4c666eb9c 100644 --- a/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.h +++ b/core/vtk/ttkTopologicalSimplification/ttkTopologicalSimplification.h @@ -38,13 +38,18 @@ /// \b Related \b publications \n /// "Generalized Topological Simplification of Scalar Fields on Surfaces" \n /// Julien Tierny, Valerio Pascucci \n -/// Proc. of IEEE VIS 2012.\n -/// IEEE Transactions on Visualization and Computer Graphics, 2012. +/// IEEE Transactions on Visualization and Computer Graphics.\n +/// Proc. of IEEE VIS 2012. /// -/// "Localized Topological Simplification of Scalar Data" -/// Jonas Lukasczyk, Christoph Garth, Ross Maciejewski, Julien Tierny +/// "Localized Topological Simplification of Scalar Data" \n +/// Jonas Lukasczyk, Christoph Garth, Ross Maciejewski, Julien Tierny \n +/// IEEE Transactions on Visualization and Computer Graphics.\n /// Proc. of IEEE VIS 2020. -/// IEEE Transactions on Visualization and Computer Graphics +/// +/// "A Practical Solver for Scalar Data Topological Simplification"\n +/// Mohamed Kissi, Mathieu Pont, Joshua A. Levine, Julien Tierny\n +/// IEEE Transactions on Visualization and Computer Graphics.\n +/// Proc. of IEEE VIS 2024. /// /// \sa ttkTopologicalSimplificationByPersistence /// \sa ttkScalarFieldCriticalPoints @@ -126,6 +131,8 @@ // ttk code includes #include #include +#include +#include class vtkDataArray; @@ -149,12 +156,66 @@ class TTKTOPOLOGICALSIMPLIFICATION_EXPORT ttkTopologicalSimplification vtkSetMacro(ForceInputVertexScalarField, bool); vtkGetMacro(ForceInputVertexScalarField, bool); - vtkSetMacro(UseLTS, bool); - vtkGetMacro(UseLTS, bool); + vtkSetMacro(Method, int); + vtkGetMacro(Method, int); vtkSetMacro(PersistenceThreshold, double); vtkGetMacro(PersistenceThreshold, double); + vtkSetMacro(UseFastPersistenceUpdate, bool); + vtkGetMacro(UseFastPersistenceUpdate, bool); + + vtkSetMacro(FastAssignmentUpdate, bool); + vtkGetMacro(FastAssignmentUpdate, bool); + + vtkSetMacro(EpochNumber, int); + vtkGetMacro(EpochNumber, int); + + vtkSetMacro(PDCMethod, int); + vtkGetMacro(PDCMethod, int); + + vtkSetMacro(MethodOptimization, int); + vtkGetMacro(MethodOptimization, int); + + vtkSetMacro(FinePairManagement, int); + vtkGetMacro(FinePairManagement, int); + + vtkSetMacro(ChooseLearningRate, bool); + vtkGetMacro(ChooseLearningRate, bool); + + vtkSetMacro(LearningRate, double); + vtkGetMacro(LearningRate, double); + + vtkSetMacro(Alpha, double); + vtkGetMacro(Alpha, double); + + vtkSetMacro(CoefStopCondition, double); + vtkGetMacro(CoefStopCondition, double); + + vtkSetMacro(OptimizationWithoutMatching, bool); + vtkGetMacro(OptimizationWithoutMatching, bool); + + vtkSetMacro(ThresholdMethod, int); + vtkGetMacro(ThresholdMethod, int); + + vtkSetMacro(Threshold, double); + vtkGetMacro(Threshold, double); + + vtkSetMacro(LowerThreshold, int); + vtkGetMacro(LowerThreshold, int); + + vtkSetMacro(UpperThreshold, int); + vtkGetMacro(UpperThreshold, int); + + vtkSetMacro(PairTypeToDelete, int); + vtkGetMacro(PairTypeToDelete, int); + + vtkSetMacro(ConstraintAveraging, bool); + vtkGetMacro(ConstraintAveraging, bool); + + vtkSetMacro(PrintFrequency, int); + vtkGetMacro(PrintFrequency, int); + protected: ttkTopologicalSimplification(); @@ -169,6 +230,6 @@ class TTKTOPOLOGICALSIMPLIFICATION_EXPORT ttkTopologicalSimplification bool ForceInputOffsetScalarField{false}; bool ConsiderIdentifierAsBlackList{false}; bool AddPerturbation{false}; - bool UseLTS{true}; + int Method{0}; double PersistenceThreshold{0}; }; diff --git a/paraview/xmls/TopologicalSimplification.xml b/paraview/xmls/TopologicalSimplification.xml index 61c13ee7bb..dd8dd1e8b3 100644 --- a/paraview/xmls/TopologicalSimplification.xml +++ b/paraview/xmls/TopologicalSimplification.xml @@ -30,13 +30,18 @@ Related publications: "Generalized Topological Simplification of Scalar Fields on Surfaces" Julien Tierny, Valerio Pascucci + IEEE Transactions on Visualization and Computer Graphics. Proc. of IEEE VIS 2012. - IEEE Transactions on Visualization and Computer Graphics, 2012. "Localized Topological Simplification of Scalar Data" Jonas Lukasczyk, Christoph Garth, Ross Maciejewski, Julien Tierny + IEEE Transactions on Visualization and Computer Graphics. Proc. of IEEE VIS 2020. - IEEE Transactions on Visualization and Computer Graphics + + "A Practical Solver for Scalar Data Topological Simplification" + Mohamed Kissi, Mathieu Pont, Joshua A. Levine, Julien Tierny + IEEE Transactions on Visualization and Computer Graphics. + Proc. of IEEE VIS 2024. See also ScalarFieldCriticalPoints, IntegralLines, ContourForests, Identifiers. @@ -52,7 +57,7 @@ Identifiers. - https://topology-tool-kit.github.io/examples/BuiltInExample1/ - https://topology-tool-kit.github.io/examples/contourTreeAlignment/ - + - https://topology-tool-kit.github.io/examples/ctBones/ - https://topology-tool-kit.github.io/examples/dragon/ @@ -66,7 +71,7 @@ Identifiers. - https://topology-tool-kit.github.io/examples/karhunenLoveDigits64Dimensions/ - https://topology-tool-kit.github.io/examples/morsePersistence/ - + - https://topology-tool-kit.github.io/examples/morseSmaleQuadrangulation/ - https://topology-tool-kit.github.io/examples/persistenceClustering0/ @@ -82,7 +87,7 @@ Identifiers. - https://topology-tool-kit.github.io/examples/tectonicPuzzle/ - https://topology-tool-kit.github.io/examples/tribute/ - + - https://topology-tool-kit.github.io/examples/uncertainStartingVortex/ @@ -143,6 +148,23 @@ Identifiers. + + + + + + + + Choose the simplification algorithm. + + + - + + + + + Check this box to force the usage of a specific input scalar field as vertex offset (used to disambiguate flat plateaus). @@ -185,6 +214,7 @@ Identifiers. + + + + + @@ -229,34 +266,483 @@ Identifiers. command="SetConsiderIdentifierAsBlackList" label="Remove selected extrema" number_of_elements="1" - default_values="0"> + default_values="0" + panel_visibility="advanced" + > + + + + Check this box to remove the selected extrema (instead of removing all non-selected extrema). + + + + + + + + + + + Employed backend for gradient descent. Direct gradient descent provides superior time performance with regard to automatic differentiation with Adam. + + + + + + + + + + + + + + Backend for Wasserstein distance computation. + The Auction algorithm is computationally more expensive than + the progressive approach, but more accurate. + + + + + + + + + + + Check this box to use fast persistence update (i.e. the persistence diagram will not be completely recomputed from scratch but only the required information will be updated). + + + + + + + + + + + Check this box to use the fast assignement update (i.e. persistence pairs which are still between consecutive iterations will maintain their assignments). + + + + + + + + + + + Coefficient used in the stopping condition of the algorithm: if the fraction between the current loss and the original loss (between the input and simplified diagrams) is smaller that this coefficient, the algorithm stops. + + + + + + + + + + + + + Maximum Iteration Number (if the stopping condition has not been satisfied yet). + + + + + + + + + + + + + + + + + Select the persistence pair cancellation primitive. For illustration, for pairs associated to topological handles of the sublevel sets, the primitive Fill-only will destroy a handle by filling a disc in its inside (only the death gradient is used). Cut-only will cut the handle (only the birth gradient is used). Fill and Cut will produce a compromise between the two (both birth and death gradients are used). + + + + + + + + + + + + + + Check this box to choose learning rate. + + + + + + + + + + + + + + Learning Rate. + + + + + + + + + + + + + Choose the gradient step size. + + + + + + + + + + If a given vertex is involved in both signal pairs (i.e. pairs to maintain) and non-signal pairs (i.e. pairs to remove), average the contributions of the constraints (otherwise, the vertex value will not change). + + + + + + + + + + + Enable on-line ad-hoc simplification (i.e., specify the non-signal pairs and the non-signal pairs will be re-evaluated at each iteration). Faster but less precise and more restrictive. + + + + + + + + + + + + + + - Use the Localized Topological Simplification algorithm. + . + + + + + + + + + + + Threshold value. + + + + + + + + + + + + + + Lower Threshold value. + + + + + + + + + + + + + + Upper Threshold value. + + + + + + + + + + + + + + Pair type to delete value. + + + + + + + + + + A print is made every PrintFrequency iterations. + + + + default_values="0" + panel_visibility="advanced" + > + + + + Numerically perturb the output (to avoid the usage of an output offset field for flat plateau disambiguation). @@ -267,6 +753,7 @@ Identifiers. + @@ -274,8 +761,23 @@ Identifiers. + + + + + + + + + + + + + + + + -