From 0e2a86035b8be2ee4f0b82ca683db4f79585e2dd Mon Sep 17 00:00:00 2001 From: MatPont Date: Tue, 17 Sep 2024 15:20:36 +0200 Subject: [PATCH] [MTNN] merge tree neural network refactor --- core/base/ftmTree/FTMNode.h | 16 + core/base/ftmTree/FTMTreeUtils.cpp | 78 +- core/base/ftmTree/FTMTreeUtils.h | 6 +- core/base/ftmTree/FTMTreeUtils_Template.h | 57 +- core/base/ftmTree/FTMTree_MT.h | 133 +- core/base/mergeTreeAutoencoder/CMakeLists.txt | 5 +- .../MergeTreeAutoencoder.cpp | 2446 ++--------------- .../MergeTreeAutoencoder.h | 432 +-- .../MergeTreeAutoencoderUtils.cpp | 394 --- .../MergeTreeAutoencoderUtils.h | 66 - .../MergeTreeAutoencoderDecoding.cpp | 35 +- .../mergeTreeClustering/MergeTreeBarycenter.h | 147 +- core/base/mergeTreeClustering/MergeTreeBase.h | 92 +- .../mergeTreeClustering/MergeTreeDistance.h | 66 +- .../base/mergeTreeClustering/MergeTreeUtils.h | 12 +- .../mergeTreeNeuralNetwork/CMakeLists.txt | 22 + .../MergeTreeNeuralBase.cpp | 319 +++ .../MergeTreeNeuralBase.h | 175 ++ .../MergeTreeNeuralLayer.cpp | 1317 +++++++++ .../MergeTreeNeuralLayer.h | 604 ++++ .../MergeTreeNeuralNetwork.cpp | 1324 +++++++++ .../MergeTreeNeuralNetwork.h | 880 ++++++ .../MergeTreeTorchUtils.cpp | 42 +- .../MergeTreeTorchUtils.h | 32 +- .../MergeTreeAxesAlgorithmBase.cpp | 30 +- .../MergeTreeAxesAlgorithmBase.h | 154 +- .../MergeTreeAxesAlgorithmUtils.cpp | 14 + .../MergeTreeAxesAlgorithmUtils.h | 74 + .../MergeTreePrincipalGeodesics.h | 5 +- core/vtk/ttkMergeTreeAutoencoder/ttk.module | 4 +- .../ttkMergeTreeAutoencoder.cpp | 346 +-- .../ttkMergeTreeAutoencoder.h | 9 - ...cpp => ttkMergeTreeNeuralNetworkUtils.cpp} | 381 ++- ...ils.h => ttkMergeTreeNeuralNetworkUtils.h} | 92 +- .../ttk.module | 2 +- .../ttkMergeTreeAutoencoderDecoding.cpp | 40 +- .../ttkMergeTreePrincipalGeodesics.cpp | 2 +- ...ttkMergeTreePrincipalGeodesicsDecoding.cpp | 6 +- 38 files changed, 6135 insertions(+), 3724 deletions(-) delete mode 100644 core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.cpp delete mode 100644 core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.h create mode 100644 core/base/mergeTreeNeuralNetwork/CMakeLists.txt create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.cpp create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.h create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.cpp create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.h create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.cpp create mode 100644 core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.h rename core/base/{mergeTreeAutoencoder => mergeTreeNeuralNetwork}/MergeTreeTorchUtils.cpp (85%) rename core/base/{mergeTreeAutoencoder => mergeTreeNeuralNetwork}/MergeTreeTorchUtils.h (95%) rename core/vtk/ttkMergeTreeAutoencoder/{ttkMergeTreeAutoencoderUtils.cpp => ttkMergeTreeNeuralNetworkUtils.cpp} (50%) rename core/vtk/ttkMergeTreeAutoencoder/{ttkMergeTreeAutoencoderUtils.h => ttkMergeTreeNeuralNetworkUtils.h} (73%) diff --git a/core/base/ftmTree/FTMNode.h b/core/base/ftmTree/FTMNode.h index a15b07243f..c8eefa679f 100644 --- a/core/base/ftmTree/FTMNode.h +++ b/core/base/ftmTree/FTMNode.h @@ -154,6 +154,22 @@ namespace ttk { } } + inline void removeDownSuperArcs(std::vector &idSa) { + if(idSa.empty()) + return; + std::vector toDelete( + (*std::max_element(idSa.begin(), idSa.end())) + 1, false); + for(auto &id : idSa) + toDelete[id] = true; + vect_downSuperArcList_.erase( + std::remove_if(vect_downSuperArcList_.begin(), + vect_downSuperArcList_.end(), + [&toDelete](const idSuperArc &i) { + return i < toDelete.size() and toDelete[i]; + }), + vect_downSuperArcList_.end()); + } + // Find and remove the arc inline void removeUpSuperArc(idSuperArc idSa) { for(idSuperArc i = 0; i < vect_upSuperArcList_.size(); ++i) { diff --git a/core/base/ftmTree/FTMTreeUtils.cpp b/core/base/ftmTree/FTMTreeUtils.cpp index 24d44a9e9d..dff0bb508c 100644 --- a/core/base/ftmTree/FTMTreeUtils.cpp +++ b/core/base/ftmTree/FTMTreeUtils.cpp @@ -13,34 +13,34 @@ namespace ttk { // -------------------- // Is // -------------------- - bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) { + bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) const { unsigned int const origin = (unsigned int)this->getNode(nodeId)->getOrigin(); return origin != nullNodes && origin < this->getNumberOfNodes(); } - bool FTMTree_MT::isRoot(idNode nodeId) { + bool FTMTree_MT::isRoot(idNode nodeId) const { return this->getNode(nodeId)->getNumberOfUpSuperArcs() == 0; } - bool FTMTree_MT::isLeaf(idNode nodeId) { + bool FTMTree_MT::isLeaf(idNode nodeId) const { return this->getNode(nodeId)->getNumberOfDownSuperArcs() == 0; } - bool FTMTree_MT::isNodeAlone(idNode nodeId) { + bool FTMTree_MT::isNodeAlone(idNode nodeId) const { return this->isRoot(nodeId) and this->isLeaf(nodeId); } - bool FTMTree_MT::isFullMerge() { + bool FTMTree_MT::isFullMerge() const { idNode const treeRoot = this->getRoot(); return (unsigned int)this->getNode(treeRoot)->getOrigin() == treeRoot; } - bool FTMTree_MT::isBranchOrigin(idNode nodeId) { + bool FTMTree_MT::isBranchOrigin(idNode nodeId) const { return this->getParentSafe(this->getNode(nodeId)->getOrigin()) != nodeId; } - bool FTMTree_MT::isNodeMerged(idNode nodeId) { + bool FTMTree_MT::isNodeMerged(idNode nodeId) const { bool merged = this->isNodeAlone(nodeId) or this->isNodeAlone(this->getNode(nodeId)->getOrigin()); auto nodeIdOrigin = this->getNode(nodeId)->getOrigin(); @@ -49,11 +49,11 @@ namespace ttk { return merged; } - bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) { + bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) const { return nodeId >= this->getNumberOfNodes(); } - bool FTMTree_MT::isThereOnlyOnePersistencePair() { + bool FTMTree_MT::isThereOnlyOnePersistencePair() const { idNode const treeRoot = this->getRoot(); unsigned int cptNodeAlone = 0; idNode otherNode = treeRoot; @@ -74,7 +74,7 @@ namespace ttk { } // Do not normalize node is if root or son of a merged root - bool FTMTree_MT::notNeedToNormalize(idNode nodeId) { + bool FTMTree_MT::notNeedToNormalize(idNode nodeId) const { auto nodeIdParent = this->getParentSafe(nodeId); return this->isRoot(nodeId) or (this->isRoot(nodeIdParent) @@ -84,7 +84,7 @@ namespace ttk { // and nodeIdOrigin == nodeIdParent) ) } - bool FTMTree_MT::isMultiPersPair(idNode nodeId) { + bool FTMTree_MT::isMultiPersPair(idNode nodeId) const { auto nodeOriginOrigin = (unsigned int)this->getNode(this->getNode(nodeId)->getOrigin()) ->getOrigin(); @@ -94,14 +94,14 @@ namespace ttk { // -------------------- // Get // -------------------- - idNode FTMTree_MT::getRoot() { + idNode FTMTree_MT::getRoot() const { for(idNode node = 0; node < this->getNumberOfNodes(); ++node) if(this->isRoot(node) and !this->isLeaf(node)) return node; return nullNodes; } - idNode FTMTree_MT::getParentSafe(idNode nodeId) { + idNode FTMTree_MT::getParentSafe(idNode nodeId) const { if(!this->isRoot(nodeId)) { // _ Nodes in merge trees should have only one parent idSuperArc const arcId = this->getNode(nodeId)->getUpSuperArcId(0); @@ -112,7 +112,7 @@ namespace ttk { } void FTMTree_MT::getChildren(idNode nodeId, - std::vector &childrens) { + std::vector &childrens) const { childrens.clear(); for(idSuperArc i = 0; i < this->getNode(nodeId)->getNumberOfDownSuperArcs(); ++i) { @@ -121,7 +121,7 @@ namespace ttk { } } - void FTMTree_MT::getLeavesFromTree(std::vector &treeLeaves) { + void FTMTree_MT::getLeavesFromTree(std::vector &treeLeaves) const { treeLeaves.clear(); for(idNode i = 0; i < this->getNumberOfNodes(); ++i) { if(this->isLeaf(i) and !this->isRoot(i)) @@ -129,25 +129,26 @@ namespace ttk { } } - int FTMTree_MT::getNumberOfLeavesFromTree() { + int FTMTree_MT::getNumberOfLeavesFromTree() const { std::vector leaves; this->getLeavesFromTree(leaves); return leaves.size(); } - int FTMTree_MT::getNumberOfNodeAlone() { + int FTMTree_MT::getNumberOfNodeAlone() const { int cpt = 0; for(idNode i = 0; i < this->getNumberOfNodes(); ++i) cpt += this->isNodeAlone(i) ? 1 : 0; return cpt; } - int FTMTree_MT::getRealNumberOfNodes() { + int FTMTree_MT::getRealNumberOfNodes() const { return this->getNumberOfNodes() - this->getNumberOfNodeAlone(); } void FTMTree_MT::getBranchOriginsFromThisBranch( - idNode node, std::tuple, std::vector> &res) { + idNode node, + std::tuple, std::vector> &res) const { std::vector branchOrigins, nonBranchOrigins; idNode const nodeOrigin = this->getNode(node)->getOrigin(); @@ -166,7 +167,7 @@ namespace ttk { void FTMTree_MT::getTreeBranching( std::vector &branching, std::vector &branchingID, - std::vector> &nodeBranching) { + std::vector> &nodeBranching) const { branching = std::vector(this->getNumberOfNodes()); branchingID = std::vector(this->getNumberOfNodes(), -1); nodeBranching @@ -200,19 +201,19 @@ namespace ttk { } void FTMTree_MT::getTreeBranching(std::vector &branching, - std::vector &branchingID) { + std::vector &branchingID) const { std::vector> nodeBranching; this->getTreeBranching(branching, branchingID, nodeBranching); } - void FTMTree_MT::getAllRoots(std::vector &roots) { + void FTMTree_MT::getAllRoots(std::vector &roots) const { roots.clear(); for(idNode node = 0; node < this->getNumberOfNodes(); ++node) if(this->isRoot(node) and !this->isLeaf(node)) roots.push_back(node); } - int FTMTree_MT::getNumberOfRoot() { + int FTMTree_MT::getNumberOfRoot() const { int noRoot = 0; for(idNode node = 0; node < this->getNumberOfNodes(); ++node) if(this->isRoot(node) and !this->isLeaf(node)) @@ -220,11 +221,11 @@ namespace ttk { return noRoot; } - int FTMTree_MT::getNumberOfChildren(idNode nodeId) { + int FTMTree_MT::getNumberOfChildren(idNode nodeId) const { return this->getNode(nodeId)->getNumberOfDownSuperArcs(); } - int FTMTree_MT::getTreeDepth() { + int FTMTree_MT::getTreeDepth() const { int maxDepth = 0; std::queue> queue; queue.emplace(this->getRoot(), 0); @@ -242,7 +243,7 @@ namespace ttk { return maxDepth; } - int FTMTree_MT::getNodeLevel(idNode nodeId) { + int FTMTree_MT::getNodeLevel(idNode nodeId) const { int level = 0; auto root = this->getRoot(); int const noRoot = this->getNumberOfRoot(); @@ -261,7 +262,7 @@ namespace ttk { return level; } - void FTMTree_MT::getAllNodeLevel(std::vector &allNodeLevel) { + void FTMTree_MT::getAllNodeLevel(std::vector &allNodeLevel) const { allNodeLevel = std::vector(this->getNumberOfNodes()); std::queue> queue; queue.emplace(this->getRoot(), 0); @@ -279,7 +280,7 @@ namespace ttk { } void FTMTree_MT::getLevelToNode( - std::vector> &levelToNode) { + std::vector> &levelToNode) const { std::vector allNodeLevel; this->getAllNodeLevel(allNodeLevel); int const maxLevel @@ -290,9 +291,10 @@ namespace ttk { } } - void FTMTree_MT::getBranchSubtree(std::vector &branching, - idNode branchRoot, - std::vector &branchSubtree) { + void + FTMTree_MT::getBranchSubtree(std::vector &branching, + idNode branchRoot, + std::vector &branchSubtree) const { branchSubtree.clear(); std::queue queue; queue.push(branchRoot); @@ -316,7 +318,7 @@ namespace ttk { // Persistence // -------------------- void FTMTree_MT::getMultiPersOriginsVectorFromTree( - std::vector> &treeMultiPers) { + std::vector> &treeMultiPers) const { treeMultiPers = std::vector>(this->getNumberOfNodes()); for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i) @@ -398,7 +400,7 @@ namespace ttk { // -------------------- // Create/Delete/Modify Tree // -------------------- - void FTMTree_MT::copyMergeTreeStructure(FTMTree_MT *tree) { + void FTMTree_MT::copyMergeTreeStructure(const FTMTree_MT *tree) { // Add Nodes for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) this->makeNode(i); @@ -418,7 +420,7 @@ namespace ttk { // -------------------- // Utils // -------------------- - void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) { + void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) const { ss << "(" << node << ") \\ "; std::vector children; @@ -431,7 +433,7 @@ namespace ttk { ss << std::endl; } - std::stringstream FTMTree_MT::printSubTree(idNode subRoot) { + std::stringstream FTMTree_MT::printSubTree(idNode subRoot) const { std::stringstream ss; ss << "Nodes----------" << std::endl; std::queue queue; @@ -450,7 +452,7 @@ namespace ttk { return ss; } - std::stringstream FTMTree_MT::printTree(bool doPrint) { + std::stringstream FTMTree_MT::printTree(bool doPrint) const { std::stringstream ss; std::vector allRoots; this->getAllRoots(allRoots); @@ -471,7 +473,7 @@ namespace ttk { return ss; } - std::stringstream FTMTree_MT::printTreeStats(bool doPrint) { + std::stringstream FTMTree_MT::printTreeStats(bool doPrint) const { auto noNodesT = this->getNumberOfNodes(); auto noNodes = this->getRealNumberOfNodes(); std::stringstream ss; @@ -483,7 +485,7 @@ namespace ttk { } std::stringstream - FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) { + FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) const { std::stringstream ss; std::vector> vec; this->getMultiPersOriginsVectorFromTree(vec); diff --git a/core/base/ftmTree/FTMTreeUtils.h b/core/base/ftmTree/FTMTreeUtils.h index f3dacda04e..c04b80c0b9 100644 --- a/core/base/ftmTree/FTMTreeUtils.h +++ b/core/base/ftmTree/FTMTreeUtils.h @@ -148,7 +148,7 @@ namespace ttk { } template - void getTreeScalars(ftm::FTMTree_MT *tree, + void getTreeScalars(const ftm::FTMTree_MT *tree, std::vector &scalarsVector) { scalarsVector.clear(); for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) @@ -162,7 +162,7 @@ namespace ttk { } template - MergeTree copyMergeTree(ftm::FTMTree_MT *tree, + MergeTree copyMergeTree(const ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs = false) { std::vector scalarsVector; getTreeScalars(tree, scalarsVector); @@ -201,7 +201,7 @@ namespace ttk { } template - MergeTree copyMergeTree(MergeTree &mergeTree, + MergeTree copyMergeTree(const MergeTree &mergeTree, bool doSplitMultiPersPairs = false) { return copyMergeTree(&(mergeTree.tree), doSplitMultiPersPairs); } diff --git a/core/base/ftmTree/FTMTreeUtils_Template.h b/core/base/ftmTree/FTMTreeUtils_Template.h index 119befe4d0..7c2659121f 100644 --- a/core/base/ftmTree/FTMTreeUtils_Template.h +++ b/core/base/ftmTree/FTMTreeUtils_Template.h @@ -16,7 +16,7 @@ namespace ttk { // Is // -------------------- template - bool FTMTree_MT::isJoinTree() { + bool FTMTree_MT::isJoinTree() const { auto root = this->getRoot(); std::vector rootChildren; this->getChildren(root, rootChildren); @@ -38,7 +38,7 @@ namespace ttk { bool FTMTree_MT::isImportantPair(idNode nodeId, double threshold, std::vector &excludeLower, - std::vector &excludeHigher) { + std::vector &excludeHigher) const { dataType rootPers = this->getNodePersistence(this->getRoot()); if(threshold > 1) threshold /= 100.0; @@ -57,14 +57,14 @@ namespace ttk { } template - bool FTMTree_MT::isImportantPair(idNode nodeId, double threshold) { + bool FTMTree_MT::isImportantPair(idNode nodeId, double threshold) const { std::vector excludeLower, excludeHigher; return this->isImportantPair( nodeId, threshold, excludeLower, excludeHigher); } template - bool FTMTree_MT::isParentInconsistent(idNode nodeId) { + bool FTMTree_MT::isParentInconsistent(idNode nodeId) const { auto parentBirthDeath = this->getBirthDeath(this->getParentSafe(nodeId)); dataType parentBirth = std::get<0>(parentBirthDeath); @@ -78,7 +78,7 @@ namespace ttk { } template - bool FTMTree_MT::verifyBranchDecompositionInconsistency() { + bool FTMTree_MT::verifyBranchDecompositionInconsistency() const { bool inconsistency = false; std::queue queue; queue.emplace(this->getRoot()); @@ -103,7 +103,7 @@ namespace ttk { // Get // -------------------- template - idNode FTMTree_MT::getMergedRootOrigin() { + idNode FTMTree_MT::getMergedRootOrigin() const { dataType maxPers = std::numeric_limits::lowest(); int maxIndex = -1; auto root = this->getRoot(); @@ -121,7 +121,7 @@ namespace ttk { } template - idNode FTMTree_MT::getLowestNode(idNode nodeStart) { + idNode FTMTree_MT::getLowestNode(idNode nodeStart) const { idNode lowestNode = nodeStart; bool isJT = this->isJoinTree(); dataType bestVal = isJT ? std::numeric_limits::max() @@ -149,7 +149,7 @@ namespace ttk { // -------------------- template std::tuple - FTMTree_MT::getBirthDeathFromIds(idNode nodeId1, idNode nodeId2) { + FTMTree_MT::getBirthDeathFromIds(idNode nodeId1, idNode nodeId2) const { dataType scalar1 = this->getValue(nodeId1); dataType scalar2 = this->getValue(nodeId2); dataType birth = std::min(scalar1, scalar2); @@ -159,7 +159,8 @@ namespace ttk { template std::tuple - FTMTree_MT::getBirthDeathNodeFromIds(idNode nodeId1, idNode nodeId2) { + FTMTree_MT::getBirthDeathNodeFromIds(idNode nodeId1, + idNode nodeId2) const { auto nodeValue = this->getValue(nodeId1); auto node2Value = this->getValue(nodeId2); auto nodeBirth = (nodeValue < node2Value ? nodeId1 : nodeId2); @@ -168,7 +169,8 @@ namespace ttk { } template - std::tuple FTMTree_MT::getBirthDeath(idNode nodeId) { + std::tuple + FTMTree_MT::getBirthDeath(idNode nodeId) const { // Avoid error if origin is not defined if(this->isNodeOriginDefined(nodeId)) { return this->getBirthDeathFromIds( @@ -179,7 +181,7 @@ namespace ttk { template std::tuple - FTMTree_MT::getBirthDeathNode(idNode nodeId) { + FTMTree_MT::getBirthDeathNode(idNode nodeId) const { if(this->isNodeOriginDefined(nodeId)) { return this->getBirthDeathNodeFromIds( nodeId, this->getNode(nodeId)->getOrigin()); @@ -188,7 +190,7 @@ namespace ttk { } template - std::tuple FTMTree_MT::getMergedRootBirthDeath() { + std::tuple FTMTree_MT::getMergedRootBirthDeath() const { if(!this->isFullMerge()) return this->getBirthDeath(this->getRoot()); return this->getBirthDeathFromIds( @@ -197,7 +199,7 @@ namespace ttk { template std::tuple - FTMTree_MT::getMergedRootBirthDeathNode() { + FTMTree_MT::getMergedRootBirthDeathNode() const { if(!this->isFullMerge()) return this->getBirthDeathNode(this->getRoot()); return this->getBirthDeathNodeFromIds( @@ -205,19 +207,19 @@ namespace ttk { } template - dataType FTMTree_MT::getBirth(idNode nodeId) { + dataType FTMTree_MT::getBirth(idNode nodeId) const { return std::get<0>(this->getBirthDeath(nodeId)); } template - dataType FTMTree_MT::getNodePersistence(idNode nodeId) { + dataType FTMTree_MT::getNodePersistence(idNode nodeId) const { std::tuple birthDeath = this->getBirthDeath(nodeId); return std::get<1>(birthDeath) - std::get<0>(birthDeath); } template - dataType FTMTree_MT::getMaximumPersistence() { + dataType FTMTree_MT::getMaximumPersistence() const { idNode const root = this->getRoot(); bool const fullMerge = this->isFullMerge(); @@ -236,7 +238,7 @@ namespace ttk { } template - ftm::idNode FTMTree_MT::getSecondMaximumPersistenceNode() { + ftm::idNode FTMTree_MT::getSecondMaximumPersistenceNode() const { idNode const root = this->getRoot(); dataType pers = std::numeric_limits::lowest(); ftm::idNode nodeSecMax = -1; @@ -258,14 +260,15 @@ namespace ttk { } template - dataType FTMTree_MT::getSecondMaximumPersistence() { + dataType FTMTree_MT::getSecondMaximumPersistence() const { return this->getNodePersistence( this->getSecondMaximumPersistenceNode()); } template void FTMTree_MT::getPersistencePairsFromTree( - std::vector> &pairs, bool useBD) { + std::vector> &pairs, + bool useBD) const { std::vector nodes; if(useBD) { for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i) @@ -286,7 +289,7 @@ namespace ttk { } template - std::vector FTMTree_MT::getMultiPersOrigins(bool useBD) { + std::vector FTMTree_MT::getMultiPersOrigins(bool useBD) const { std::vector multiPersOrigins; std::vector> pairs; @@ -314,7 +317,8 @@ namespace ttk { // Utils // -------------------- template - std::stringstream FTMTree_MT::printNode2(idNode nodeId, bool doPrint) { + std::stringstream FTMTree_MT::printNode2(idNode nodeId, + bool doPrint) const { auto origin = this->getNode(nodeId)->getOrigin(); std::stringstream ss; ss << "nodeId = " << nodeId << " (" << this->getValue(nodeId) @@ -327,7 +331,7 @@ namespace ttk { } template - std::stringstream FTMTree_MT::printMergedRoot(bool doPrint) { + std::stringstream FTMTree_MT::printMergedRoot(bool doPrint) const { std::stringstream ss; ss << this->getRoot() << " (" << this->getValue(this->getRoot()) << ") _ "; @@ -346,7 +350,7 @@ namespace ttk { template std::stringstream FTMTree_MT::printTreeScalars(bool printNodeAlone, - bool doPrint) { + bool doPrint) const { std::stringstream wholeSS; std::streamsize const sSize = std::cout.precision(); for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i) { @@ -373,7 +377,7 @@ namespace ttk { template std::stringstream FTMTree_MT::printPairsFromTree(bool useBD, bool printPairs, - bool doPrint) { + bool doPrint) const { std::stringstream ss; std::vector> pairs; this->getPersistencePairsFromTree(pairs, useBD); @@ -395,9 +399,8 @@ namespace ttk { } template - std::stringstream FTMTree_MT::printMultiPersPairsFromTree(bool useBD, - bool printPairs, - bool doPrint) { + std::stringstream FTMTree_MT::printMultiPersPairsFromTree( + bool useBD, bool printPairs, bool doPrint) const { std::vector> pairs; this->getPersistencePairsFromTree(pairs, useBD); std::vector noOrigin(this->getNumberOfNodes(), 0); diff --git a/core/base/ftmTree/FTMTree_MT.h b/core/base/ftmTree/FTMTree_MT.h index 64196f00c4..c8cf470dc7 100644 --- a/core/base/ftmTree/FTMTree_MT.h +++ b/core/base/ftmTree/FTMTree_MT.h @@ -390,7 +390,7 @@ namespace ttk { return mt_data_.nodes->size(); } - inline Node *getNode(idNode nodeId) { + inline Node *getNode(idNode nodeId) const { return &((*mt_data_.nodes)[nodeId]); } @@ -599,146 +599,149 @@ namespace ttk { // -------------------- // Is // -------------------- - bool isNodeOriginDefined(idNode nodeId); + bool isNodeOriginDefined(idNode nodeId) const; - bool isRoot(idNode nodeId); + bool isRoot(idNode nodeId) const; - bool isLeaf(idNode nodeId); + bool isLeaf(idNode nodeId) const; - bool isNodeAlone(idNode nodeId); + bool isNodeAlone(idNode nodeId) const; - bool isFullMerge(); + bool isFullMerge() const; - bool isBranchOrigin(idNode nodeId); + bool isBranchOrigin(idNode nodeId) const; template - bool isJoinTree(); + bool isJoinTree() const; template bool isImportantPair(idNode nodeId, double threshold, std::vector &excludeLower, - std::vector &excludeHigher); + std::vector &excludeHigher) const; template - bool isImportantPair(idNode nodeId, double threshold); + bool isImportantPair(idNode nodeId, double threshold) const; - bool isNodeMerged(idNode nodeId); + bool isNodeMerged(idNode nodeId) const; - bool isNodeIdInconsistent(idNode nodeId); + bool isNodeIdInconsistent(idNode nodeId) const; - bool isThereOnlyOnePersistencePair(); + bool isThereOnlyOnePersistencePair() const; // Do not normalize node is if root or son of a merged root - bool notNeedToNormalize(idNode nodeId); + bool notNeedToNormalize(idNode nodeId) const; - bool isMultiPersPair(idNode nodeId); + bool isMultiPersPair(idNode nodeId) const; template - bool isParentInconsistent(ftm::idNode nodeId); + bool isParentInconsistent(ftm::idNode nodeId) const; template - bool verifyBranchDecompositionInconsistency(); + bool verifyBranchDecompositionInconsistency() const; // -------------------- // Get // -------------------- - idNode getRoot(); + idNode getRoot() const; - idNode getParentSafe(idNode nodeId); + idNode getParentSafe(idNode nodeId) const; - void getChildren(idNode nodeId, std::vector &res); + void getChildren(idNode nodeId, std::vector &res) const; - void getLeavesFromTree(std::vector &res); + void getLeavesFromTree(std::vector &res) const; - int getNumberOfLeavesFromTree(); + int getNumberOfLeavesFromTree() const; - int getNumberOfNodeAlone(); + int getNumberOfNodeAlone() const; - int getRealNumberOfNodes(); + int getRealNumberOfNodes() const; template - idNode getMergedRootOrigin(); + idNode getMergedRootOrigin() const; void getBranchOriginsFromThisBranch( - idNode node, std::tuple, std::vector> &res); + idNode node, + std::tuple, std::vector> &res) const; - void getTreeBranching(std::vector &branching, - std::vector &branchingID, - std::vector> &nodeBranching); + void + getTreeBranching(std::vector &branching, + std::vector &branchingID, + std::vector> &nodeBranching) const; void getTreeBranching(std::vector &branching, - std::vector &branchingID); + std::vector &branchingID) const; - void getAllRoots(std::vector &res); + void getAllRoots(std::vector &res) const; - int getNumberOfRoot(); + int getNumberOfRoot() const; - int getNumberOfChildren(idNode nodeId); + int getNumberOfChildren(idNode nodeId) const; - int getTreeDepth(); + int getTreeDepth() const; - int getNodeLevel(idNode nodeId); + int getNodeLevel(idNode nodeId) const; - void getAllNodeLevel(std::vector &res); + void getAllNodeLevel(std::vector &res) const; - void getLevelToNode(std::vector> &res); + void getLevelToNode(std::vector> &res) const; void getBranchSubtree(std::vector &branching, idNode branchRoot, - std::vector &res); + std::vector &res) const; template - idNode getLowestNode(idNode nodeStart); + idNode getLowestNode(idNode nodeStart) const; // -------------------- // Persistence // -------------------- template std::tuple getBirthDeathFromIds(idNode nodeId1, - idNode nodeId2); + idNode nodeId2) const; template - std::tuple getBirthDeathNodeFromIds(idNode nodeId1, - idNode nodeId2); + std::tuple + getBirthDeathNodeFromIds(idNode nodeId1, idNode nodeId2) const; template - std::tuple getBirthDeath(idNode nodeId); + std::tuple getBirthDeath(idNode nodeId) const; template - std::tuple getBirthDeathNode(idNode nodeId); + std::tuple + getBirthDeathNode(idNode nodeId) const; template - std::tuple getMergedRootBirthDeath(); + std::tuple getMergedRootBirthDeath() const; template - std::tuple getMergedRootBirthDeathNode(); + std::tuple getMergedRootBirthDeathNode() const; template - dataType getBirth(idNode nodeId); + dataType getBirth(idNode nodeId) const; template - dataType getNodePersistence(idNode nodeId); + dataType getNodePersistence(idNode nodeId) const; template - dataType getMaximumPersistence(); + dataType getMaximumPersistence() const; template - ftm::idNode getSecondMaximumPersistenceNode(); + ftm::idNode getSecondMaximumPersistenceNode() const; template - dataType getSecondMaximumPersistence(); + dataType getSecondMaximumPersistence() const; template void getPersistencePairsFromTree( std::vector> &pairs, - bool useBD); + bool useBD) const; template - std::vector getMultiPersOrigins(bool useBD); + std::vector getMultiPersOrigins(bool useBD) const; void getMultiPersOriginsVectorFromTree( - std::vector> &res); + std::vector> &res) const; // -------------------- // Set @@ -762,41 +765,41 @@ namespace ttk { // -------------------- // Create/Delete/Modify Tree // -------------------- - void copyMergeTreeStructure(FTMTree_MT *tree); + void copyMergeTreeStructure(const FTMTree_MT *tree); // -------------------- // Utils // -------------------- - void printNodeSS(idNode node, std::stringstream &ss); + void printNodeSS(idNode node, std::stringstream &ss) const; template - std::stringstream printNode2(idNode nodeId, bool doPrint = true); + std::stringstream printNode2(idNode nodeId, bool doPrint = true) const; template - std::stringstream printMergedRoot(bool doPrint = true); + std::stringstream printMergedRoot(bool doPrint = true) const; - std::stringstream printSubTree(idNode subRoot); + std::stringstream printSubTree(idNode subRoot) const; - std::stringstream printTree(bool doPrint = true); + std::stringstream printTree(bool doPrint = true) const; - std::stringstream printTreeStats(bool doPrint = true); + std::stringstream printTreeStats(bool doPrint = true) const; template std::stringstream printTreeScalars(bool printNodeAlone = true, - bool doPrint = true); + bool doPrint = true) const; template std::stringstream printPairsFromTree(bool useBD = false, bool printPairs = true, - bool doPrint = true); + bool doPrint = true) const; std::stringstream printMultiPersOriginsVectorFromTree(bool doPrint - = true); + = true) const; template std::stringstream printMultiPersPairsFromTree(bool useBD = false, bool printPairs = true, - bool doPrint = true); + bool doPrint = true) const; // ---------------------------------------- // End of utils functions diff --git a/core/base/mergeTreeAutoencoder/CMakeLists.txt b/core/base/mergeTreeAutoencoder/CMakeLists.txt index ed4ddd9dfb..6c9bbd0327 100644 --- a/core/base/mergeTreeAutoencoder/CMakeLists.txt +++ b/core/base/mergeTreeAutoencoder/CMakeLists.txt @@ -1,14 +1,11 @@ ttk_add_base_library(mergeTreeAutoencoder SOURCES MergeTreeAutoencoder.cpp - MergeTreeAutoencoderUtils.cpp - MergeTreeTorchUtils.cpp HEADERS MergeTreeAutoencoder.h - MergeTreeAutoencoderUtils.h - MergeTreeTorchUtils.h DEPENDS mergeTreePrincipalGeodesics + mergeTreeNeuralNetwork geometry ) diff --git a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.cpp b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.cpp index 05bd234cff..3f729fc4a8 100644 --- a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.cpp +++ b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.cpp @@ -1,5 +1,4 @@ #include -#include #include #ifdef TTK_ENABLE_TORCH @@ -12,419 +11,6 @@ ttk::MergeTreeAutoencoder::MergeTreeAutoencoder() { } #ifdef TTK_ENABLE_TORCH -// --------------------------------------------------------------------------- -// --- Init -// --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::initOutputBasisTreeStructure( - mtu::TorchMergeTree &originPrime, - bool isJT, - mtu::TorchMergeTree &baseOrigin) { - // ----- Create scalars vector - std::vector scalarsVector( - originPrime.tensor.data_ptr(), - originPrime.tensor.data_ptr() + originPrime.tensor.numel()); - unsigned int noNodes = scalarsVector.size() / 2; - std::vector> childrenFinal(noNodes); - - // ----- Init tree structure and modify scalars if necessary - if(isPersistenceDiagram_) { - for(unsigned int i = 2; i < scalarsVector.size(); i += 2) - childrenFinal[0].emplace_back(i / 2); - } else { - // --- Fix or swap min-max pair - float maxPers = std::numeric_limits::lowest(); - unsigned int indMax = 0; - for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { - if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) { - maxPers = (scalarsVector[i + 1] - scalarsVector[i]); - indMax = i; - } - } - if(indMax != 0) { - float temp = scalarsVector[0]; - scalarsVector[0] = scalarsVector[indMax]; - scalarsVector[indMax] = temp; - temp = scalarsVector[1]; - scalarsVector[1] = scalarsVector[indMax + 1]; - scalarsVector[indMax + 1] = temp; - } - ftm::idNode refNode = 0; - for(unsigned int i = 2; i < scalarsVector.size(); i += 2) { - ftm::idNode node = i / 2; - wae::adjustNestingScalars(scalarsVector, node, refNode); - } - - if(not initOriginPrimeStructByCopy_ - or (int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) { - // --- Get possible children and parent relations - std::vector> parents(noNodes), children(noNodes); - for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { - for(unsigned int j = i; j < scalarsVector.size(); j += 2) { - if(i == j) - continue; - unsigned int iN = i / 2, jN = j / 2; - if(scalarsVector[i] <= scalarsVector[j] - and scalarsVector[i + 1] >= scalarsVector[j + 1]) { - // - i is parent of j - parents[jN].emplace_back(iN); - children[iN].emplace_back(jN); - } else if(scalarsVector[i] >= scalarsVector[j] - and scalarsVector[i + 1] <= scalarsVector[j + 1]) { - // - j is parent of i - parents[iN].emplace_back(jN); - children[jN].emplace_back(iN); - } - } - } - wae::createBalancedBDT( - parents, children, scalarsVector, childrenFinal, this->threadNumber_); - } else { - ftm::MergeTree mTreeTemp - = ftm::copyMergeTree(baseOrigin.mTree); - bool useBD = true; - keepMostImportantPairs(&(mTreeTemp.tree), noNodes, useBD); - torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2}); - torch::Tensor order = torch::argsort( - (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1, - true); - std::vector nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0); - unsigned int nodeNum = 1; - std::queue queue; - queue.emplace(mTreeTemp.tree.getRoot()); - while(!queue.empty()) { - ftm::idNode node = queue.front(); - queue.pop(); - std::vector children; - mTreeTemp.tree.getChildren(node, children); - for(auto &child : children) { - queue.emplace(child); - unsigned int tNode = nodeCorr[node]; - nodeCorr[child] = order[nodeNum].item(); - ++nodeNum; - unsigned int tChild = nodeCorr[child]; - childrenFinal[tNode].emplace_back(tChild); - wae::adjustNestingScalars(scalarsVector, tChild, tNode); - } - } - } - } - - // ----- Create new tree - originPrime.mTree = ftm::createEmptyMergeTree(scalarsVector.size()); - ftm::FTMTree_MT *tree = &(originPrime.mTree.tree); - if(isJT) { - for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { - float temp = scalarsVector[i]; - scalarsVector[i] = scalarsVector[i + 1]; - scalarsVector[i + 1] = temp; - } - } - ftm::setTreeScalars(originPrime.mTree, scalarsVector); - - // ----- Create tree structure - originPrime.nodeCorr.clear(); - originPrime.nodeCorr.assign( - scalarsVector.size(), std::numeric_limits::max()); - for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { - tree->makeNode(i); - tree->makeNode(i + 1); - tree->getNode(i)->setOrigin(i + 1); - tree->getNode(i + 1)->setOrigin(i); - originPrime.nodeCorr[i] = (unsigned int)(i / 2); - } - for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { - unsigned int node = i / 2; - for(auto &child : childrenFinal[node]) - tree->makeSuperArc(child * 2, i); - } - mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri); - - if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) { - std::stringstream ss; - ss << originPrime.mTree.tree.printPairsFromTree(true).str() - << std::endl; - ss << "isTreeHasBigValues(originPrime.mTree)" << std::endl; - ss << "pause" << std::endl; - printMsg(ss.str()); - std::cin.get(); - } -} - -void ttk::MergeTreeAutoencoder::initOutputBasis(unsigned int l, - unsigned int dim, - unsigned int dim2) { - unsigned int originSize = origins_[l].tensor.sizes()[0]; - unsigned int origin2Size = 0; - if(useDoubleInput_) - origin2Size = origins2_[l].tensor.sizes()[0]; - - // --- Compute output basis origin - printMsg("Compute output basis origin", debug::Priority::DETAIL); - auto initOutputBasisOrigin = [this, &l](torch::Tensor &w, - mtu::TorchMergeTree &tmt, - mtu::TorchMergeTree &baseTmt) { - // - Create scalars - torch::nn::init::xavier_normal_(w); - torch::Tensor baseTmtTensor = baseTmt.tensor; - if(normalizedWasserstein_) - // Work on unnormalized tensor - mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor, false); - torch::Tensor b = torch::fill(torch::zeros({w.sizes()[0], 1}), 0.01); - tmt.tensor = (torch::matmul(w, baseTmtTensor) + b); - // - Shift to keep mean birth and max pers - mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor); - // - Shift to avoid diagonal points - mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor); - // - auto endLayer - = (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1); - if(trackingLossWeight_ != 0 and l < endLayer) { - auto baseTensor - = (l == 0 ? origins_[0].tensor : originsPrime_[l - 1].tensor); - auto baseTensorDiag = baseTensor.reshape({-1, 2}); - auto basePersDiag = (baseTensorDiag.index({Slice(), 1}) - - baseTensorDiag.index({Slice(), 0})); - auto tmtTensorDiag = tmt.tensor.reshape({-1, 2}); - auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1}) - - tmtTensorDiag.index({Slice(1, None), 0})); - int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]); - auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))}); - auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1; - indexes = torch::cat({torch::zeros(1), indexes}).to(torch::kLong); - if(trackingLossInitRandomness_ != 0) { - topVal = (1 - trackingLossInitRandomness_) * topVal - + trackingLossInitRandomness_ * tmtTensorDiag.index({indexes}); - } - tmtTensorDiag.index_put_({indexes}, topVal); - } - // - Create tree structure - initOutputBasisTreeStructure( - tmt, baseTmt.mTree.tree.isJoinTree(), baseTmt); - if(normalizedWasserstein_) - // Normalize tensor - mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor, true); - // - Projection - interpolationProjection(tmt); - }; - torch::Tensor w = torch::zeros({dim, originSize}); - initOutputBasisOrigin(w, originsPrime_[l], origins_[l]); - torch::Tensor w2; - if(useDoubleInput_) { - w2 = torch::zeros({dim2, origin2Size}); - initOutputBasisOrigin(w2, origins2Prime_[l], origins2_[l]); - } - - // --- Compute output basis vectors - printMsg("Compute output basis vectors", debug::Priority::DETAIL); - initOutputBasisVectors(l, w, w2); -} - -void ttk::MergeTreeAutoencoder::initOutputBasisVectors(unsigned int l, - torch::Tensor &w, - torch::Tensor &w2) { - vSPrimeTensor_[l] = torch::matmul(w, vSTensor_[l]); - if(useDoubleInput_) - vS2PrimeTensor_[l] = torch::matmul(w2, vS2Tensor_[l]); - if(normalizedWasserstein_) { - mtu::normalizeVectors(originsPrime_[l].tensor, vSPrimeTensor_[l]); - if(useDoubleInput_) - mtu::normalizeVectors(origins2Prime_[l].tensor, vS2PrimeTensor_[l]); - } -} - -void ttk::MergeTreeAutoencoder::initOutputBasisVectors(unsigned int l, - unsigned int dim, - unsigned int dim2) { - unsigned int originSize = origins_[l].tensor.sizes()[0]; - unsigned int origin2Size = 0; - if(useDoubleInput_) - origin2Size = origins2_[l].tensor.sizes()[0]; - torch::Tensor w = torch::zeros({dim, originSize}); - torch::nn::init::xavier_normal_(w); - torch::Tensor w2 = torch::zeros({dim2, origin2Size}); - torch::nn::init::xavier_normal_(w2); - initOutputBasisVectors(l, w, w2); -} - -void ttk::MergeTreeAutoencoder::initInputBasisOrigin( - std::vector> &treesToUse, - std::vector> &trees2ToUse, - double barycenterSizeLimitPercent, - unsigned int barycenterMaxNoPairs, - unsigned int barycenterMaxNoPairs2, - mtu::TorchMergeTree &origin, - mtu::TorchMergeTree &origin2, - std::vector &inputToBaryDistances, - std::vector>> - &baryMatchings, - std::vector>> - &baryMatchings2) { - computeOneBarycenter(treesToUse, origin.mTree, baryMatchings, - inputToBaryDistances, barycenterSizeLimitPercent, - useDoubleInput_); - if(barycenterMaxNoPairs > 0) - keepMostImportantPairs( - &(origin.mTree.tree), barycenterMaxNoPairs, true); - if(useDoubleInput_) { - std::vector baryDistances2; - computeOneBarycenter(trees2ToUse, origin2.mTree, baryMatchings2, - baryDistances2, barycenterSizeLimitPercent, - useDoubleInput_, false); - if(barycenterMaxNoPairs2 > 0) - keepMostImportantPairs( - &(origin2.mTree.tree), barycenterMaxNoPairs2, true); - for(unsigned int i = 0; i < inputToBaryDistances.size(); ++i) - inputToBaryDistances[i] - = mixDistances(inputToBaryDistances[i], baryDistances2[i]); - } - - mtu::getParentsVector(origin.mTree, origin.parentsOri); - mtu::mergeTreeToTorchTensor( - origin.mTree, origin.tensor, origin.nodeCorr, normalizedWasserstein_); - if(useDoubleInput_) { - mtu::getParentsVector(origin2.mTree, origin2.parentsOri); - mtu::mergeTreeToTorchTensor( - origin2.mTree, origin2.tensor, origin2.nodeCorr, normalizedWasserstein_); - } -} - -void ttk::MergeTreeAutoencoder::initInputBasisVectors( - std::vector> &tmTreesToUse, - std::vector> &tmTrees2ToUse, - std::vector> &treesToUse, - std::vector> &trees2ToUse, - mtu::TorchMergeTree &origin, - mtu::TorchMergeTree &origin2, - unsigned int noVectors, - std::vector> &allAlphasInit, - unsigned int l, - std::vector &inputToBaryDistances, - std::vector>> - &baryMatchings, - std::vector>> - &baryMatchings2, - torch::Tensor &vSTensor, - torch::Tensor &vS2Tensor) { - // --- Initialized vectors projection function to avoid collinearity - auto initializedVectorsProjection - = [=](int ttkNotUsed(_axeNumber), - ftm::MergeTree &ttkNotUsed(_barycenter), - std::vector> &_v, - std::vector> &ttkNotUsed(_v2), - std::vector>> &_vS, - std::vector>> &ttkNotUsed(_v2s), - ftm::MergeTree &ttkNotUsed(_barycenter2), - std::vector> &ttkNotUsed(_trees2V), - std::vector> &ttkNotUsed(_trees2V2), - std::vector>> &ttkNotUsed(_trees2Vs), - std::vector>> &ttkNotUsed(_trees2V2s), - bool ttkNotUsed(_useSecondInput), - unsigned int ttkNotUsed(_noProjectionStep)) { - std::vector scaledV, scaledVSi; - Geometry::flattenMultiDimensionalVector(_v, scaledV); - Geometry::scaleVector( - scaledV, 1.0 / Geometry::magnitude(scaledV), scaledV); - for(unsigned int i = 0; i < _vS.size(); ++i) { - Geometry::flattenMultiDimensionalVector(_vS[i], scaledVSi); - Geometry::scaleVector( - scaledVSi, 1.0 / Geometry::magnitude(scaledVSi), scaledVSi); - auto prod = Geometry::dotProduct(scaledV, scaledVSi); - double tol = 0.01; - if(prod <= -1.0 + tol or prod >= 1.0 - tol) { - // Reset vector to initialize it again - for(unsigned int j = 0; j < _v.size(); ++j) - for(unsigned int k = 0; k < _v[j].size(); ++k) - _v[j][k] = 0; - break; - } - } - return 0; - }; - - // --- Init vectors - std::vector> inputToAxesDistances; - std::vector>> vS, v2s, trees2Vs, trees2V2s; - std::stringstream ss; - for(unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) { - ss.str(""); - ss << "Compute vectors " << vecNum; - printMsg(ss.str(), debug::Priority::VERBOSE); - std::vector> v1, v2, trees2V1, trees2V2; - int newVectorOffset = 0; - bool projectInitializedVectors = true; - int bestIndex = MergeTreeAxesAlgorithmBase::initVectors( - vecNum, origin.mTree, treesToUse, origin2.mTree, trees2ToUse, v1, v2, - trees2V1, trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings, - baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s, - projectInitializedVectors, initializedVectorsProjection); - vS.emplace_back(v1); - v2s.emplace_back(v2); - trees2Vs.emplace_back(trees2V1); - trees2V2s.emplace_back(trees2V2); - - ss.str(""); - ss << "bestIndex = " << bestIndex; - printMsg(ss.str(), debug::Priority::VERBOSE); - - // Update inputToAxesDistances - printMsg("Update inputToAxesDistances", debug::Priority::VERBOSE); - inputToAxesDistances.resize(1, std::vector(treesToUse.size())); - if(bestIndex == -1 and normalizedWasserstein_) { - mtu::normalizeVectors(origin, vS[vS.size() - 1]); - if(useDoubleInput_) - mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]); - } - mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor); - if(useDoubleInput_) { - mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor); - } - mtu::TorchMergeTree dummyTmt; - std::vector> - dummyBaryMatching2; -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) \ - num_threads(this->threadNumber_) if(parallelize_) -#endif - for(unsigned int i = 0; i < treesToUse.size(); ++i) { - auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2ToUse[i]); - if(not euclideanVectorsInit_) { - unsigned int k = k_; - auto newAlpha = torch::ones({1, 1}); - if(bestIndex == -1) { - newAlpha = torch::zeros({1, 1}); - } - allAlphasInit[i][l] = (allAlphasInit[i][l].defined() - ? torch::cat({allAlphasInit[i][l], newAlpha}) - : newAlpha); - torch::Tensor bestAlphas; - bool isCalled = true; - inputToAxesDistances[0][i] = assignmentOneData( - tmTreesToUse[i], origin, vSTensor, tmt2ToUse, origin2, vS2Tensor, k, - allAlphasInit[i][l], bestAlphas, isCalled); - allAlphasInit[i][l] = bestAlphas.detach(); - } else { - auto &baryMatching2ToUse - = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]); - torch::Tensor alphas; - computeAlphas(tmTreesToUse[i], origin, vSTensor, origin, - baryMatchings[i], tmt2ToUse, origin2, vS2Tensor, origin2, - baryMatching2ToUse, alphas); - mtu::TorchMergeTree interpolated, interpolated2; - getMultiInterpolation(origin, vSTensor, alphas, interpolated); - if(useDoubleInput_) - getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2); - torch::Tensor tensorDist; - bool doSqrt = true; - getDifferentiableDistanceFromMatchings( - interpolated, tmTreesToUse[i], interpolated2, tmt2ToUse, - baryMatchings[i], baryMatching2ToUse, tensorDist, doSqrt); - inputToAxesDistances[0][i] = tensorDist.item(); - allAlphasInit[i][l] = alphas.detach(); - } - } - } -} - void ttk::MergeTreeAutoencoder::initClusteringLossParameters() { unsigned int l = getLatentLayerIndex(); unsigned int noCentroids @@ -458,10 +44,64 @@ void ttk::MergeTreeAutoencoder::initClusteringLossParameters() { } } +bool ttk::MergeTreeAutoencoder::initResetOutputBasis( + unsigned int l, + unsigned int layerNoAxes, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain) { + printMsg("Reset output basis", debug::Priority::DETAIL); + if((noLayers_ == 2 and l == 1) or noLayers_ == 1) { + initOutputBasisSpecialCase(l, layerNoAxes, trees, trees2); + } else if(l < (unsigned int)(noLayers_ / 2)) { + initOutputBasis(l, layerOriginPrimeSizePercent, trees, trees2, isTrain); + } else { + printErr("recs[i].mTree.tree.getRealNumberOfNodes() == 0"); + std::stringstream ssT; + ssT << "layer " << l; + printWrn(ssT.str()); + return true; + } + return false; +} + +void ttk::MergeTreeAutoencoder::initOutputBasisSpecialCase( + unsigned int l, + unsigned int layerNoAxes, + std::vector> &trees, + std::vector> &trees2) { + // - Compute Origin + printMsg("Compute output basis origin", debug::Priority::DETAIL); + layers_[l].setOriginPrime(layers_[0].getOrigin()); + if(useDoubleInput_) + layers_[l].setOrigin2Prime(layers_[0].getOrigin2()); + // - Compute vectors + printMsg("Compute output basis vectors", debug::Priority::DETAIL); + if(layerNoAxes != layers_[0].getVSTensor().sizes()[1]) { + // TODO is there a way to avoid copy of merge trees? + std::vector> treesToUse, trees2ToUse; + for(unsigned int i = 0; i < trees.size(); ++i) { + treesToUse.emplace_back(trees[i].mTree); + if(useDoubleInput_) + trees2ToUse.emplace_back(trees2[i].mTree); + } + std::vector allAlphasInitT(trees.size()); + layers_[l].initInputBasisVectors( + trees, trees2, treesToUse, trees2ToUse, layerNoAxes, allAlphasInitT, + inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_, false); + } else { + layers_[l].setVSPrimeTensor(layers_[0].getVSTensor()); + if(useDoubleInput_) + layers_[l].setVS2PrimeTensor(layers_[0].getVS2Tensor()); + } +} + float ttk::MergeTreeAutoencoder::initParameters( std::vector> &trees, std::vector> &trees2, - bool computeReconstructionError) { + std::vector &isTrain, + bool computeError) { // ----- Init variables // noLayers_ = number of encoder layers + number of decoder layers + the // latent layer + the output layer @@ -496,34 +136,18 @@ float ttk::MergeTreeAutoencoder::initParameters( / 2.0; } - std::vector ftmTrees(trees.size()), - ftmTrees2(trees2.size()); - for(unsigned int i = 0; i < trees.size(); ++i) - ftmTrees[i] = &(trees[i].mTree.tree); - for(unsigned int i = 0; i < trees2.size(); ++i) - ftmTrees2[i] = &(trees2[i].mTree.tree); - auto sizeMetric = getSizeLimitMetric(ftmTrees); - auto sizeMetric2 = getSizeLimitMetric(ftmTrees2); - auto getDim = [](double _sizeMetric, double _percent) { - unsigned int dim = std::max((int)(_sizeMetric * _percent / 100.0), 2) * 2; - return dim; - }; - // ----- Resize parameters - origins_.resize(noLayers_); - originsPrime_.resize(noLayers_); - vSTensor_.resize(noLayers_); - vSPrimeTensor_.resize(noLayers_); - if(trees2.size() != 0) { - origins2_.resize(noLayers_); - origins2Prime_.resize(noLayers_); - vS2Tensor_.resize(noLayers_); - vS2PrimeTensor_.resize(noLayers_); + layers_.resize(noLayers_); + for(unsigned int l = 0; l < layers_.size(); ++l) { + initOriginPrimeValuesByCopy_ + = trackingLossWeight_ != 0 + and l < (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1); + initOriginPrimeValuesByCopyRandomness_ = trackingLossInitRandomness_; + passLayerParameters(layers_[l]); } // ----- Compute parameters of each layer bool fullSymmetricAE = fullSymmetricAE_; - bool outputBasisActivation = activateOutputInit_; std::vector> recs, recs2; std::vector> allAlphasInit( @@ -537,68 +161,22 @@ float ttk::MergeTreeAutoencoder::initParameters( // --- Init Input Basis if(l < (unsigned int)(noLayers_ / 2) or not fullSymmetricAE or (noLayers_ <= 2 and not fullSymmetricAE)) { - // TODO is there a way to avoid copy of merge trees? - std::vector> treesToUse, trees2ToUse; - for(unsigned int i = 0; i < trees.size(); ++i) { - treesToUse.emplace_back((l == 0 ? trees[i].mTree : recs[i].mTree)); - if(trees2.size() != 0) - trees2ToUse.emplace_back((l == 0 ? trees2[i].mTree : recs2[i].mTree)); - } - - // - Compute origin - printMsg("Compute origin...", debug::Priority::DETAIL); - Timer t_origin; - std::vector inputToBaryDistances; - std::vector>> - baryMatchings, baryMatchings2; - if(l != 0 or not origins_[0].tensor.defined()) { - double sizeLimit = (l == 0 ? barycenterSizeLimitPercent_ : 0); - unsigned int maxNoPairs - = (l == 0 ? 0 : originsPrime_[l - 1].tensor.sizes()[0] / 2); - unsigned int maxNoPairs2 - = (l == 0 or not useDoubleInput_ - ? 0 - : origins2Prime_[l - 1].tensor.sizes()[0] / 2); - initInputBasisOrigin(treesToUse, trees2ToUse, sizeLimit, maxNoPairs, - maxNoPairs2, origins_[l], origins2_[l], - inputToBaryDistances, baryMatchings, - baryMatchings2); - if(l == 0) { - baryMatchings_L0_ = baryMatchings; - baryMatchings2_L0_ = baryMatchings2; - inputToBaryDistances_L0_ = inputToBaryDistances; - } - } else { - baryMatchings = baryMatchings_L0_; - baryMatchings2 = baryMatchings2_L0_; - inputToBaryDistances = inputToBaryDistances_L0_; - } - printMsg("Compute origin time", 1, t_origin.getElapsedTime(), - threadNumber_, debug::LineMode::NEW, debug::Priority::DETAIL); - - // - Compute vectors - printMsg("Compute vectors...", debug::Priority::DETAIL); - Timer t_vectors; - auto &tmTreesToUse = (l == 0 ? trees : recs); - auto &tmTrees2ToUse = (l == 0 ? trees2 : recs2); - initInputBasisVectors( - tmTreesToUse, tmTrees2ToUse, treesToUse, trees2ToUse, origins_[l], - origins2_[l], layersNoAxes[l], allAlphasInit, l, inputToBaryDistances, - baryMatchings, baryMatchings2, vSTensor_[l], vS2Tensor_[l]); - printMsg("Compute vectors time", 1, t_vectors.getElapsedTime(), - threadNumber_, debug::LineMode::NEW, debug::Priority::DETAIL); + auto &treesToUse = (l == 0 ? trees : recs); + auto &trees2ToUse = (l == 0 ? trees2 : recs2); + initInputBasis( + l, layersNoAxes[l], treesToUse, trees2ToUse, isTrain, allAlphasInit); } else { // - Copy output tensors of the opposite layer (full symmetric init) printMsg( "Copy output tensors of the opposite layer", debug::Priority::DETAIL); unsigned int middle = noLayers_ / 2; unsigned int l_opp = middle - (l - middle + 1); - mtu::copyTorchMergeTree(originsPrime_[l_opp], origins_[l]); - mtu::copyTensor(vSPrimeTensor_[l_opp], vSTensor_[l]); + layers_[l].setOrigin(layers_[l_opp].getOriginPrime()); + layers_[l].setVSTensor(layers_[l_opp].getVSPrimeTensor()); if(trees2.size() != 0) { if(fullSymmetricAE) { - mtu::copyTorchMergeTree(origins2Prime_[l_opp], origins2_[l]); - mtu::copyTensor(vS2PrimeTensor_[l_opp], vS2Tensor_[l]); + layers_[l].setOrigin2(layers_[l_opp].getOrigin2Prime()); + layers_[l].setVS2Tensor(layers_[l_opp].getVS2PrimeTensor()); } } for(unsigned int i = 0; i < trees.size(); ++i) @@ -606,102 +184,40 @@ float ttk::MergeTreeAutoencoder::initParameters( } // --- Init Output Basis - auto initOutputBasisSpecialCase - = [this, &l, &layersNoAxes, &trees, &trees2]() { - // - Compute Origin - printMsg("Compute output basis origin", debug::Priority::DETAIL); - mtu::copyTorchMergeTree(origins_[0], originsPrime_[l]); - if(useDoubleInput_) - mtu::copyTorchMergeTree(origins2_[0], origins2Prime_[l]); - // - Compute vectors - printMsg("Compute output basis vectors", debug::Priority::DETAIL); - if(layersNoAxes[l] != layersNoAxes[0]) { - // TODO is there a way to avoid copy of merge trees? - std::vector> treesToUse, trees2ToUse; - for(unsigned int i = 0; i < trees.size(); ++i) { - treesToUse.emplace_back(trees[i].mTree); - if(useDoubleInput_) - trees2ToUse.emplace_back(trees2[i].mTree); - } - std::vector> allAlphasInitT( - trees.size(), std::vector(noLayers_)); - initInputBasisVectors( - trees, trees2, treesToUse, trees2ToUse, originsPrime_[l], - origins2Prime_[l], layersNoAxes[l], allAlphasInitT, l, - inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_, - vSPrimeTensor_[l], vS2PrimeTensor_[l]); - } else { - mtu::copyTensor(vSTensor_[0], vSPrimeTensor_[l]); - if(useDoubleInput_) - mtu::copyTensor(vS2Tensor_[0], vS2PrimeTensor_[l]); - } - }; - if((noLayers_ == 2 and l == 1) or noLayers_ == 1) { // -- Special case - initOutputBasisSpecialCase(); + initOutputBasisSpecialCase(l, layersNoAxes[l], trees, trees2); } else if(l < (unsigned int)(noLayers_ / 2)) { - unsigned int dim = getDim(sizeMetric, layersOriginPrimeSizePercent[l]); - dim = std::min(dim, (unsigned int)origins_[l].tensor.sizes()[0]); - unsigned int dim2 = getDim(sizeMetric2, layersOriginPrimeSizePercent[l]); - if(trees2.size() != 0) - dim2 = std::min(dim2, (unsigned int)origins2_[l].tensor.sizes()[0]); - initOutputBasis(l, dim, dim2); + initOutputBasis( + l, layersOriginPrimeSizePercent[l], trees, trees2, isTrain); } else { // - Copy input tensors of the opposite layer (symmetric init) printMsg( "Copy input tensors of the opposite layer", debug::Priority::DETAIL); unsigned int middle = noLayers_ / 2; unsigned int l_opp = middle - (l - middle + 1); - mtu::copyTorchMergeTree(origins_[l_opp], originsPrime_[l]); + layers_[l].setOriginPrime(layers_[l_opp].getOrigin()); if(trees2.size() != 0) - mtu::copyTorchMergeTree(origins2_[l_opp], origins2Prime_[l]); + layers_[l].setOrigin2Prime(layers_[l_opp].getOrigin2()); if(l == (unsigned int)(noLayers_) / 2 and scaleLayerAfterLatent_) { unsigned int dim2 - = (trees2.size() != 0 ? origins2Prime_[l].tensor.sizes()[0] : 0); - initOutputBasisVectors(l, originsPrime_[l].tensor.sizes()[0], dim2); + = (trees2.size() != 0 ? layers_[l].getOrigin2Prime().tensor.sizes()[0] + : 0); + layers_[l].initOutputBasisVectors( + layers_[l].getOriginPrime().tensor.sizes()[0], dim2); } else { - mtu::copyTensor(vSTensor_[l_opp], vSPrimeTensor_[l]); + layers_[l].setVSPrimeTensor(layers_[l_opp].getVSTensor()); if(trees2.size() != 0) - mtu::copyTensor(vS2Tensor_[l_opp], vS2PrimeTensor_[l]); + layers_[l].setVS2PrimeTensor(layers_[l_opp].getVS2Tensor()); } } // --- Get reconstructed - printMsg("Get reconstructed", debug::Priority::DETAIL); - recs.resize(trees.size()); - recs2.resize(trees.size()); - unsigned int i = 0; - unsigned int noReset = 0; - while(i < trees.size()) { - outputBasisReconstruction(originsPrime_[l], vSPrimeTensor_[l], - origins2Prime_[l], vS2PrimeTensor_[l], - allAlphasInit[i][l], recs[i], recs2[i], - outputBasisActivation); - if(recs[i].mTree.tree.getRealNumberOfNodes() == 0) { - printMsg("Reset output basis", debug::Priority::DETAIL); - if((noLayers_ == 2 and l == 1) or noLayers_ == 1) { - initOutputBasisSpecialCase(); - } else if(l < (unsigned int)(noLayers_ / 2)) { - initOutputBasis(l, - getDim(sizeMetric, layersOriginPrimeSizePercent[l]), - getDim(sizeMetric2, layersOriginPrimeSizePercent[l])); - } else { - printErr("recs[i].mTree.tree.getRealNumberOfNodes() == 0"); - std::stringstream ssT; - ssT << "layer " << l; - printWrn(ssT.str()); - return std::numeric_limits::max(); - } - i = 0; - ++noReset; - if(noReset >= 100) { - printWrn("[initParameters] noReset >= 100"); - return std::numeric_limits::max(); - } - } - ++i; - } + bool fullReset = initGetReconstructed( + l, layersNoAxes[l], layersOriginPrimeSizePercent[l], trees, trees2, + isTrain, recs, recs2, allAlphasInit); + if(fullReset) + return std::numeric_limits::max(); } allAlphas_ = allAlphasInit; @@ -711,7 +227,7 @@ float ttk::MergeTreeAutoencoder::initParameters( // Compute error float error = 0.0, recLoss = 0.0; - if(computeReconstructionError) { + if(computeError) { printMsg("Compute error", debug::Priority::DETAIL); std::vector indexes(trees.size()); std::iota(indexes.begin(), indexes.end(), 0); @@ -722,10 +238,9 @@ float ttk::MergeTreeAutoencoder::initParameters( layersOuts2; std::vector>> matchings, matchings2; - bool reset - = forwardStep(trees, trees2, indexes, k, allAlphasInit, - computeReconstructionError, recs, recs2, bestAlphas, - layersOuts, layersOuts2, matchings, matchings2, recLoss); + bool reset = forwardStep(trees, trees2, indexes, k, allAlphasInit, + computeError, recs, recs2, bestAlphas, layersOuts, + layersOuts2, matchings, matchings2, recLoss); if(reset) { printWrn("[initParameters] forwardStep reset"); return std::numeric_limits::max(); @@ -758,568 +273,6 @@ float ttk::MergeTreeAutoencoder::initParameters( return error; } -void ttk::MergeTreeAutoencoder::initStep( - std::vector> &trees, - std::vector> &trees2) { - origins_.clear(); - originsPrime_.clear(); - vSTensor_.clear(); - vSPrimeTensor_.clear(); - origins2_.clear(); - origins2Prime_.clear(); - vS2Tensor_.clear(); - vS2PrimeTensor_.clear(); - - float bestError = std::numeric_limits::max(); - std::vector bestVSTensor, bestVSPrimeTensor, bestVS2Tensor, - bestVS2PrimeTensor, bestLatentCentroids; - std::vector> bestOrigins, bestOriginsPrime, - bestOrigins2, bestOrigins2Prime; - std::vector> bestAlphasInit; - for(unsigned int n = 0; n < noInit_; ++n) { - // Init parameters - float error = initParameters(trees, trees2, (noInit_ != 1)); - // Save best parameters - if(noInit_ != 1) { - std::stringstream ss; - ss << "Init error = " << error; - printMsg(ss.str()); - if(error < bestError) { - bestError = error; - copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_, - origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_, - allAlphas_, bestOrigins, bestOriginsPrime, bestVSTensor, - bestVSPrimeTensor, bestOrigins2, bestOrigins2Prime, - bestVS2Tensor, bestVS2PrimeTensor, bestAlphasInit); - bestLatentCentroids.resize(latentCentroids_.size()); - for(unsigned int i = 0; i < latentCentroids_.size(); ++i) - mtu::copyTensor(latentCentroids_[i], bestLatentCentroids[i]); - } - } - } - // TODO this copy can be avoided if initParameters takes dummy tensors to fill - // as parameters and then copy to the member tensors when a better init is - // found. - if(noInit_ != 1) { - // Put back best parameters - std::stringstream ss; - ss << "Best init error = " << bestError; - printMsg(ss.str()); - copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, - bestOrigins2, bestOrigins2Prime, bestVS2Tensor, - bestVS2PrimeTensor, bestAlphasInit, origins_, originsPrime_, - vSTensor_, vSPrimeTensor_, origins2_, origins2Prime_, vS2Tensor_, - vS2PrimeTensor_, allAlphas_); - latentCentroids_.resize(bestLatentCentroids.size()); - for(unsigned int i = 0; i < bestLatentCentroids.size(); ++i) - mtu::copyTensor(bestLatentCentroids[i], latentCentroids_[i]); - } - - for(unsigned int l = 0; l < noLayers_; ++l) { - origins_[l].tensor.requires_grad_(true); - originsPrime_[l].tensor.requires_grad_(true); - vSTensor_[l].requires_grad_(true); - vSPrimeTensor_[l].requires_grad_(true); - if(trees2.size() != 0) { - origins2_[l].tensor.requires_grad_(true); - origins2Prime_[l].tensor.requires_grad_(true); - vS2Tensor_[l].requires_grad_(true); - vS2PrimeTensor_[l].requires_grad_(true); - } - - // Print - printMsg(debug::Separator::L2); - std::stringstream ss; - ss << "Layer " << l; - printMsg(ss.str()); - if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_)) { - ss.str(""); - ss << "origins_[" << l << "] has big values!" << std::endl; - printMsg(ss.str()); - wae::printPairs(origins_[l].mTree); - } - if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_)) { - ss.str(""); - ss << "originsPrime_[" << l << "] has big values!" << std::endl; - printMsg(ss.str()); - wae::printPairs(originsPrime_[l].mTree); - } - ss.str(""); - ss << "vS size = " << vSTensor_[l].sizes(); - printMsg(ss.str()); - ss.str(""); - ss << "vS' size = " << vSPrimeTensor_[l].sizes(); - printMsg(ss.str()); - if(trees2.size() != 0) { - ss.str(""); - ss << "vS2 size = " << vS2Tensor_[l].sizes(); - printMsg(ss.str()); - ss.str(""); - ss << "vS2' size = " << vS2PrimeTensor_[l].sizes(); - printMsg(ss.str()); - } - } - - // Init Clustering Loss Parameters - if(clusteringLossWeight_ != 0) - initClusteringLossParameters(); -} - -// --------------------------------------------------------------------------- -// --- Interpolation -// --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::interpolationDiagonalProjection( - mtu::TorchMergeTree &interpolation) { - torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2}); - if(interpolation.tensor.requires_grad()) - diagTensor = diagTensor.detach(); - - torch::Tensor birthTensor = diagTensor.index({Slice(), 0}); - torch::Tensor deathTensor = diagTensor.index({Slice(), 1}); - - torch::Tensor indexer = (birthTensor > deathTensor); - - torch::Tensor allProj = (birthTensor + deathTensor) / 2.0; - allProj = allProj.index({indexer}); - allProj = allProj.reshape({-1, 1}); - - diagTensor.index_put_({indexer}, allProj); -} - -void ttk::MergeTreeAutoencoder::interpolationNestingProjection( - mtu::TorchMergeTree &interpolation) { - torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2}); - if(interpolation.tensor.requires_grad()) - diagTensor = diagTensor.detach(); - - torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0}); - torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1}); - - torch::Tensor birthIndexer = (birthTensor < 0); - torch::Tensor deathIndexer = (deathTensor < 0); - birthTensor.index_put_( - {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer}))); - deathTensor.index_put_( - {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer}))); - - birthIndexer = (birthTensor > 1); - deathIndexer = (deathTensor > 1); - birthTensor.index_put_( - {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer}))); - deathTensor.index_put_( - {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer}))); -} - -void ttk::MergeTreeAutoencoder::interpolationProjection( - mtu::TorchMergeTree &interpolation) { - interpolationDiagonalProjection(interpolation); - if(normalizedWasserstein_) - interpolationNestingProjection(interpolation); - - ftm::MergeTree interpolationNew; - bool noRoot = mtu::torchTensorToMergeTree( - interpolation, normalizedWasserstein_, interpolationNew); - if(noRoot) - printWrn("[interpolationProjection] no root found"); - interpolation.mTree = copyMergeTree(interpolationNew); - - persistenceThresholding(&(interpolation.mTree.tree), 0.001); - - if(isThereMissingPairs(interpolation) and isPersistenceDiagram_) - printWrn("[getMultiInterpolation] missing pairs"); -} - -void ttk::MergeTreeAutoencoder::getMultiInterpolation( - mtu::TorchMergeTree &origin, - torch::Tensor &vS, - torch::Tensor &alphas, - mtu::TorchMergeTree &interpolation) { - mtu::copyTorchMergeTree(origin, interpolation); - interpolation.tensor = origin.tensor + torch::matmul(vS, alphas); - interpolationProjection(interpolation); -} - -// --------------------------------------------------------------------------- -// --- Forward -// --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::getAlphasOptimizationTensors( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &interpolated, - std::vector> &matching, - torch::Tensor &reorderedTreeTensor, - torch::Tensor &deltaOrigin, - torch::Tensor &deltaA, - torch::Tensor &originTensor_f, - torch::Tensor &vSTensor_f) { - // Create matching indexing - std::vector tensorMatching; - mtu::getTensorMatching(interpolated, tree, matching, tensorMatching); - - torch::Tensor indexes = torch::tensor(tensorMatching); - torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1}); - - dataReorderingGivenMatching( - origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin); - - // Create axes projection given matching - deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2}); - deltaA = (deltaA.index({Slice(), Slice(), 0}) - + deltaA.index({Slice(), Slice(), 1})) - / 2.0; - deltaA = torch::stack({deltaA, deltaA}, 2); - deltaA = deltaA * projIndexer; - deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1); - - // - originTensor_f = origin.tensor; - vSTensor_f = vSTensor; -} - -void ttk::MergeTreeAutoencoder::computeAlphas( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &interpolated, - std::vector> &matching, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &interpolated2, - std::vector> &matching2, - torch::Tensor &alphasOut) { - torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f, - vSTensor_f; - getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching, - reorderedTreeTensor, deltaOrigin, deltaA, - originTensor_f, vSTensor_f); - - if(useDoubleInput_) { - torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f, - vS2Tensor_f; - getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2, - matching2, reorderedTree2Tensor, deltaOrigin2, - deltaA2, origin2Tensor_f, vS2Tensor_f); - vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f}); - deltaA = torch::cat({deltaA, deltaA2}); - reorderedTreeTensor - = torch::cat({reorderedTreeTensor, reorderedTree2Tensor}); - originTensor_f = torch::cat({originTensor_f, origin2Tensor_f}); - deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2}); - } - - torch::Tensor r_axes = vSTensor_f - deltaA; - torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin; - - // Pseudo inverse - auto driver = "gelsd"; - alphasOut - = std::get<0>(torch::linalg::lstsq(r_axes, r_data, c10::nullopt, driver)); - - alphasOut.reshape({-1, 1}); -} - -float ttk::MergeTreeAutoencoder::assignmentOneData( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - unsigned int k, - torch::Tensor &alphasInit, - std::vector> &bestMatching, - std::vector> &bestMatching2, - torch::Tensor &bestAlphas, - bool isCalled) { - torch::Tensor alphas, oldAlphas; - std::vector> matching, matching2; - float bestDistance = std::numeric_limits::max(); - mtu::TorchMergeTree interpolated, interpolated2; - unsigned int i = 0; - auto reset = [&]() { - alphasInit = torch::randn_like(alphas); - i = 0; - }; - unsigned int noUpdate = 0; - unsigned int noReset = 0; - while(i < k) { - if(i == 0) { - if(alphasInit.defined()) - alphas = alphasInit; - else - alphas = torch::zeros({vSTensor.sizes()[1], 1}); - } else { - computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2, - origin2, vS2Tensor, interpolated2, matching2, alphas); - if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas) - and i != 1) { - break; - } - } - mtu::copyTensor(alphas, oldAlphas); - getMultiInterpolation(origin, vSTensor, alphas, interpolated); - if(useDoubleInput_) - getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2); - if(interpolated.mTree.tree.getRealNumberOfNodes() == 0 - or (useDoubleInput_ - and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) { - ++noReset; - if(noReset >= 100) - printWrn("[assignmentOneData] noReset >= 100"); - reset(); - continue; - } - float distance; - computeOneDistance(interpolated.mTree, tree.mTree, matching, - distance, isCalled, useDoubleInput_); - if(useDoubleInput_) { - float distance2; - computeOneDistance(interpolated2.mTree, tree2.mTree, matching2, - distance2, isCalled, useDoubleInput_, false); - distance = mixDistances(distance, distance2); - } - if(distance < bestDistance and i != 0) { - bestDistance = distance; - bestMatching = matching; - bestMatching2 = matching2; - bestAlphas = alphas; - noUpdate += 1; - } - i += 1; - } - if(noUpdate == 0) - printErr("[assignmentOneData] noUpdate == 0"); - return bestDistance; -} - -float ttk::MergeTreeAutoencoder::assignmentOneData( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - unsigned int k, - torch::Tensor &alphasInit, - torch::Tensor &bestAlphas, - bool isCalled) { - std::vector> bestMatching, - bestMatching2; - return assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k, - alphasInit, bestMatching, bestMatching2, bestAlphas, - isCalled); -} - -torch::Tensor ttk::MergeTreeAutoencoder::activation(torch::Tensor &in) { - torch::Tensor act; - switch(activationFunction_) { - case 1: - act = torch::nn::LeakyReLU()(in); - break; - case 0: - default: - act = torch::nn::ReLU()(in); - } - return act; -} - -void ttk::MergeTreeAutoencoder::outputBasisReconstruction( - mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - torch::Tensor &alphas, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - bool activate) { - if(not activate_) - activate = false; - torch::Tensor act = (activate ? activation(alphas) : alphas); - getMultiInterpolation(originPrime, vSPrimeTensor, act, out); - if(useDoubleInput_) - getMultiInterpolation(origin2Prime, vS2PrimeTensor, act, out2); -} - -bool ttk::MergeTreeAutoencoder::forwardOneLayer( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - unsigned int k, - torch::Tensor &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - torch::Tensor &bestAlphas, - float &bestDistance) { - bool goodOutput = false; - int noReset = 0; - while(not goodOutput) { - bool isCalled = true; - bestDistance - = assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k, - alphasInit, bestAlphas, isCalled); - outputBasisReconstruction(originPrime, vSPrimeTensor, origin2Prime, - vS2PrimeTensor, bestAlphas, out, out2); - goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0 - and (not useDoubleInput_ - or out2.mTree.tree.getRealNumberOfNodes() != 0)); - if(not goodOutput) { - ++noReset; - if(noReset >= 100) { - printWrn("[forwardOneLayer] noReset >= 100"); - return true; - } - alphasInit = torch::randn_like(alphasInit); - } - } - return false; -} - -bool ttk::MergeTreeAutoencoder::forwardOneLayer( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - unsigned int k, - torch::Tensor &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - torch::Tensor &bestAlphas) { - float bestDistance; - return forwardOneLayer(tree, origin, vSTensor, originPrime, vSPrimeTensor, - tree2, origin2, vS2Tensor, origin2Prime, - vS2PrimeTensor, k, alphasInit, out, out2, bestAlphas, - bestDistance); -} - -bool ttk::MergeTreeAutoencoder::forwardOneData( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, - unsigned int treeIndex, - unsigned int k, - std::vector &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - std::vector &dataAlphas, - std::vector> &outs, - std::vector> &outs2) { - outs.resize(noLayers_ - 1); - outs2.resize(noLayers_ - 1); - dataAlphas.resize(noLayers_); - for(unsigned int l = 0; l < noLayers_; ++l) { - auto &treeToUse = (l == 0 ? tree : outs[l - 1]); - auto &tree2ToUse = (l == 0 ? tree2 : outs2[l - 1]); - auto &outToUse = (l != noLayers_ - 1 ? outs[l] : out); - auto &out2ToUse = (l != noLayers_ - 1 ? outs2[l] : out2); - bool reset = forwardOneLayer( - treeToUse, origins_[l], vSTensor_[l], originsPrime_[l], vSPrimeTensor_[l], - tree2ToUse, origins2_[l], vS2Tensor_[l], origins2Prime_[l], - vS2PrimeTensor_[l], k, alphasInit[l], outToUse, out2ToUse, dataAlphas[l]); - if(reset) - return true; - // Update recs - auto updateRecs - = [this, &treeIndex, &l]( - std::vector>> &recs, - mtu::TorchMergeTree &outT) { - if(recs[treeIndex].size() > noLayers_) - mtu::copyTorchMergeTree(outT, recs[treeIndex][l + 1]); - else { - mtu::TorchMergeTree tmt; - mtu::copyTorchMergeTree(outT, tmt); - recs[treeIndex].emplace_back(tmt); - } - }; - updateRecs(recs_, outToUse); - if(useDoubleInput_) - updateRecs(recs2_, out2ToUse); - } - return false; -} - -bool ttk::MergeTreeAutoencoder::forwardStep( - std::vector> &trees, - std::vector> &trees2, - std::vector &indexes, - unsigned int k, - std::vector> &allAlphasInit, - bool computeReconstructionError, - std::vector> &outs, - std::vector> &outs2, - std::vector> &bestAlphas, - std::vector>> &layersOuts, - std::vector>> &layersOuts2, - std::vector>> - &matchings, - std::vector>> - &matchings2, - float &loss) { - loss = 0; - outs.resize(trees.size()); - outs2.resize(trees.size()); - bestAlphas.resize(trees.size()); - layersOuts.resize(trees.size()); - layersOuts2.resize(trees.size()); - matchings.resize(trees.size()); - if(useDoubleInput_) - matchings2.resize(trees2.size()); - mtu::TorchMergeTree dummyTMT; - bool reset = false; -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \ - if(parallelize_) reduction(||: reset) reduction(+:loss) -#endif - for(unsigned int ind = 0; ind < indexes.size(); ++ind) { - unsigned int i = indexes[ind]; - auto &tree2ToUse = (trees2.size() == 0 ? dummyTMT : trees2[i]); - bool dReset - = forwardOneData(trees[i], tree2ToUse, i, k, allAlphasInit[i], outs[i], - outs2[i], bestAlphas[i], layersOuts[i], layersOuts2[i]); - if(computeReconstructionError) { - float iLoss = computeOneLoss( - trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]); - loss += iLoss; - } - if(dReset) - reset = reset || dReset; - } - loss /= indexes.size(); - return reset; -} - -bool ttk::MergeTreeAutoencoder::forwardStep( - std::vector> &trees, - std::vector> &trees2, - std::vector &indexes, - unsigned int k, - std::vector> &allAlphasInit, - std::vector> &outs, - std::vector> &outs2, - std::vector> &bestAlphas) { - std::vector>> layersOuts, layersOuts2; - std::vector>> - matchings, matchings2; - bool computeReconstructionError = false; - float loss; - return forwardStep(trees, trees2, indexes, k, allAlphasInit, - computeReconstructionError, outs, outs2, bestAlphas, - layersOuts, layersOuts2, matchings, matchings2, loss); -} - // --------------------------------------------------------------------------- // --- Backward // --------------------------------------------------------------------------- @@ -1332,11 +285,11 @@ bool ttk::MergeTreeAutoencoder::backwardStep( std::vector> &outs2, std::vector>> &matchings2, + std::vector> &ttkNotUsed(alphas), torch::optim::Optimizer &optimizer, std::vector &indexes, - torch::Tensor &metricLoss, - torch::Tensor &clusteringLoss, - torch::Tensor &trackingLoss) { + std::vector &ttkNotUsed(isTrain), + std::vector &torchCustomLoss) { double totalLoss = 0; bool retainGraph = (metricLossWeight_ != 0 or clusteringLossWeight_ != 0 or trackingLossWeight_ != 0); @@ -1378,76 +331,31 @@ bool ttk::MergeTreeAutoencoder::backwardStep( if(metricLossWeight_ != 0) { bool retainGraphMetricLoss = (clusteringLossWeight_ != 0 or trackingLossWeight_ != 0); - metricLoss *= metricLossWeight_ - * getCustomLossDynamicWeight( - totalLoss / indexes.size(), baseRecLoss2_); - metricLoss.backward({}, retainGraphMetricLoss); + torchCustomLoss[0] *= metricLossWeight_ + * getCustomLossDynamicWeight( + totalLoss / indexes.size(), baseRecLoss2_); + torchCustomLoss[0].backward({}, retainGraphMetricLoss); } if(clusteringLossWeight_ != 0) { bool retainGraphClusteringLoss = (trackingLossWeight_ != 0); - clusteringLoss *= clusteringLossWeight_ - * getCustomLossDynamicWeight( - totalLoss / indexes.size(), baseRecLoss2_); - clusteringLoss.backward({}, retainGraphClusteringLoss); - } - if(trackingLossWeight_ != 0) { - trackingLoss *= trackingLossWeight_; - trackingLoss.backward(); + torchCustomLoss[1] *= clusteringLossWeight_ + * getCustomLossDynamicWeight( + totalLoss / indexes.size(), baseRecLoss2_); + torchCustomLoss[1].backward({}, retainGraphClusteringLoss); } - - for(unsigned int l = 0; l < noLayers_; ++l) { - if(not origins_[l].tensor.grad().defined() - or not origins_[l].tensor.grad().count_nonzero().is_nonzero()) - ++originsNoZeroGrad_[l]; - if(not originsPrime_[l].tensor.grad().defined() - or not originsPrime_[l].tensor.grad().count_nonzero().is_nonzero()) - ++originsPrimeNoZeroGrad_[l]; - if(not vSTensor_[l].grad().defined() - or not vSTensor_[l].grad().count_nonzero().is_nonzero()) - ++vSNoZeroGrad_[l]; - if(not vSPrimeTensor_[l].grad().defined() - or not vSPrimeTensor_[l].grad().count_nonzero().is_nonzero()) - ++vSPrimeNoZeroGrad_[l]; - if(useDoubleInput_) { - if(not origins2_[l].tensor.grad().defined() - or not origins2_[l].tensor.grad().count_nonzero().is_nonzero()) - ++origins2NoZeroGrad_[l]; - if(not origins2Prime_[l].tensor.grad().defined() - or not origins2Prime_[l].tensor.grad().count_nonzero().is_nonzero()) - ++origins2PrimeNoZeroGrad_[l]; - if(not vS2Tensor_[l].grad().defined() - or not vS2Tensor_[l].grad().count_nonzero().is_nonzero()) - ++vS2NoZeroGrad_[l]; - if(not vS2PrimeTensor_[l].grad().defined() - or not vS2PrimeTensor_[l].grad().count_nonzero().is_nonzero()) - ++vS2PrimeNoZeroGrad_[l]; - } + if(trackingLossWeight_ != 0) { + torchCustomLoss[2] *= trackingLossWeight_; + torchCustomLoss[2].backward(); } + for(unsigned int l = 0; l < noLayers_; ++l) + checkZeroGrad(l); + optimizer.step(); optimizer.zero_grad(); return false; } -// --------------------------------------------------------------------------- -// --- Projection -// --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::projectionStep() { - auto projectTree = [this](mtu::TorchMergeTree &tmt) { - interpolationProjection(tmt); - tmt.tensor = tmt.tensor.detach(); - tmt.tensor.requires_grad_(true); - }; - for(unsigned int l = 0; l < noLayers_; ++l) { - projectTree(origins_[l]); - projectTree(originsPrime_[l]); - if(useDoubleInput_) { - projectTree(origins2_[l]); - projectTree(origins2Prime_[l]); - } - } -} - // --------------------------------------------------------------------------- // --- Convergence // --------------------------------------------------------------------------- @@ -1457,7 +365,9 @@ float ttk::MergeTreeAutoencoder::computeOneLoss( mtu::TorchMergeTree &tree2, mtu::TorchMergeTree &out2, std::vector> &matching, - std::vector> &matching2) { + std::vector> &matching2, + std::vector &ttkNotUsed(alphas), + unsigned int ttkNotUsed(treeIndex)) { float loss = 0; bool isCalled = true; float distance; @@ -1473,563 +383,162 @@ float ttk::MergeTreeAutoencoder::computeOneLoss( return loss; } -float ttk::MergeTreeAutoencoder::computeLoss( - std::vector> &trees, - std::vector> &outs, - std::vector> &trees2, - std::vector> &outs2, - std::vector &indexes, - std::vector>> - &matchings, - std::vector>> - &matchings2) { - float loss = 0; - matchings.resize(trees.size()); - if(useDoubleInput_) - matchings2.resize(trees2.size()); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \ - if(parallelize_) reduction(+:loss) -#endif - for(unsigned int ind = 0; ind < indexes.size(); ++ind) { - unsigned int i = indexes[ind]; - float iLoss = computeOneLoss( - trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]); - loss += iLoss; - } - return loss / indexes.size(); -} - -bool ttk::MergeTreeAutoencoder::isBestLoss(float loss, - float &minLoss, - unsigned int &cptBlocked) { - bool isBestEnergy = false; - if(loss + ENERGY_COMPARISON_TOLERANCE < minLoss) { - minLoss = loss; - cptBlocked = 0; - isBestEnergy = true; - } - return isBestEnergy; -} - -bool ttk::MergeTreeAutoencoder::convergenceStep(float loss, - float &oldLoss, - float &minLoss, - unsigned int &cptBlocked) { - double tol = oldLoss / 125.0; - bool converged = std::abs(loss - oldLoss) < std::abs(tol); - oldLoss = loss; - if(not converged) { - cptBlocked += (minLoss < loss) ? 1 : 0; - converged = (cptBlocked >= 10 * 10); - if(converged) - printMsg("Blocked!", debug::Priority::DETAIL); - } - return converged; -} - // --------------------------------------------------------------------------- // --- Main Functions // --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::fit( - std::vector> &trees, - std::vector> &trees2) { - torch::set_num_threads(1); - // ----- Determinism - if(deterministic_) { - int m_seed = 0; - bool m_torch_deterministic = true; - srand(m_seed); - torch::manual_seed(m_seed); - at::globalContext().setDeterministicCuDNN(m_torch_deterministic ? true - : false); - at::globalContext().setDeterministicAlgorithms( - m_torch_deterministic ? true : false, true); - } - - // ----- Testing - for(unsigned int i = 0; i < trees.size(); ++i) { - for(unsigned int n = 0; n < trees[i].tree.getNumberOfNodes(); ++n) { - if(trees[i].tree.isNodeAlone(n)) - continue; - auto birthDeath = trees[i].tree.template getBirthDeath(n); - bigValuesThreshold_ - = std::max(std::abs(std::get<0>(birthDeath)), bigValuesThreshold_); - bigValuesThreshold_ - = std::max(std::abs(std::get<1>(birthDeath)), bigValuesThreshold_); - } - } - bigValuesThreshold_ *= 100; - - // ----- Convert MergeTree to TorchMergeTree - std::vector> torchTrees, torchTrees2; - mergeTreesToTorchTrees(trees, torchTrees, normalizedWasserstein_); - mergeTreesToTorchTrees(trees2, torchTrees2, normalizedWasserstein_); - - auto initRecs = [](std::vector>> &recs, - std::vector> &torchTreesT) { - recs.clear(); - recs.resize(torchTreesT.size()); - for(unsigned int i = 0; i < torchTreesT.size(); ++i) { - mtu::TorchMergeTree tmt; - mtu::copyTorchMergeTree(torchTreesT[i], tmt); - recs[i].emplace_back(tmt); - } - }; - initRecs(recs_, torchTrees); - if(useDoubleInput_) - initRecs(recs2_, torchTrees2); - +void ttk::MergeTreeAutoencoder::customInit( + std::vector> &torchTrees, + std::vector> &torchTrees2) { + baseRecLoss_ = std::numeric_limits::max(); + baseRecLoss2_ = std::numeric_limits::max(); // ----- Init Metric Loss if(metricLossWeight_ != 0) getDistanceMatrix(torchTrees, torchTrees2, distanceMatrix_); +} - // ----- Init Model Parameters - Timer t_init; - initStep(torchTrees, torchTrees2); - printMsg("Init", 1, t_init.getElapsedTime(), threadNumber_); - - // --- Init optimizer - std::vector parameters; - for(unsigned int l = 0; l < noLayers_; ++l) { - parameters.emplace_back(origins_[l].tensor); - parameters.emplace_back(originsPrime_[l].tensor); - parameters.emplace_back(vSTensor_[l]); - parameters.emplace_back(vSPrimeTensor_[l]); - if(trees2.size() != 0) { - parameters.emplace_back(origins2_[l].tensor); - parameters.emplace_back(origins2Prime_[l].tensor); - parameters.emplace_back(vS2Tensor_[l]); - parameters.emplace_back(vS2PrimeTensor_[l]); - } - } +void ttk::MergeTreeAutoencoder::addCustomParameters( + std::vector ¶meters) { if(clusteringLossWeight_ != 0) for(unsigned int i = 0; i < latentCentroids_.size(); ++i) parameters.emplace_back(latentCentroids_[i]); +} - torch::optim::Optimizer *optimizer; - // - Init Adam - auto adamOptions = torch::optim::AdamOptions(gradientStepSize_); - adamOptions.betas(std::make_tuple(beta1_, beta2_)); - auto adamOptimizer = torch::optim::Adam(parameters, adamOptions); - // - Init SGD optimizer - auto sgdOptions = torch::optim::SGDOptions(gradientStepSize_); - auto sgdOptimizer = torch::optim::SGD(parameters, sgdOptions); - // -Init RMSprop optimizer - auto rmspropOptions = torch::optim::RMSpropOptions(gradientStepSize_); - auto rmspropOptimizer = torch::optim::RMSprop(parameters, rmspropOptions); - // - Set optimizer pointer - switch(optimizer_) { - case 1: - optimizer = &sgdOptimizer; - break; - case 2: - optimizer = &rmspropOptimizer; - break; - case 0: - default: - optimizer = &adamOptimizer; +void ttk::MergeTreeAutoencoder::computeCustomLosses( + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector> &bestAlphas, + std::vector &indexes, + std::vector &ttkNotUsed(isTrain), + unsigned int ttkNotUsed(iteration), + std::vector> &gapCustomLosses, + std::vector> &iterationCustomLosses, + std::vector &torchCustomLoss) { + if(gapCustomLosses.empty()) + gapCustomLosses.resize(3); + if(iterationCustomLosses.empty()) + iterationCustomLosses.resize(3); + torchCustomLoss.resize(3); + // - Metric Loss + if(metricLossWeight_ != 0) { + computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_, + indexes, torchCustomLoss[0]); + float metricLossF = torchCustomLoss[0].item(); + gapCustomLosses[0].emplace_back(metricLossF); + iterationCustomLosses[0].emplace_back(metricLossF); } - - // --- Init batches indexes - unsigned int batchSize = std::min( - std::max((int)(trees.size() * batchSize_), 1), (int)trees.size()); - std::stringstream ssBatch; - ssBatch << "batchSize = " << batchSize; - printMsg(ssBatch.str()); - unsigned int noBatch - = trees.size() / batchSize + ((trees.size() % batchSize) != 0 ? 1 : 0); - std::vector> allIndexes(noBatch); - if(noBatch == 1) { - allIndexes[0].resize(trees.size()); - std::iota(allIndexes[0].begin(), allIndexes[0].end(), 0); + // - Clustering Loss + if(clusteringLossWeight_ != 0) { + torch::Tensor asgn; + computeClusteringLoss(bestAlphas, indexes, torchCustomLoss[1], asgn); + float clusteringLossF = torchCustomLoss[1].item(); + gapCustomLosses[1].emplace_back(clusteringLossF); + iterationCustomLosses[1].emplace_back(clusteringLossF); + } + // - Tracking Loss + if(trackingLossWeight_ != 0) { + computeTrackingLoss(torchCustomLoss[2]); + float trackingLossF = torchCustomLoss[2].item(); + gapCustomLosses[2].emplace_back(trackingLossF); + iterationCustomLosses[2].emplace_back(trackingLossF); } - auto rng = std::default_random_engine{}; +} - // ----- Testing - originsNoZeroGrad_.resize(noLayers_); - originsPrimeNoZeroGrad_.resize(noLayers_); - vSNoZeroGrad_.resize(noLayers_); - vSPrimeNoZeroGrad_.resize(noLayers_); - for(unsigned int l = 0; l < noLayers_; ++l) { - originsNoZeroGrad_[l] = 0; - originsPrimeNoZeroGrad_[l] = 0; - vSNoZeroGrad_[l] = 0; - vSPrimeNoZeroGrad_[l] = 0; +float ttk::MergeTreeAutoencoder::computeIterationTotalLoss( + float iterationLoss, + std::vector> &iterationCustomLosses, + std::vector &iterationCustomLoss) { + iterationCustomLoss.emplace_back(iterationLoss); + float iterationTotalLoss = reconstructionLossWeight_ * iterationLoss; + // Metric + float iterationMetricLoss = 0; + if(metricLossWeight_ != 0) { + iterationMetricLoss + = torch::tensor(iterationCustomLosses[0]).mean().item(); + iterationTotalLoss + += metricLossWeight_ + * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_) + * iterationMetricLoss; + } + iterationCustomLoss.emplace_back(iterationMetricLoss); + // Clustering + float iterationClusteringLoss = 0; + if(clusteringLossWeight_ != 0) { + iterationClusteringLoss + = torch::tensor(iterationCustomLosses[1]).mean().item(); + iterationTotalLoss + += clusteringLossWeight_ + * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_) + * iterationClusteringLoss; + } + iterationCustomLoss.emplace_back(iterationClusteringLoss); + // Tracking + float iterationTrackingLoss = 0; + if(trackingLossWeight_ != 0) { + iterationTrackingLoss + = torch::tensor(iterationCustomLosses[2]).mean().item(); + iterationTotalLoss += trackingLossWeight_ * iterationTrackingLoss; + } + iterationCustomLoss.emplace_back(iterationTrackingLoss); + return iterationTotalLoss; +} + +void ttk::MergeTreeAutoencoder::printCustomLosses( + std::vector &customLoss, + std::stringstream &prefix, + const debug::Priority &priority) { + if(priority != debug::Priority::VERBOSE) + prefix.str(""); + std::stringstream ssBestLoss; + if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0 + or trackingLossWeight_ != 0) { + ssBestLoss.str(""); + ssBestLoss << "- Rec. " << prefix.str() << "loss = " << customLoss[0]; + printMsg(ssBestLoss.str(), priority); } - if(useDoubleInput_) { - origins2NoZeroGrad_.resize(noLayers_); - origins2PrimeNoZeroGrad_.resize(noLayers_); - vS2NoZeroGrad_.resize(noLayers_); - vS2PrimeNoZeroGrad_.resize(noLayers_); - for(unsigned int l = 0; l < noLayers_; ++l) { - origins2NoZeroGrad_[l] = 0; - origins2PrimeNoZeroGrad_[l] = 0; - vS2NoZeroGrad_[l] = 0; - vS2PrimeNoZeroGrad_[l] = 0; - } + if(metricLossWeight_ != 0) { + ssBestLoss.str(""); + ssBestLoss << "- Metric " << prefix.str() << "loss = " << customLoss[1]; + printMsg(ssBestLoss.str(), priority); } - - // ----- Init Variables - baseRecLoss_ = std::numeric_limits::max(); - baseRecLoss2_ = std::numeric_limits::max(); - unsigned int k = k_; - float oldLoss, minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss; - unsigned int cptBlocked, iteration = 0; - auto initLoop = [&]() { - oldLoss = -1; - minLoss = std::numeric_limits::max(); - minRecLoss = minLoss; - minMetricLoss = minLoss; - minClustLoss = minLoss; - minTrackLoss = minLoss; - cptBlocked = 0; - iteration = 0; - }; - initLoop(); - int convWinSize = 5; - int noConverged = 0, noConvergedToGet = 10; - std::vector losses, metricLosses, clusteringLosses, trackingLosses; - float windowLoss = 0; - - double assignmentTime = 0.0, updateTime = 0.0, projectionTime = 0.0, - lossTime = 0.0; - - int bestIteration = 0; - std::vector bestVSTensor, bestVSPrimeTensor, bestVS2Tensor, - bestVS2PrimeTensor; - std::vector> bestOrigins, bestOriginsPrime, - bestOrigins2, bestOrigins2Prime; - std::vector> bestAlphasInit; - std::vector>> bestRecs, bestRecs2; - double bestTime = 0; - - auto printLoss - = [this](float loss, float recLoss, float metricLoss, float clustLoss, - float trackLoss, int iterationT, int iterationTT, double time, - const debug::Priority &priority = debug::Priority::INFO) { - std::stringstream prefix; - prefix << (priority == debug::Priority::VERBOSE ? "Iter " : "Best "); - std::stringstream ssBestLoss; - ssBestLoss << prefix.str() << "loss is " << loss << " (iteration " - << iterationT << " / " << iterationTT << ") at time " - << time; - printMsg(ssBestLoss.str(), priority); - if(priority != debug::Priority::VERBOSE) - prefix.str(""); - if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0 - or trackingLossWeight_ != 0) { - ssBestLoss.str(""); - ssBestLoss << "- Rec. " << prefix.str() << "loss = " << recLoss; - printMsg(ssBestLoss.str(), priority); - } - if(metricLossWeight_ != 0) { - ssBestLoss.str(""); - ssBestLoss << "- Metric " << prefix.str() << "loss = " << metricLoss; - printMsg(ssBestLoss.str(), priority); - } - if(clusteringLossWeight_ != 0) { - ssBestLoss.str(""); - ssBestLoss << "- Clust. " << prefix.str() << "loss = " << clustLoss; - printMsg(ssBestLoss.str(), priority); - } - if(trackingLossWeight_ != 0) { - ssBestLoss.str(""); - ssBestLoss << "- Track. " << prefix.str() << "loss = " << trackLoss; - printMsg(ssBestLoss.str(), priority); - } - }; - - // ----- Algorithm - Timer t_alg; - bool converged = false; - while(not converged) { - if(iteration % iterationGap_ == 0) { - std::stringstream ss; - ss << "Iteration " << iteration; - printMsg(debug::Separator::L2); - printMsg(ss.str()); - } - - bool forwardReset = false; - std::vector iterationLosses, iterationMetricLosses, - iterationClusteringLosses, iterationTrackingLosses; - if(noBatch != 1) { - std::vector indexes(trees.size()); - std::iota(indexes.begin(), indexes.end(), 0); - std::shuffle(std::begin(indexes), std::end(indexes), rng); - for(unsigned int i = 0; i < allIndexes.size(); ++i) { - unsigned int noProcessed = batchSize * i; - unsigned int remaining = trees.size() - noProcessed; - unsigned int size = std::min(batchSize, remaining); - allIndexes[i].resize(size); - for(unsigned int j = 0; j < size; ++j) - allIndexes[i][j] = indexes[noProcessed + j]; - } - } - for(unsigned batchNum = 0; batchNum < allIndexes.size(); ++batchNum) { - auto &indexes = allIndexes[batchNum]; - - // --- Assignment - Timer t_assignment; - std::vector> outs, outs2; - std::vector> bestAlphas; - std::vector>> layersOuts, - layersOuts2; - std::vector>> - matchings, matchings2; - float loss; - bool computeReconstructionError = reconstructionLossWeight_ != 0; - forwardReset - = forwardStep(torchTrees, torchTrees2, indexes, k, allAlphas_, - computeReconstructionError, outs, outs2, bestAlphas, - layersOuts, layersOuts2, matchings, matchings2, loss); - if(forwardReset) - break; - for(unsigned int ind = 0; ind < indexes.size(); ++ind) { - unsigned int i = indexes[ind]; - for(unsigned int j = 0; j < bestAlphas[i].size(); ++j) - mtu::copyTensor(bestAlphas[i][j], allAlphas_[i][j]); - } - assignmentTime += t_assignment.getElapsedTime(); - - // --- Loss - Timer t_loss; - losses.emplace_back(loss); - iterationLosses.emplace_back(loss); - // - Metric Loss - torch::Tensor metricLoss; - if(metricLossWeight_ != 0) { - computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_, - indexes, metricLoss); - float metricLossF = metricLoss.item(); - metricLosses.emplace_back(metricLossF); - iterationMetricLosses.emplace_back(metricLossF); - } - // - Clustering Loss - torch::Tensor clusteringLoss; - if(clusteringLossWeight_ != 0) { - torch::Tensor asgn; - computeClusteringLoss(bestAlphas, indexes, clusteringLoss, asgn); - float clusteringLossF = clusteringLoss.item(); - clusteringLosses.emplace_back(clusteringLossF); - iterationClusteringLosses.emplace_back(clusteringLossF); - } - // - Tracking Loss - torch::Tensor trackingLoss; - if(trackingLossWeight_ != 0) { - computeTrackingLoss(trackingLoss); - float trackingLossF = trackingLoss.item(); - trackingLosses.emplace_back(trackingLossF); - iterationTrackingLosses.emplace_back(trackingLossF); - } - lossTime += t_loss.getElapsedTime(); - - // --- Update - Timer t_update; - backwardStep(torchTrees, outs, matchings, torchTrees2, outs2, matchings2, - *optimizer, indexes, metricLoss, clusteringLoss, - trackingLoss); - updateTime += t_update.getElapsedTime(); - - // --- Projection - Timer t_projection; - projectionStep(); - projectionTime += t_projection.getElapsedTime(); - } - - if(forwardReset) { - // TODO better manage reset by init new parameters and start again for - // example (should not happen anymore) - printWrn("Forward reset!"); - break; - } - - // --- Get iteration loss - // TODO an approximation is made here if batch size != 1 because the - // iteration loss will not be exact, we need to do a forward step and - // compute loss with the whole dataset - /*if(batchSize_ != 1) - printWrn("iteration loss approximation (batchSize_ != 1)");*/ - float iterationRecLoss - = torch::tensor(iterationLosses).mean().item(); - float iterationLoss = reconstructionLossWeight_ * iterationRecLoss; - float iterationMetricLoss = 0; - if(metricLossWeight_ != 0) { - iterationMetricLoss - = torch::tensor(iterationMetricLosses).mean().item(); - iterationLoss - += metricLossWeight_ - * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_) - * iterationMetricLoss; - } - float iterationClusteringLoss = 0; - if(clusteringLossWeight_ != 0) { - iterationClusteringLoss - = torch::tensor(iterationClusteringLosses).mean().item(); - iterationLoss - += clusteringLossWeight_ - * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_) - * iterationClusteringLoss; - } - float iterationTrackingLoss = 0; - if(trackingLossWeight_ != 0) { - iterationTrackingLoss - = torch::tensor(iterationTrackingLosses).mean().item(); - iterationLoss += trackingLossWeight_ * iterationTrackingLoss; - } - printLoss(iterationLoss, iterationRecLoss, iterationMetricLoss, - iterationClusteringLoss, iterationTrackingLoss, iteration, - iteration, t_alg.getElapsedTime() - t_allVectorCopy_time_, - debug::Priority::VERBOSE); - - // --- Update best parameters - bool isBest = isBestLoss(iterationLoss, minLoss, cptBlocked); - if(isBest) { - Timer t_copy; - bestIteration = iteration; - copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_, origins2_, - origins2Prime_, vS2Tensor_, vS2PrimeTensor_, allAlphas_, - bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, - bestOrigins2, bestOrigins2Prime, bestVS2Tensor, - bestVS2PrimeTensor, bestAlphasInit); - copyParams(recs_, bestRecs); - copyParams(recs2_, bestRecs2); - t_allVectorCopy_time_ += t_copy.getElapsedTime(); - bestTime = t_alg.getElapsedTime() - t_allVectorCopy_time_; - minRecLoss = iterationRecLoss; - minMetricLoss = iterationMetricLoss; - minClustLoss = iterationClusteringLoss; - minTrackLoss = iterationTrackingLoss; - printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss, - bestIteration, iteration, bestTime, debug::Priority::DETAIL); - } - - // --- Convergence - windowLoss += iterationLoss; - if((iteration + 1) % convWinSize == 0) { - windowLoss /= convWinSize; - converged = convergenceStep(windowLoss, oldLoss, minLoss, cptBlocked); - windowLoss = 0; - if(converged) { - ++noConverged; - } else - noConverged = 0; - converged = noConverged >= noConvergedToGet; - if(converged and iteration < minIteration_) - printMsg("convergence is detected but iteration < minIteration_", - debug::Priority::DETAIL); - if(iteration < minIteration_) - converged = false; - if(converged) - break; - } - - // --- Print - if(iteration % iterationGap_ == 0) { - printMsg("Assignment", 1, assignmentTime, threadNumber_); - printMsg("Loss", 1, lossTime, threadNumber_); - printMsg("Update", 1, updateTime, threadNumber_); - printMsg("Projection", 1, projectionTime, threadNumber_); - assignmentTime = 0.0; - lossTime = 0.0; - updateTime = 0.0; - projectionTime = 0.0; - std::stringstream ss; - float loss = torch::tensor(losses).mean().item(); - losses.clear(); - ss << "Rec. loss = " << loss; - printMsg(ss.str()); - if(metricLossWeight_ != 0) { - float metricLoss = torch::tensor(metricLosses).mean().item(); - metricLosses.clear(); - ss.str(""); - ss << "Metric loss = " << metricLoss; - printMsg(ss.str()); - } - if(clusteringLossWeight_ != 0) { - float clusteringLoss - = torch::tensor(clusteringLosses).mean().item(); - clusteringLosses.clear(); - ss.str(""); - ss << "Clust. loss = " << clusteringLoss; - printMsg(ss.str()); - } - if(trackingLossWeight_ != 0) { - float trackingLoss = torch::tensor(trackingLosses).mean().item(); - trackingLosses.clear(); - ss.str(""); - ss << "Track. loss = " << trackingLoss; - printMsg(ss.str()); - } - - // Verify grad and big values (testing) - for(unsigned int l = 0; l < noLayers_; ++l) { - ss.str(""); - if(originsNoZeroGrad_[l] != 0) - ss << originsNoZeroGrad_[l] << " originsNoZeroGrad_[" << l << "]" - << std::endl; - if(originsPrimeNoZeroGrad_[l] != 0) - ss << originsPrimeNoZeroGrad_[l] << " originsPrimeNoZeroGrad_[" << l - << "]" << std::endl; - if(vSNoZeroGrad_[l] != 0) - ss << vSNoZeroGrad_[l] << " vSNoZeroGrad_[" << l << "]" << std::endl; - if(vSPrimeNoZeroGrad_[l] != 0) - ss << vSPrimeNoZeroGrad_[l] << " vSPrimeNoZeroGrad_[" << l << "]" - << std::endl; - originsNoZeroGrad_[l] = 0; - originsPrimeNoZeroGrad_[l] = 0; - vSNoZeroGrad_[l] = 0; - vSPrimeNoZeroGrad_[l] = 0; - if(useDoubleInput_) { - if(origins2NoZeroGrad_[l] != 0) - ss << origins2NoZeroGrad_[l] << " origins2NoZeroGrad_[" << l << "]" - << std::endl; - if(origins2PrimeNoZeroGrad_[l] != 0) - ss << origins2PrimeNoZeroGrad_[l] << " origins2PrimeNoZeroGrad_[" - << l << "]" << std::endl; - if(vS2NoZeroGrad_[l] != 0) - ss << vS2NoZeroGrad_[l] << " vS2NoZeroGrad_[" << l << "]" - << std::endl; - if(vS2PrimeNoZeroGrad_[l] != 0) - ss << vS2PrimeNoZeroGrad_[l] << " vS2PrimeNoZeroGrad_[" << l << "]" - << std::endl; - origins2NoZeroGrad_[l] = 0; - origins2PrimeNoZeroGrad_[l] = 0; - vS2NoZeroGrad_[l] = 0; - vS2PrimeNoZeroGrad_[l] = 0; - } - if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_)) - ss << "origins_[" << l << "] has big values!" << std::endl; - if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_)) - ss << "originsPrime_[" << l << "] has big values!" << std::endl; - if(ss.rdbuf()->in_avail() != 0) - printMsg(ss.str(), debug::Priority::DETAIL); - } - } - - ++iteration; - if(maxIteration_ != 0 and iteration >= maxIteration_) { - printMsg("iteration >= maxIteration_", debug::Priority::DETAIL); - break; - } + if(clusteringLossWeight_ != 0) { + ssBestLoss.str(""); + ssBestLoss << "- Clust. " << prefix.str() << "loss = " << customLoss[2]; + printMsg(ssBestLoss.str(), priority); + } + if(trackingLossWeight_ != 0) { + ssBestLoss.str(""); + ssBestLoss << "- Track. " << prefix.str() << "loss = " << customLoss[3]; + printMsg(ssBestLoss.str(), priority); } - printMsg(debug::Separator::L2); - printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss, - bestIteration, iteration, bestTime); - printMsg(debug::Separator::L2); - bestLoss_ = minLoss; +} - Timer t_copy; - copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, - bestOrigins2, bestOrigins2Prime, bestVS2Tensor, bestVS2PrimeTensor, - bestAlphasInit, origins_, originsPrime_, vSTensor_, vSPrimeTensor_, - origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_, - allAlphas_); - copyParams(bestRecs, recs_); - copyParams(bestRecs2, recs2_); - t_allVectorCopy_time_ += t_copy.getElapsedTime(); - printMsg("Copy time", 1, t_allVectorCopy_time_, threadNumber_); +void ttk::MergeTreeAutoencoder::printGapLoss( + float loss, std::vector> &gapCustomLosses) { + std::stringstream ss; + ss << "Rec. loss = " << loss; + printMsg(ss.str()); + if(metricLossWeight_ != 0) { + float metricLoss = torch::tensor(gapCustomLosses[0]).mean().item(); + gapCustomLosses[0].clear(); + ss.str(""); + ss << "Metric loss = " << metricLoss; + printMsg(ss.str()); + } + if(clusteringLossWeight_ != 0) { + float clusteringLoss + = torch::tensor(gapCustomLosses[1]).mean().item(); + gapCustomLosses[1].clear(); + ss.str(""); + ss << "Clust. loss = " << clusteringLoss; + printMsg(ss.str()); + } + if(trackingLossWeight_ != 0) { + float trackingLoss = torch::tensor(gapCustomLosses[2]).mean().item(); + gapCustomLosses[2].clear(); + ss.str(""); + ss << "Track. loss = " << trackingLoss; + printMsg(ss.str()); + } } // --------------------------------------------------------------------------- @@ -2044,165 +553,6 @@ double ttk::MergeTreeAutoencoder::getCustomLossDynamicWeight(double recLoss, return 1.0; } -void ttk::MergeTreeAutoencoder::getDistanceMatrix( - std::vector> &tmts, - std::vector> &distanceMatrix, - bool useDoubleInput, - bool isFirstInput) { - distanceMatrix.clear(); - distanceMatrix.resize(tmts.size(), std::vector(tmts.size(), 0)); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \ - shared(distanceMatrix, tmts) - { -#pragma omp single nowait - { -#endif - for(unsigned int i = 0; i < tmts.size(); ++i) { - for(unsigned int j = i + 1; j < tmts.size(); ++j) { -#ifdef TTK_ENABLE_OPENMP -#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j) - { -#endif - std::vector> matching; - float distance; - bool isCalled = true; - computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance, - isCalled, useDoubleInput, isFirstInput); - distance = distance * distance; - distanceMatrix[i][j] = distance; - distanceMatrix[j][i] = distance; -#ifdef TTK_ENABLE_OPENMP - } // pragma omp task -#endif - } - } -#ifdef TTK_ENABLE_OPENMP -#pragma omp taskwait - } // pragma omp single nowait - } // pragma omp parallel -#endif -} - -void ttk::MergeTreeAutoencoder::getDistanceMatrix( - std::vector> &tmts, - std::vector> &tmts2, - std::vector> &distanceMatrix) { - getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_); - if(useDoubleInput_) { - std::vector> distanceMatrix2; - getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_, false); - mixDistancesMatrix(distanceMatrix, distanceMatrix2); - } -} - -void ttk::MergeTreeAutoencoder::getDifferentiableDistanceFromMatchings( - mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &tree1_2, - mtu::TorchMergeTree &tree2_2, - std::vector> &matchings, - std::vector> &matchings2, - torch::Tensor &tensorDist, - bool doSqrt) { - torch::Tensor reorderedITensor, reorderedJTensor; - dataReorderingGivenMatching( - tree1, tree2, matchings, reorderedITensor, reorderedJTensor); - if(useDoubleInput_) { - torch::Tensor reorderedI2Tensor, reorderedJ2Tensor; - dataReorderingGivenMatching( - tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor); - reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor}); - reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor}); - } - tensorDist = (reorderedITensor - reorderedJTensor).pow(2).sum(); - if(doSqrt) - tensorDist = tensorDist.sqrt(); -} - -void ttk::MergeTreeAutoencoder::getDifferentiableDistance( - mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &tree1_2, - mtu::TorchMergeTree &tree2_2, - torch::Tensor &tensorDist, - bool isCalled, - bool doSqrt) { - std::vector> matchings, - matchings2; - float distance; - computeOneDistance( - tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_); - if(useDoubleInput_) { - float distance2; - computeOneDistance(tree1_2.mTree, tree2_2.mTree, matchings2, - distance2, isCalled, useDoubleInput_, false); - } - getDifferentiableDistanceFromMatchings( - tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt); -} - -void ttk::MergeTreeAutoencoder::getDifferentiableDistance( - mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - torch::Tensor &tensorDist, - bool isCalled, - bool doSqrt) { - mtu::TorchMergeTree tree1_2, tree2_2; - getDifferentiableDistance( - tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt); -} - -void ttk::MergeTreeAutoencoder::getDifferentiableDistanceMatrix( - std::vector *> &trees, - std::vector *> &trees2, - std::vector> &outDistMat) { - outDistMat.resize(trees.size(), std::vector(trees.size())); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \ - shared(trees, trees2, outDistMat) - { -#pragma omp single nowait - { -#endif - for(unsigned int i = 0; i < trees.size(); ++i) { - outDistMat[i][i] = torch::tensor(0); - for(unsigned int j = i + 1; j < trees.size(); ++j) { -#ifdef TTK_ENABLE_OPENMP -#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j) - { -#endif - bool isCalled = true; - bool doSqrt = false; - torch::Tensor tensorDist; - getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]), - *(trees2[j]), tensorDist, isCalled, - doSqrt); - outDistMat[i][j] = tensorDist; - outDistMat[j][i] = tensorDist; -#ifdef TTK_ENABLE_OPENMP - } // pragma omp task -#endif - } - } -#ifdef TTK_ENABLE_OPENMP -#pragma omp taskwait - } // pragma omp single nowait - } // pragma omp parallel -#endif -} - -void ttk::MergeTreeAutoencoder::getAlphasTensor( - std::vector> &alphas, - std::vector &indexes, - unsigned int layerIndex, - torch::Tensor &alphasOut) { - alphasOut = alphas[indexes[0]][layerIndex].transpose(0, 1); - for(unsigned int ind = 1; ind < indexes.size(); ++ind) - alphasOut = torch::cat( - {alphasOut, alphas[indexes[ind]][layerIndex].transpose(0, 1)}); -} - void ttk::MergeTreeAutoencoder::computeMetricLoss( std::vector>> &layersOuts, std::vector>> &layersOuts2, @@ -2228,7 +578,7 @@ void ttk::MergeTreeAutoencoder::computeMetricLoss( getDifferentiableDistanceMatrix(trees, trees2, outDistMat); } else { std::vector> scaledAlphas; - createScaledAlphas(alphas, vSTensor_, scaledAlphas); + createScaledAlphas(alphas, scaledAlphas); torch::Tensor latentAlphas; getAlphasTensor(scaledAlphas, indexes, layerIndex, latentAlphas); if(customLossActivate_) @@ -2300,8 +650,10 @@ void ttk::MergeTreeAutoencoder::computeTrackingLoss( num_threads(this->threadNumber_) if(parallelize_) #endif for(unsigned int l = 0; l < endLayer; ++l) { - auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]); + auto &tree1 + = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime()); + auto &tree2 + = (l == 0 ? layers_[0].getOriginPrime() : layers_[l].getOriginPrime()); torch::Tensor tensorDist; bool isCalled = true, doSqrt = false; getDifferentiableDistance(tree1, tree2, tensorDist, isCalled, doSqrt); @@ -2315,9 +667,7 @@ void ttk::MergeTreeAutoencoder::computeTrackingLoss( // --------------------------------------------------------------------------- // --- End Functions // --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::createCustomRecs( - std::vector> &origins, - std::vector> &originsPrime) { +void ttk::MergeTreeAutoencoder::createCustomRecs() { if(customAlphas_.empty()) return; @@ -2358,10 +708,7 @@ void ttk::MergeTreeAutoencoder::createCustomRecs( outs.resize(noOuts); outs2.resize(noOuts); mtu::TorchMergeTree out, out2; - outputBasisReconstruction(originsPrime[latLayer], vSPrimeTensor_[latLayer], - origins2Prime_[latLayer], - vS2PrimeTensor_[latLayer], alphas, outs[0], - outs2[0]); + layers_[latLayer].outputBasisReconstruction(alphas, outs[0], outs2[0]); // Decoding unsigned int k = 32; for(unsigned int l = latLayer + 1; l < noLayers_; ++l) { @@ -2369,7 +716,8 @@ void ttk::MergeTreeAutoencoder::createCustomRecs( std::vector allAlphasInit(noIter); torch::Tensor maxNorm; for(unsigned int j = 0; j < allAlphasInit.size(); ++j) { - allAlphasInit[j] = torch::randn({vSTensor_[l].sizes()[1], 1}); + allAlphasInit[j] + = torch::randn({layers_[l].getVSTensor().sizes()[1], 1}); auto norm = torch::linalg::vector_norm( allAlphasInit[j], 2, 0, false, c10::nullopt); if(j == 0 or maxNorm.item() < norm.item()) @@ -2389,11 +737,9 @@ void ttk::MergeTreeAutoencoder::createCustomRecs( alphasInit = allAlphasInit[j]; } float distance; - forwardOneLayer(outs[outIndex - 1], origins[l], vSTensor_[l], - originsPrime[l], vSPrimeTensor_[l], outs2[outIndex - 1], - origins2_[l], vS2Tensor_[l], origins2Prime_[l], - vS2PrimeTensor_[l], k, alphasInit, outToUse, - outs2[outIndex], dataAlphas, distance); + layers_[l].forward(outs[outIndex - 1], outs2[outIndex - 1], k, + alphasInit, outToUse, outs2[outIndex], dataAlphas, + distance); if(distance < bestDistance) { bestDistance = distance; mtu::copyTorchMergeTree( @@ -2411,16 +757,16 @@ void ttk::MergeTreeAutoencoder::createCustomRecs( for(unsigned int i = 0; i < customRecs_.size(); ++i) { bool isCalled = true; float distance; - computeOneDistance(origins[0].mTree, customRecs_[i].mTree, - customMatchings_[i], distance, isCalled, - useDoubleInput_); + computeOneDistance(layers_[0].getOrigin().mTree, + customRecs_[i].mTree, customMatchings_[i], + distance, isCalled, useDoubleInput_); } mtu::TorchMergeTree originCopy; - mtu::copyTorchMergeTree(origins[0], originCopy); + mtu::copyTorchMergeTree(layers_[0].getOrigin(), originCopy); postprocessingPipeline(&(originCopy.mTree.tree)); for(unsigned int i = 0; i < customRecs_.size(); ++i) { - wae::fixTreePrecisionScalars(customRecs_[i].mTree); + fixTreePrecisionScalars(customRecs_[i].mTree); postprocessingPipeline(&(customRecs_[i].mTree.tree)); if(not isPersistenceDiagram_) { convertBranchDecompositionMatching(&(originCopy.mTree.tree), @@ -2430,144 +776,9 @@ void ttk::MergeTreeAutoencoder::createCustomRecs( } } -void ttk::MergeTreeAutoencoder::computeTrackingInformation() { - unsigned int latentLayerIndex = getLatentLayerIndex() + 1; - originsMatchings_.resize(latentLayerIndex); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) \ - num_threads(this->threadNumber_) if(parallelize_) -#endif - for(unsigned int l = 0; l < latentLayerIndex; ++l) { - auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]); - bool isCalled = true; - float distance; - computeOneDistance(tree1.mTree, tree2.mTree, originsMatchings_[l], - distance, isCalled, useDoubleInput_); - } - - // Data matchings - ++latentLayerIndex; - dataMatchings_.resize(latentLayerIndex); - for(unsigned int l = 0; l < latentLayerIndex; ++l) { - dataMatchings_[l].resize(recs_.size()); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) \ - num_threads(this->threadNumber_) if(parallelize_) -#endif - for(unsigned int i = 0; i < recs_.size(); ++i) { - bool isCalled = true; - float distance; - auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - computeOneDistance(origin.mTree, recs_[i][l].mTree, - dataMatchings_[l][i], distance, isCalled, - useDoubleInput_); - } - } - - // Reconst matchings - reconstMatchings_.resize(recs_.size()); -#ifdef TTK_ENABLE_OPENMP -#pragma omp parallel for schedule(dynamic) \ - num_threads(this->threadNumber_) if(parallelize_) -#endif - for(unsigned int i = 0; i < recs_.size(); ++i) { - bool isCalled = true; - float distance; - auto l = recs_[i].size() - 1; - computeOneDistance(recs_[i][0].mTree, recs_[i][l].mTree, - reconstMatchings_[i], distance, isCalled, - useDoubleInput_); - } -} - -void ttk::MergeTreeAutoencoder::createScaledAlphas( - std::vector> &alphas, - std::vector &vSTensor, - std::vector> &scaledAlphas) { - scaledAlphas.clear(); - scaledAlphas.resize( - alphas.size(), std::vector(alphas[0].size())); - for(unsigned int l = 0; l < alphas[0].size(); ++l) { - torch::Tensor scale = vSTensor[l].pow(2).sum(0).sqrt(); - for(unsigned int i = 0; i < alphas.size(); ++i) { - scaledAlphas[i][l] = alphas[i][l] * scale.reshape({-1, 1}); - } - } -} - -void ttk::MergeTreeAutoencoder::createScaledAlphas() { - createScaledAlphas(allAlphas_, vSTensor_, allScaledAlphas_); -} - -void ttk::MergeTreeAutoencoder::createActivatedAlphas() { - allActAlphas_ = allAlphas_; - for(unsigned int i = 0; i < allActAlphas_.size(); ++i) - for(unsigned int j = 0; j < allActAlphas_[i].size(); ++j) - allActAlphas_[i][j] = activation(allActAlphas_[i][j]); - createScaledAlphas(allActAlphas_, vSTensor_, allActScaledAlphas_); -} - // --------------------------------------------------------------------------- // --- Utils // --------------------------------------------------------------------------- -void ttk::MergeTreeAutoencoder::copyParams( - std::vector> &srcOrigins, - std::vector> &srcOriginsPrime, - std::vector &srcVS, - std::vector &srcVSPrime, - std::vector> &srcOrigins2, - std::vector> &srcOrigins2Prime, - std::vector &srcVS2, - std::vector &srcVS2Prime, - std::vector> &srcAlphas, - std::vector> &dstOrigins, - std::vector> &dstOriginsPrime, - std::vector &dstVS, - std::vector &dstVSPrime, - std::vector> &dstOrigins2, - std::vector> &dstOrigins2Prime, - std::vector &dstVS2, - std::vector &dstVS2Prime, - std::vector> &dstAlphas) { - dstOrigins.resize(noLayers_); - dstOriginsPrime.resize(noLayers_); - dstVS.resize(noLayers_); - dstVSPrime.resize(noLayers_); - dstAlphas.resize(srcAlphas.size(), std::vector(noLayers_)); - if(useDoubleInput_) { - dstOrigins2.resize(noLayers_); - dstOrigins2Prime.resize(noLayers_); - dstVS2.resize(noLayers_); - dstVS2Prime.resize(noLayers_); - } - for(unsigned int l = 0; l < noLayers_; ++l) { - mtu::copyTorchMergeTree(srcOrigins[l], dstOrigins[l]); - mtu::copyTorchMergeTree(srcOriginsPrime[l], dstOriginsPrime[l]); - mtu::copyTensor(srcVS[l], dstVS[l]); - mtu::copyTensor(srcVSPrime[l], dstVSPrime[l]); - if(useDoubleInput_) { - mtu::copyTorchMergeTree(srcOrigins2[l], dstOrigins2[l]); - mtu::copyTorchMergeTree(srcOrigins2Prime[l], dstOrigins2Prime[l]); - mtu::copyTensor(srcVS2[l], dstVS2[l]); - mtu::copyTensor(srcVS2Prime[l], dstVS2Prime[l]); - } - for(unsigned int i = 0; i < srcAlphas.size(); ++i) - mtu::copyTensor(srcAlphas[i][l], dstAlphas[i][l]); - } -} - -void ttk::MergeTreeAutoencoder::copyParams( - std::vector>> &src, - std::vector>> &dst) { - dst.resize(src.size()); - for(unsigned int i = 0; i < src.size(); ++i) { - dst[i].resize(src[i].size()); - for(unsigned int j = 0; j < src[i].size(); ++j) - mtu::copyTorchMergeTree(src[i][j], dst[i][j]); - } -} - unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() { unsigned int idx = noLayers_ / 2 - 1; if(idx > noLayers_) // unsigned negativeness @@ -2575,128 +786,25 @@ unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() { return idx; } -// --------------------------------------------------------------------------- -// --- Testing -// --------------------------------------------------------------------------- -bool ttk::MergeTreeAutoencoder::isTreeHasBigValues(ftm::MergeTree &mTree, - float threshold) { - bool found = false; - for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) { - if(mTree.tree.isNodeAlone(n)) - continue; - auto birthDeath = mTree.tree.template getBirthDeath(n); - if(std::abs(std::get<0>(birthDeath)) > threshold - or std::abs(std::get<1>(birthDeath)) > threshold) { - found = true; - break; - } - } - return found; +void ttk::MergeTreeAutoencoder::copyCustomParams(bool get) { + auto &srcLatentCentroids = (get ? latentCentroids_ : bestLatentCentroids_); + auto &dstLatentCentroids = (!get ? latentCentroids_ : bestLatentCentroids_); + dstLatentCentroids.resize(srcLatentCentroids.size()); + for(unsigned int i = 0; i < dstLatentCentroids.size(); ++i) + mtu::copyTensor(srcLatentCentroids[i], dstLatentCentroids[i]); } #endif // --------------------------------------------------------------------------- // --- Main Functions // --------------------------------------------------------------------------- - -void ttk::MergeTreeAutoencoder::execute( +void ttk::MergeTreeAutoencoder::executeEndFunction( std::vector> &trees, - std::vector> &trees2) { -#ifndef TTK_ENABLE_TORCH - TTK_FORCE_USE(trees); - TTK_FORCE_USE(trees2); - printErr("This module requires Torch."); -#else -#ifdef TTK_ENABLE_OPENMP - int ompNested = omp_get_nested(); - omp_set_nested(1); -#endif - // --- Preprocessing - Timer t_preprocess; - preprocessingTrees(trees, treesNodeCorr_); - if(trees2.size() != 0) - preprocessingTrees(trees2, trees2NodeCorr_); - printMsg("Preprocessing", 1, t_preprocess.getElapsedTime(), threadNumber_); - useDoubleInput_ = (trees2.size() != 0); - - // --- Fit autoencoder - Timer t_total; - fit(trees, trees2); - auto totalTime = t_total.getElapsedTime() - t_allVectorCopy_time_; - printMsg(debug::Separator::L1); - printMsg("Total time", 1, totalTime, threadNumber_); - hasComputedOnce_ = true; - - // --- End functions - createScaledAlphas(); - createActivatedAlphas(); - computeTrackingInformation(); + std::vector> &ttkNotUsed(trees2)) { + // Tracking + computeTrackingInformation(getLatentLayerIndex() + 1); // Correlation - auto latLayer = getLatentLayerIndex(); - std::vector> allTs; - auto noGeod = allAlphas_[0][latLayer].sizes()[0]; - allTs.resize(noGeod); - for(unsigned int i = 0; i < noGeod; ++i) { - allTs[i].resize(allAlphas_.size()); - for(unsigned int j = 0; j < allAlphas_.size(); ++j) - allTs[i][j] = allAlphas_[j][latLayer][i].item(); - } - computeBranchesCorrelationMatrix(origins_[0].mTree, trees, dataMatchings_[0], - allTs, branchesCorrelationMatrix_, - persCorrelationMatrix_); + computeCorrelationMatrix(trees, getLatentLayerIndex()); // Custom recs - originsCopy_.resize(origins_.size()); - originsPrimeCopy_.resize(originsPrime_.size()); - for(unsigned int l = 0; l < origins_.size(); ++l) { - mtu::copyTorchMergeTree(origins_[l], originsCopy_[l]); - mtu::copyTorchMergeTree(originsPrime_[l], originsPrimeCopy_[l]); - } - createCustomRecs(originsCopy_, originsPrimeCopy_); - - // --- Postprocessing - if(createOutput_) { - for(unsigned int i = 0; i < trees.size(); ++i) - postprocessingPipeline(&(trees[i].tree)); - for(unsigned int i = 0; i < trees2.size(); ++i) - postprocessingPipeline(&(trees2[i].tree)); - for(unsigned int l = 0; l < origins_.size(); ++l) { - fillMergeTreeStructure(origins_[l]); - postprocessingPipeline(&(origins_[l].mTree.tree)); - fillMergeTreeStructure(originsPrime_[l]); - postprocessingPipeline(&(originsPrime_[l].mTree.tree)); - } - for(unsigned int j = 0; j < recs_[0].size(); ++j) { - for(unsigned int i = 0; i < recs_.size(); ++i) { - wae::fixTreePrecisionScalars(recs_[i][j].mTree); - postprocessingPipeline(&(recs_[i][j].mTree.tree)); - } - } - } - - if(not isPersistenceDiagram_) { - for(unsigned int l = 0; l < originsMatchings_.size(); ++l) { - auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]); - convertBranchDecompositionMatching( - &(tree1.mTree.tree), &(tree2.mTree.tree), originsMatchings_[l]); - } - for(unsigned int l = 0; l < dataMatchings_.size(); ++l) { - for(unsigned int i = 0; i < recs_.size(); ++i) { - auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - convertBranchDecompositionMatching(&(origin.mTree.tree), - &(recs_[i][l].mTree.tree), - dataMatchings_[l][i]); - } - } - for(unsigned int i = 0; i < reconstMatchings_.size(); ++i) { - auto l = recs_[i].size() - 1; - convertBranchDecompositionMatching(&(recs_[i][0].mTree.tree), - &(recs_[i][l].mTree.tree), - reconstMatchings_[i]); - } - } -#ifdef TTK_ENABLE_OPENMP - omp_set_nested(ompNested); -#endif -#endif + createCustomRecs(); } diff --git a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.h b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.h index 4c74c92cde..9beb8d95cb 100644 --- a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.h +++ b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoder.h @@ -24,7 +24,8 @@ // ttk common includes #include #include -#include +#include +#include #include #ifdef TTK_ENABLE_TORCH @@ -38,26 +39,15 @@ namespace ttk { * of merge trees or persistence diagrams. */ class MergeTreeAutoencoder : virtual public Debug, - public MergeTreeAxesAlgorithmBase { + public MergeTreeNeuralNetwork { protected: - bool doCompute_; - bool hasComputedOnce_ = false; - // Model hyper-parameters; int encoderNoLayers_ = 1; bool scaleLayerAfterLatent_ = false; unsigned int inputNumberOfAxes_ = 16; double inputOriginPrimeSizePercent_ = 15; double latentSpaceOriginPrimeSizePercent_ = 10; - unsigned int minIteration_ = 0; - unsigned int maxIteration_ = 0; - unsigned int iterationGap_ = 100; - double batchSize_ = 1; - int optimizer_ = 0; - double gradientStepSize_ = 0.1; - double beta1_ = 0.9; - double beta2_ = 0.999; double reconstructionLossWeight_ = 1; double trackingLossWeight_ = 0; double metricLossWeight_ = 0; @@ -67,66 +57,24 @@ namespace ttk { bool customLossSpace_ = false; bool customLossActivate_ = false; bool normalizeMetricLoss_ = false; - unsigned int noInit_ = 4; - bool euclideanVectorsInit_ = false; - bool initOriginPrimeStructByCopy_ = true; bool trackingLossDecoding_ = false; double trackingLossInitRandomness_ = 0.0; - bool activate_ = true; - unsigned int activationFunction_ = 1; - bool activateOutputInit_ = false; - - bool createOutput_ = true; // Old hyper-parameters bool fullSymmetricAE_ = false; #ifdef TTK_ENABLE_TORCH // Model optimized parameters - std::vector vSTensor_, vSPrimeTensor_, vS2Tensor_, - vS2PrimeTensor_, latentCentroids_; - std::vector> origins_, originsPrime_, origins2_, - origins2Prime_; + std::vector bestLatentCentroids_, latentCentroids_; - std::vector> originsCopy_, originsPrimeCopy_; + std::vector vSTensorCopy_, vSPrimeTensorCopy_; - // Filled by the algorithm - std::vector> allAlphas_, allScaledAlphas_, - allActAlphas_, allActScaledAlphas_; - std::vector>> recs_, recs2_; std::vector> customRecs_; #endif // Filled by the algorithm - unsigned noLayers_; double baseRecLoss_, baseRecLoss2_; - float bestLoss_; - std::vector clusterAsgn_; std::vector> distanceMatrix_, customAlphas_; - std::vector>> - baryMatchings_L0_, baryMatchings2_L0_; - std::vector inputToBaryDistances_L0_; - - // Tracking matchings - std::vector>> - originsMatchings_, reconstMatchings_, customMatchings_; - std::vector< - std::vector>>> - dataMatchings_; - std::vector> branchesCorrelationMatrix_, - persCorrelationMatrix_; - - // Testing - double t_allVectorCopy_time_ = 0.0; - std::vector originsNoZeroGrad_, originsPrimeNoZeroGrad_, - vSNoZeroGrad_, vSPrimeNoZeroGrad_, origins2NoZeroGrad_, - origins2PrimeNoZeroGrad_, vS2NoZeroGrad_, vS2PrimeNoZeroGrad_; - bool outputInit_ = true; -#ifdef TTK_ENABLE_TORCH - std::vector> initOrigins_, initOriginsPrime_, - initRecs_; -#endif - float bigValuesThreshold_ = 0; public: MergeTreeAutoencoder(); @@ -135,211 +83,25 @@ namespace ttk { // ----------------------------------------------------------------------- // --- Init // ----------------------------------------------------------------------- - void initOutputBasisTreeStructure(mtu::TorchMergeTree &originPrime, - bool isJT, - mtu::TorchMergeTree &baseOrigin); - - void initOutputBasis(unsigned int l, unsigned int dim, unsigned int dim2); - - void initOutputBasisVectors(unsigned int l, - torch::Tensor &w, - torch::Tensor &w2); - - void initOutputBasisVectors(unsigned int l, - unsigned int dim, - unsigned int dim2); - - void initInputBasisOrigin( - std::vector> &treesToUse, - std::vector> &trees2ToUse, - double barycenterSizeLimitPercent, - unsigned int barycenterMaxNoPairs, - unsigned int barycenterMaxNoPairs2, - mtu::TorchMergeTree &origin, - mtu::TorchMergeTree &origin2, - std::vector &inputToBaryDistances, - std::vector>> - &baryMatchings, - std::vector>> - &baryMatchings2); - - void initInputBasisVectors( - std::vector> &tmTreesToUse, - std::vector> &tmTrees2ToUse, - std::vector> &treesToUse, - std::vector> &trees2ToUse, - mtu::TorchMergeTree &origin, - mtu::TorchMergeTree &origin2, - unsigned int noVectors, - std::vector> &allAlphasInit, - unsigned int l, - std::vector &inputToBaryDistances, - std::vector>> - &baryMatchings, - std::vector>> - &baryMatchings2, - torch::Tensor &vSTensor, - torch::Tensor &vS2Tensor); - void initClusteringLossParameters(); - float initParameters(std::vector> &trees, - std::vector> &trees2, - bool computeReconstructionError = false); - - void initStep(std::vector> &trees, - std::vector> &trees2); - - // ----------------------------------------------------------------------- - // --- Interpolation - // ----------------------------------------------------------------------- - void interpolationDiagonalProjection( - mtu::TorchMergeTree &interpolationTensor); + bool initResetOutputBasis(unsigned int l, + unsigned int layerNoAxes, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain) override; - void - interpolationNestingProjection(mtu::TorchMergeTree &interpolation); - - void interpolationProjection(mtu::TorchMergeTree &interpolation); - - void getMultiInterpolation(mtu::TorchMergeTree &origin, - torch::Tensor &vS, - torch::Tensor &alphas, - mtu::TorchMergeTree &interpolation); - - // ----------------------------------------------------------------------- - // --- Forward - // ----------------------------------------------------------------------- - void getAlphasOptimizationTensors( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &interpolated, - std::vector> &matching, - torch::Tensor &reorderedTreeTensor, - torch::Tensor &deltaOrigin, - torch::Tensor &deltaA, - torch::Tensor &originTensor_f, - torch::Tensor &vSTensor_f); - - void computeAlphas( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &interpolated, - std::vector> &matching, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &interpolated2, - std::vector> &matching2, - torch::Tensor &alphasOut); - - float assignmentOneData( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - unsigned int k, - torch::Tensor &alphasInit, - std::vector> &bestMatching, - std::vector> &bestMatching2, - torch::Tensor &bestAlphas, - bool isCalled = false); - - float assignmentOneData(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - unsigned int k, - torch::Tensor &alphasInit, - torch::Tensor &bestAlphas, - bool isCalled = false); - - torch::Tensor activation(torch::Tensor &in); - - void outputBasisReconstruction(mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - torch::Tensor &alphas, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - bool activate = true); - - bool forwardOneLayer(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - unsigned int k, - torch::Tensor &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - torch::Tensor &bestAlphas, - float &bestDistance); - - bool forwardOneLayer(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &origin, - torch::Tensor &vSTensor, - mtu::TorchMergeTree &originPrime, - torch::Tensor &vSPrimeTensor, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &origin2, - torch::Tensor &vS2Tensor, - mtu::TorchMergeTree &origin2Prime, - torch::Tensor &vS2PrimeTensor, - unsigned int k, - torch::Tensor &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - torch::Tensor &bestAlphas); - - bool forwardOneData(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, - unsigned int treeIndex, - unsigned int k, - std::vector &alphasInit, - mtu::TorchMergeTree &out, - mtu::TorchMergeTree &out2, - std::vector &dataAlphas, - std::vector> &outs, - std::vector> &outs2); - - bool forwardStep( + void initOutputBasisSpecialCase( + unsigned int l, + unsigned int layerNoAxes, std::vector> &trees, - std::vector> &trees2, - std::vector &indexes, - unsigned int k, - std::vector> &allAlphasInit, - bool computeReconstructionError, - std::vector> &outs, - std::vector> &outs2, - std::vector> &bestAlphas, - std::vector>> &layersOuts, - std::vector>> &layersOuts2, - std::vector>> - &matchings, - std::vector>> - &matchings2, - float &loss); + std::vector> &trees2); - bool forwardStep(std::vector> &trees, - std::vector> &trees2, - std::vector &indexes, - unsigned int k, - std::vector> &allAlphasInit, - std::vector> &outs, - std::vector> &outs2, - std::vector> &bestAlphas); + float initParameters(std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain, + bool computeError = false) override; // ----------------------------------------------------------------------- // --- Backward @@ -353,16 +115,11 @@ namespace ttk { std::vector> &outs2, std::vector>> &matchings2, + std::vector> &alphas, torch::optim::Optimizer &optimizer, std::vector &indexes, - torch::Tensor &metricLoss, - torch::Tensor &clusteringLoss, - torch::Tensor &trackingLoss); - - // ----------------------------------------------------------------------- - // --- Projection - // ----------------------------------------------------------------------- - void projectionStep(); + std::vector &isTrain, + std::vector &torchCustomLoss) override; // ----------------------------------------------------------------------- // --- Convergence @@ -373,80 +130,49 @@ namespace ttk { mtu::TorchMergeTree &tree2, mtu::TorchMergeTree &out2, std::vector> &matching, - std::vector> &matching2); - - float computeLoss( - std::vector> &trees, - std::vector> &outs, - std::vector> &trees2, - std::vector> &outs2, - std::vector &indexes, - std::vector>> - &matchings, - std::vector>> - &matchings2); - - bool isBestLoss(float loss, float &minLoss, unsigned int &cptBlocked); - - bool convergenceStep(float loss, - float &oldLoss, - float &minLoss, - unsigned int &cptBlocked); + std::vector> &matching2, + std::vector &alphas, + unsigned int treeIndex) override; // ----------------------------------------------------------------------- // --- Main Functions // ----------------------------------------------------------------------- - void fit(std::vector> &trees, - std::vector> &trees2); + void + customInit(std::vector> &torchTrees, + std::vector> &torchTrees2) override; + + void addCustomParameters(std::vector ¶meters) override; + + void computeCustomLosses( + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector> &bestAlphas, + std::vector &indexes, + std::vector &isTrain, + unsigned int iteration, + std::vector> &gapCustomLosses, + std::vector> &iterationCustomLosses, + std::vector &torchCustomLoss) override; + + float computeIterationTotalLoss( + float iterationLoss, + std::vector> &iterationCustomLosses, + std::vector &iterationCustomLoss) override; + + void printCustomLosses(std::vector &customLoss, + std::stringstream &prefix, + const debug::Priority &priority + = debug::Priority::INFO) override; + + void + printGapLoss(float loss, + std::vector> &gapCustomLosses) override; // ----------------------------------------------------------------------- // --- Custom Losses // ----------------------------------------------------------------------- double getCustomLossDynamicWeight(double recLoss, double &baseLoss); - void getDistanceMatrix(std::vector> &tmts, - std::vector> &distanceMatrix, - bool useDoubleInput = false, - bool isFirstInput = true); - - void getDistanceMatrix(std::vector> &tmts, - std::vector> &tmts2, - std::vector> &distanceMatrix); - - void getDifferentiableDistanceFromMatchings( - mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &tree1_2, - mtu::TorchMergeTree &tree2_2, - std::vector> &matchings, - std::vector> &matchings2, - torch::Tensor &tensorDist, - bool doSqrt); - - void getDifferentiableDistance(mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - mtu::TorchMergeTree &tree1_2, - mtu::TorchMergeTree &tree2_2, - torch::Tensor &tensorDist, - bool isCalled, - bool doSqrt); - - void getDifferentiableDistance(mtu::TorchMergeTree &tree1, - mtu::TorchMergeTree &tree2, - torch::Tensor &tensorDist, - bool isCalled, - bool doSqrt); - - void getDifferentiableDistanceMatrix( - std::vector *> &trees, - std::vector *> &trees2, - std::vector> &outDistMat); - - void getAlphasTensor(std::vector> &alphas, - std::vector &indexes, - unsigned int layerIndex, - torch::Tensor &alphasOut); - void computeMetricLoss( std::vector>> &layersOuts, std::vector>> &layersOuts2, @@ -465,60 +191,22 @@ namespace ttk { // --------------------------------------------------------------------------- // --- End Functions // --------------------------------------------------------------------------- - void - createCustomRecs(std::vector> &origins, - std::vector> &originsPrime); - - void computeTrackingInformation(); - - void - createScaledAlphas(std::vector> &alphas, - std::vector &vSTensor, - std::vector> &scaledAlphas); - - void createScaledAlphas(); - - void createActivatedAlphas(); + void createCustomRecs(); // ----------------------------------------------------------------------- // --- Utils // ----------------------------------------------------------------------- - void copyParams(std::vector> &srcOrigins, - std::vector> &srcOriginsPrime, - std::vector &srcVS, - std::vector &srcVSPrime, - std::vector> &srcOrigins2, - std::vector> &srcOrigins2Prime, - std::vector &srcVS2, - std::vector &srcVS2Prime, - std::vector> &srcAlphas, - std::vector> &dstOrigins, - std::vector> &dstOriginsPrime, - std::vector &dstVS, - std::vector &dstVSPrime, - std::vector> &dstOrigins2, - std::vector> &dstOrigins2Prime, - std::vector &dstVS2, - std::vector &dstVS2Prime, - std::vector> &dstAlphas); - - void copyParams(std::vector>> &src, - std::vector>> &dst); - unsigned int getLatentLayerIndex(); - // ----------------------------------------------------------------------- - // --- Testing - // ----------------------------------------------------------------------- - bool isTreeHasBigValues(ftm::MergeTree &mTree, - float threshold = 10000); + void copyCustomParams(bool get) override; #endif // --------------------------------------------------------------------------- // --- Main Functions // --------------------------------------------------------------------------- - void execute(std::vector> &trees, - std::vector> &trees2); + void + executeEndFunction(std::vector> &trees, + std::vector> &trees2) override; }; // MergeTreeAutoencoder class } // namespace ttk diff --git a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.cpp b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.cpp deleted file mode 100644 index ee4eaae0a0..0000000000 --- a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.cpp +++ /dev/null @@ -1,394 +0,0 @@ -#include - -void ttk::wae::fixTreePrecisionScalars(ftm::MergeTree &mTree) { - double eps = 1e-6; - auto shiftSubtree - = [&mTree, &eps](ftm::idNode node, ftm::idNode birthNodeParent, - ftm::idNode deathNodeParent, std::vector &scalars, - bool invalidBirth, bool invalidDeath) { - std::queue queue; - queue.emplace(node); - while(!queue.empty()) { - ftm::idNode nodeT = queue.front(); - queue.pop(); - auto birthDeathNode = mTree.tree.getBirthDeathNode(node); - auto birthNode = std::get<0>(birthDeathNode); - auto deathNode = std::get<1>(birthDeathNode); - if(invalidBirth) - scalars[birthNode] = scalars[birthNodeParent] + 2 * eps; - if(invalidDeath) - scalars[deathNode] = scalars[deathNodeParent] - 2 * eps; - std::vector children; - mTree.tree.getChildren(nodeT, children); - for(auto &child : children) - queue.emplace(child); - } - }; - std::vector scalars; - getTreeScalars(mTree, scalars); - std::queue queue; - auto root = mTree.tree.getRoot(); - queue.emplace(root); - while(!queue.empty()) { - ftm::idNode node = queue.front(); - queue.pop(); - auto birthDeathNode = mTree.tree.getBirthDeathNode(node); - auto birthNode = std::get<0>(birthDeathNode); - auto deathNode = std::get<1>(birthDeathNode); - auto birthDeathNodeParent - = mTree.tree.getBirthDeathNode(mTree.tree.getParentSafe(node)); - auto birthNodeParent = std::get<0>(birthDeathNodeParent); - auto deathNodeParent = std::get<1>(birthDeathNodeParent); - bool invalidBirth = (scalars[birthNode] <= scalars[birthNodeParent] + eps); - bool invalidDeath = (scalars[deathNode] >= scalars[deathNodeParent] - eps); - if(!mTree.tree.isRoot(node) and (invalidBirth or invalidDeath)) - shiftSubtree(node, birthNodeParent, deathNodeParent, scalars, - invalidBirth, invalidDeath); - std::vector children; - mTree.tree.getChildren(node, children); - for(auto &child : children) - queue.emplace(child); - } - ftm::setTreeScalars(mTree, scalars); -} - -void ttk::wae::adjustNestingScalars(std::vector &scalarsVector, - ftm::idNode node, - ftm::idNode refNode) { - float birth = scalarsVector[refNode * 2]; - float death = scalarsVector[refNode * 2 + 1]; - auto getSign = [](float v) { return (v > 0 ? 1 : -1); }; - auto getPrecValue = [&getSign](float v, bool opp = false) { - return v * (1 + (opp ? -1 : 1) * getSign(v) * 1e-6); - }; - // Shift scalars - if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) { - float diff = scalarsVector[node * 2 + 1] - getPrecValue(death, true); - scalarsVector[node * 2] -= diff; - scalarsVector[node * 2 + 1] -= diff; - } else if(scalarsVector[node * 2] < getPrecValue(birth)) { - float diff = getPrecValue(birth) - scalarsVector[node * 2]; - scalarsVector[node * 2] += getPrecValue(diff); - scalarsVector[node * 2 + 1] += getPrecValue(diff); - } - // Cut scalars - if(scalarsVector[node * 2] < getPrecValue(birth)) - scalarsVector[node * 2] = getPrecValue(birth); - if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) - scalarsVector[node * 2 + 1] = getPrecValue(death, true); -} - -void ttk::wae::createBalancedBDT( - std::vector> &parents, - std::vector> &children, - std::vector &scalarsVector, - std::vector> &childrenFinal, - int threadNumber) { - // ----- Some variables - unsigned int noNodes = scalarsVector.size() / 2; - childrenFinal.resize(noNodes); - int mtLevel = ceil(log(noNodes * 2) / log(2)) + 1; - int bdtLevel = mtLevel - 1; - int noDim = bdtLevel; - - // ----- Get node levels - std::vector nodeLevels(noNodes, -1); - std::queue queueLevels; - std::vector noChildDone(noNodes, 0); - for(unsigned int i = 0; i < children.size(); ++i) { - if(children[i].size() == 0) { - queueLevels.emplace(i); - nodeLevels[i] = 1; - } - } - while(!queueLevels.empty()) { - ftm::idNode node = queueLevels.front(); - queueLevels.pop(); - for(auto &parent : parents[node]) { - ++noChildDone[parent]; - nodeLevels[parent] = std::max(nodeLevels[parent], nodeLevels[node] + 1); - if(noChildDone[parent] >= (int)children[parent].size()) - queueLevels.emplace(parent); - } - } - - // ----- Sort heuristic lambda - auto sortChildren = [&parents, &scalarsVector, &noNodes, &threadNumber]( - ftm::idNode nodeOrigin, std::vector &nodeDone, - std::vector> &childrenT) { - double refPers = scalarsVector[1] - scalarsVector[0]; - auto getRemaining = [&nodeDone](std::vector &vec) { - unsigned int remaining = 0; - for(auto &e : vec) - remaining += (not nodeDone[e]); - return remaining; - }; - std::vector parentsRemaining(noNodes, 0), - childrenRemaining(noNodes, 0); - for(auto &child : childrenT[nodeOrigin]) { - parentsRemaining[child] = getRemaining(parents[child]); - childrenRemaining[child] = getRemaining(childrenT[child]); - } - TTK_FORCE_USE(threadNumber); - TTK_PSORT( - threadNumber, childrenT[nodeOrigin].begin(), childrenT[nodeOrigin].end(), - [&](ftm::idNode nodeI, ftm::idNode nodeJ) { - double persI = scalarsVector[nodeI * 2 + 1] - scalarsVector[nodeI * 2]; - double persJ = scalarsVector[nodeJ * 2 + 1] - scalarsVector[nodeJ * 2]; - return parentsRemaining[nodeI] + childrenRemaining[nodeI] - - persI / refPers * noNodes - < parentsRemaining[nodeJ] + childrenRemaining[nodeJ] - - persJ / refPers * noNodes; - }); - }; - - // ----- Greedy approach to find balanced BDT structures - const auto findStructGivenDim = - [&children, &noNodes, &nodeLevels]( - ftm::idNode _nodeOrigin, int _dimToFound, bool _searchMaxDim, - std::vector &_nodeDone, std::vector &_dimFound, - std::vector> &_childrenFinalOut) { - // --- Recursive lambda - auto findStructGivenDimImpl = - [&children, &noNodes, &nodeLevels]( - ftm::idNode nodeOrigin, int dimToFound, bool searchMaxDim, - std::vector &nodeDone, std::vector &dimFound, - std::vector> &childrenFinalOut, - auto &findStructGivenDimRef) mutable { - childrenFinalOut.resize(noNodes); - // - Find structures - int dim = (searchMaxDim ? dimToFound - 1 : 0); - unsigned int i = 0; - // - auto searchMaxDimReset = [&i, &dim, &nodeDone]() { - --dim; - i = 0; - unsigned int noDone = 0; - for(auto done : nodeDone) - if(done) - ++noDone; - return noDone == nodeDone.size() - 1; // -1 for root - }; - while(i < children[nodeOrigin].size()) { - auto child = children[nodeOrigin][i]; - // Skip if child was already processed - if(nodeDone[child]) { - // If we have processed all children while searching for max - // dim then restart at the beginning to find a lower dim - if(searchMaxDim and i == children[nodeOrigin].size() - 1) { - if(searchMaxDimReset()) - break; - } else - ++i; - continue; - } - if(dim == 0) { - // Base case - childrenFinalOut[nodeOrigin].emplace_back(child); - nodeDone[child] = true; - dimFound[0] = true; - if(dimToFound <= 1 or searchMaxDim) - return true; - ++dim; - } else { - // General case - std::vector> childrenFinalDim; - std::vector nodeDoneDim; - std::vector dimFoundDim(dim); - bool found = false; - if(nodeLevels[child] > dim) { - nodeDoneDim = nodeDone; - found = findStructGivenDimRef(child, dim, false, nodeDoneDim, - dimFoundDim, childrenFinalDim, - findStructGivenDimRef); - } - if(found) { - dimFound[dim] = true; - childrenFinalOut[nodeOrigin].emplace_back(child); - for(unsigned int j = 0; j < childrenFinalDim.size(); ++j) - for(auto &e : childrenFinalDim[j]) - childrenFinalOut[j].emplace_back(e); - nodeDone[child] = true; - for(unsigned int j = 0; j < nodeDoneDim.size(); ++j) - nodeDone[j] = nodeDone[j] || nodeDoneDim[j]; - // Return if it is the last dim to found - if(dim == dimToFound - 1 and not searchMaxDim) - return true; - // Reset index if we search for the maximum dim - if(searchMaxDim) { - if(searchMaxDimReset()) - break; - } else { - ++dim; - } - continue; - } else if(searchMaxDim and i == children[nodeOrigin].size() - 1) { - // If we have processed all children while searching for max - // dim then restart at the beginning to find a lower dim - if(searchMaxDimReset()) - break; - continue; - } - } - ++i; - } - return false; - }; - return findStructGivenDimImpl(_nodeOrigin, _dimToFound, _searchMaxDim, - _nodeDone, _dimFound, _childrenFinalOut, - findStructGivenDimImpl); - }; - std::vector dimFound(noDim - 1, false); - std::vector nodeDone(noNodes, false); - for(unsigned int i = 0; i < children.size(); ++i) - sortChildren(i, nodeDone, children); - Timer t_find; - ftm::idNode startNode = 0; - findStructGivenDim(startNode, noDim, true, nodeDone, dimFound, childrenFinal); - - // ----- Greedy approach to create non found structures - const auto createStructGivenDim = - [&children, &noNodes, &findStructGivenDim, &nodeLevels]( - int _nodeOrigin, int _dimToCreate, std::vector &_nodeDone, - ftm::idNode &_structOrigin, std::vector &_scalarsVectorOut, - std::vector> &_childrenFinalOut) { - // --- Recursive lambda - auto createStructGivenDimImpl = - [&children, &noNodes, &findStructGivenDim, &nodeLevels]( - int nodeOrigin, int dimToCreate, std::vector &nodeDoneImpl, - ftm::idNode &structOrigin, std::vector &scalarsVectorOut, - std::vector> &childrenFinalOut, - auto &createStructGivenDimRef) mutable { - // Deduction of auto lambda type - if(false) - return; - // - Find structures of lower dimension - int dimToFound = dimToCreate - 1; - std::vector>> childrenFinalT(2); - std::array structOrigins; - for(unsigned int n = 0; n < 2; ++n) { - bool found = false; - for(unsigned int i = 0; i < children[nodeOrigin].size(); ++i) { - auto child = children[nodeOrigin][i]; - if(nodeDoneImpl[child]) - continue; - if(dimToFound != 0) { - if(nodeLevels[child] > dimToFound) { - std::vector dimFoundT(dimToFound, false); - childrenFinalT[n].clear(); - childrenFinalT[n].resize(noNodes); - std::vector nodeDoneImplFind = nodeDoneImpl; - found = findStructGivenDim(child, dimToFound, false, - nodeDoneImplFind, dimFoundT, - childrenFinalT[n]); - } - } else - found = true; - if(found) { - structOrigins[n] = child; - nodeDoneImpl[child] = true; - for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) { - for(auto &e : childrenFinalT[n][j]) { - childrenFinalOut[j].emplace_back(e); - nodeDoneImpl[e] = true; - } - } - break; - } - } // end for children[nodeOrigin] - if(not found) { - if(dimToFound <= 0) { - structOrigins[n] = std::numeric_limits::max(); - continue; - } - childrenFinalT[n].clear(); - childrenFinalT[n].resize(noNodes); - createStructGivenDimRef( - nodeOrigin, dimToFound, nodeDoneImpl, structOrigins[n], - scalarsVectorOut, childrenFinalT[n], createStructGivenDimRef); - for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) { - for(auto &e : childrenFinalT[n][j]) { - if(e == structOrigins[n]) - continue; - childrenFinalOut[j].emplace_back(e); - } - } - } - } // end for n - // - Combine both structures - if(structOrigins[0] == std::numeric_limits::max() - and structOrigins[1] == std::numeric_limits::max()) { - structOrigin = std::numeric_limits::max(); - return; - } - bool firstIsParent = true; - if(structOrigins[0] == std::numeric_limits::max()) - firstIsParent = false; - else if(structOrigins[1] == std::numeric_limits::max()) - firstIsParent = true; - else if(scalarsVectorOut[structOrigins[1] * 2 + 1] - - scalarsVectorOut[structOrigins[1] * 2] - > scalarsVectorOut[structOrigins[0] * 2 + 1] - - scalarsVectorOut[structOrigins[0] * 2]) - firstIsParent = false; - structOrigin = (firstIsParent ? structOrigins[0] : structOrigins[1]); - ftm::idNode modOrigin - = (firstIsParent ? structOrigins[1] : structOrigins[0]); - childrenFinalOut[nodeOrigin].emplace_back(structOrigin); - if(modOrigin != std::numeric_limits::max()) { - childrenFinalOut[structOrigin].emplace_back(modOrigin); - std::queue> queue; - queue.emplace(std::array{modOrigin, structOrigin}); - while(!queue.empty()) { - auto &nodeAndParent = queue.front(); - ftm::idNode node = nodeAndParent[0]; - ftm::idNode parent = nodeAndParent[1]; - queue.pop(); - adjustNestingScalars(scalarsVectorOut, node, parent); - // Push children - for(auto &child : childrenFinalOut[node]) - queue.emplace(std::array{child, node}); - } - } - return; - }; - return createStructGivenDimImpl( - _nodeOrigin, _dimToCreate, _nodeDone, _structOrigin, _scalarsVectorOut, - _childrenFinalOut, createStructGivenDimImpl); - }; - for(unsigned int i = 0; i < children.size(); ++i) - sortChildren(i, nodeDone, children); - Timer t_create; - for(unsigned int i = 0; i < dimFound.size(); ++i) { - if(dimFound[i]) - continue; - ftm::idNode structOrigin; - createStructGivenDim( - startNode, i, nodeDone, structOrigin, scalarsVector, childrenFinal); - } -} - -void ttk::wae::printPairs(ftm::MergeTree &mTree, bool useBD) { - std::stringstream ss; - if(mTree.tree.getRealNumberOfNodes() != 0) - ss = mTree.tree.template printPairsFromTree(useBD); - else { - std::vector nodeDone(mTree.tree.getNumberOfNodes(), false); - for(unsigned int i = 0; i < mTree.tree.getNumberOfNodes(); ++i) { - if(nodeDone[i]) - continue; - std::tuple pair - = std::make_tuple(i, mTree.tree.getNode(i)->getOrigin(), - mTree.tree.getNodePersistence(i)); - ss << std::get<0>(pair) << " (" - << mTree.tree.getValue(std::get<0>(pair)) << ") _ "; - ss << std::get<1>(pair) << " (" - << mTree.tree.getValue(std::get<1>(pair)) << ") _ "; - ss << std::get<2>(pair) << std::endl; - nodeDone[i] = true; - nodeDone[mTree.tree.getNode(i)->getOrigin()] = true; - } - } - ss << std::endl; - std::cout << ss.str(); -} diff --git a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.h b/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.h deleted file mode 100644 index 8862ba0e3d..0000000000 --- a/core/base/mergeTreeAutoencoder/MergeTreeAutoencoderUtils.h +++ /dev/null @@ -1,66 +0,0 @@ -/// \ingroup base -/// \author Mathieu Pont -/// \date 2023. -/// -/// \brief Provide utils methods related to Merge Trees and Persistence Diagrams -/// Auto-Encoders. - -#pragma once - -#include -#include -#include - -namespace ttk { - - namespace wae { - - /** - * @brief Fix the scalars of a merge tree to ensure that the nesting - * condition is respected. - * - * @param[in] mTree Merge tree to process. - */ - void fixTreePrecisionScalars(ftm::MergeTree &mTree); - - /** - * @brief Fix the scalars of a merge tree to ensure that the nesting - * condition is respected. - * - * @param[in] scalarsVector scalars array to process. - * @param[in] node node to adjust. - * @param[in] refNode reference node. - */ - void adjustNestingScalars(std::vector &scalarsVector, - ftm::idNode node, - ftm::idNode refNode); - - /** - * @brief Create a balanced BDT structure (for output basis initialization). - * - * @param[in] parents vector containing the possible parents for each node. - * @param[in] children vector containing the possible children for each - * node. - * @param[in] scalarsVector vector containing the scalars value. - * @param[out] childrenFinal output vector containing the children of each - * node, representing the tree structure. - * @param[in] threadNumber number of threads for parallel sort. - */ - void createBalancedBDT(std::vector> &parents, - std::vector> &children, - std::vector &scalarsVector, - std::vector> &childrenFinal, - int threadNumber = 1); - - /** - * @brief Util function to print pairs of a merge tree. - * - * @param[in] mTree merge tree to process. - * @param[in] useBD if the merge tree is in branch decomposition mode or - * not. - */ - void printPairs(ftm::MergeTree &mTree, bool useBD = true); - - } // namespace wae - -} // namespace ttk diff --git a/core/base/mergeTreeAutoencoderDecoding/MergeTreeAutoencoderDecoding.cpp b/core/base/mergeTreeAutoencoderDecoding/MergeTreeAutoencoderDecoding.cpp index f47c22a055..ba5c933c6a 100644 --- a/core/base/mergeTreeAutoencoderDecoding/MergeTreeAutoencoderDecoding.cpp +++ b/core/base/mergeTreeAutoencoderDecoding/MergeTreeAutoencoderDecoding.cpp @@ -1,5 +1,5 @@ #include -#include +#include ttk::MergeTreeAutoencoderDecoding::MergeTreeAutoencoderDecoding() { // inherited from Debug: prefix will be printed at the beginning of every msg @@ -37,21 +37,33 @@ void ttk::MergeTreeAutoencoderDecoding::execute( pt, nodeCorr, false); } } - mergeTreesToTorchTrees(originsTrees, origins_, normalizedWasserstein_, + mergeTreesToTorchTrees(originsTrees, originsCopy_, normalizedWasserstein_, allRevNodeCorr, allRevNodeCorrSize); - mergeTreesToTorchTrees(originsPrimeTrees, originsPrime_, + mergeTreesToTorchTrees(originsPrimeTrees, originsPrimeCopy_, normalizedWasserstein_, allRevNodeCorrPrime, allRevNodeCorrPrimeSize); + layers_.resize(noLayers_); + for(unsigned int l = 0; l < layers_.size(); ++l) { + layers_[l].setOrigin(originsCopy_[l]); + layers_[l].setVSTensor(vSTensorCopy_[l]); + layers_[l].setOriginPrime(originsPrimeCopy_[l]); + layers_[l].setVSPrimeTensor(vSPrimeTensorCopy_[l]); + initOriginPrimeValuesByCopy_ + = trackingLossWeight_ != 0 + and l < (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1); + initOriginPrimeValuesByCopyRandomness_ = trackingLossInitRandomness_; + passLayerParameters(layers_[l]); + } // --- Execute - if(allAlphas_[0].size() != originsPrime_.size()) { + if(allAlphas_[0].size() != originsPrimeCopy_.size()) { customAlphas_.resize(allAlphas_.size()); for(unsigned int i = 0; i < customAlphas_.size(); ++i) customAlphas_[i] = std::vector( allAlphas_[i][0].data_ptr(), allAlphas_[i][0].data_ptr() + allAlphas_[i][0].numel()); allAlphas_.clear(); - createCustomRecs(origins_, originsPrime_); + createCustomRecs(); } else { recs_.resize(allAlphas_.size()); for(unsigned int i = 0; i < recs_.size(); ++i) { @@ -59,21 +71,22 @@ void ttk::MergeTreeAutoencoderDecoding::execute( for(unsigned int l = 0; l < allAlphas_[i].size(); ++l) { torch::Tensor act = (activate_ ? activation(allAlphas_[i][l]) : allAlphas_[i][l]); - getMultiInterpolation( - originsPrime_[l], vSPrimeTensor_[l], act, recs_[i][l]); + layers_[l].getMultiInterpolation(layers_[l].getOriginPrime(), + layers_[l].getVSPrimeTensor(), act, + recs_[i][l]); } } } // --- Postprocessing - for(unsigned int l = 0; l < origins_.size(); ++l) { - postprocessingPipeline(&(origins_[l].mTree.tree)); - postprocessingPipeline(&(originsPrime_[l].mTree.tree)); + for(unsigned int l = 0; l < originsCopy_.size(); ++l) { + postprocessingPipeline(&(originsCopy_[l].mTree.tree)); + postprocessingPipeline(&(originsPrimeCopy_[l].mTree.tree)); } if(!recs_.empty()) { for(unsigned int j = 0; j < recs_[0].size(); ++j) { for(unsigned int i = 0; i < recs_.size(); ++i) { - wae::fixTreePrecisionScalars(recs_[i][j].mTree); + fixTreePrecisionScalars(recs_[i][j].mTree); postprocessingPipeline(&(recs_[i][j].mTree.tree)); } } diff --git a/core/base/mergeTreeClustering/MergeTreeBarycenter.h b/core/base/mergeTreeClustering/MergeTreeBarycenter.h index f74f1f6a28..ba74dc45ca 100644 --- a/core/base/mergeTreeClustering/MergeTreeBarycenter.h +++ b/core/base/mergeTreeClustering/MergeTreeBarycenter.h @@ -35,6 +35,8 @@ namespace ttk { double tol_ = 0.0; bool addNodes_ = true; bool deterministic_ = true; + int barycenterInitIndex_ = -1; + int barycenterMaxIter_ = -1; bool isCalled_ = false; bool progressiveBarycenter_ = false; double progressiveSpeedDivisor_ = 4.0; @@ -75,6 +77,14 @@ namespace ttk { deterministic_ = deterministicT; } + void setBarycenterInitIndex(int barycenterInitIndex) { + barycenterInitIndex_ = barycenterInitIndex; + } + + void setBarycenterMaxIter(int barycenterMaxIter) { + barycenterMaxIter_ = barycenterMaxIter; + } + void setProgressiveBarycenter(bool progressive) { progressiveBarycenter_ = progressive; } @@ -164,6 +174,10 @@ namespace ttk { double sizeLimitPercent, std::vector> &mTreesLimited) { mTreesLimited.resize(trees.size()); +#ifdef TTK_ENABLE_OPENMP4 +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif for(unsigned int i = 0; i < trees.size(); ++i) { mTreesLimited[i] = ftm::copyMergeTree(trees[i]); limitSizeBarycenter(mTreesLimited[i], trees, @@ -212,6 +226,8 @@ namespace ttk { unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool distMinimizer = true) { + if(barycenterInitIndex_ != -1) + return barycenterInitIndex_; std::vector> distanceMatrix, distanceMatrix2; bool const useDoubleInput = (trees2.size() != 0); getParametrizedDistanceMatrix(trees, distanceMatrix, @@ -282,11 +298,23 @@ namespace ttk { // ------------------------------------------------------------------------ // Update // ------------------------------------------------------------------------ + /** + * @brief Get information about the nodes to add in the barycenter. + * + * @param[in] nodeId1 node in the barycenter. + * @param[in] tree tree. + * @param[in] nodeId2 node (and its subtree) of tree to add as a children of + * nodeId1 in the barycenter. + * @param[out] newScalarsVector scalar values of the added nodes. + * @param[out] nodesToProcess vector of tuples containing for each node in + * tree its future parent in the barycenter (and the index of the tree). + * @param[in] nodeCpt number of nodes in barycenter. + * @param[in] i tree index. + */ template ftm::idNode getNodesAndScalarsToAdd( - ftm::MergeTree &ttkNotUsed(mTree1), ftm::idNode nodeId1, - ftm::FTMTree_MT *tree2, + ftm::FTMTree_MT *tree, ftm::idNode nodeId2, std::vector &newScalarsVector, std::vector> &nodesToProcess, @@ -297,16 +325,16 @@ namespace ttk { queue.emplace(nodeId2, nodeId1); nodesToProcess.emplace_back(nodeId2, nodeId1, i); while(!queue.empty()) { - auto queueTuple = queue.front(); + auto &queueTuple = queue.front(); queue.pop(); ftm::idNode const node = std::get<0>(queueTuple); // Get scalars newScalarsVector.push_back( - tree2->getValue(tree2->getNode(node)->getOrigin())); - newScalarsVector.push_back(tree2->getValue(node)); + tree->getValue(tree->getNode(node)->getOrigin())); + newScalarsVector.push_back(tree->getValue(node)); // Process children std::vector children; - tree2->getChildren(node, children); + tree->getChildren(node, children); for(auto child : children) { queue.emplace(child, nodeCpt + 1); nodesToProcess.emplace_back(child, nodeCpt + 1, i); @@ -329,7 +357,7 @@ namespace ttk { // Add nodes nodesProcessed.clear(); nodesProcessed.resize(noTrees); - for(auto processTuple : nodesToProcess) { + for(auto &processTuple : nodesToProcess) { ftm::idNode const parent = std::get<1>(processTuple); ftm::idNode const nodeTree1 = tree1->getNumberOfNodes(); int const index = std::get<2>(processTuple); @@ -385,10 +413,10 @@ namespace ttk { std::vector> matrixMatchings(trees.size()); std::vector baryMatched(baryTree->getNumberOfNodes(), false); for(unsigned int i = 0; i < matchings.size(); ++i) { - auto matching = matchings[i]; + auto &matching = matchings[i]; matrixMatchings[i].resize(trees[i]->getNumberOfNodes(), std::numeric_limits::max()); - for(auto match : matching) { + for(auto &match : matching) { matrixMatchings[i][std::get<1>(match)] = std::get<0>(match); baryMatched[std::get<0>(match)] = true; } @@ -396,6 +424,10 @@ namespace ttk { // Iterate through trees to get the nodes to add in the barycenter std::vector> nodesToAdd(trees.size()); +#ifdef TTK_ENABLE_OPENMP4 +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif for(unsigned int i = 0; i < trees.size(); ++i) { ftm::idNode const root = trees[i]->getRoot(); std::queue queue; @@ -471,8 +503,7 @@ namespace ttk { parent = baryTree->getRoot();*/ std::vector addedScalars; nodeCpt = getNodesAndScalarsToAdd( - baryMergeTree, parent, trees[i], node, addedScalars, - nodesToProcess, nodeCpt, i); + parent, trees[i], node, addedScalars, nodesToProcess, nodeCpt, i); newScalarsVector.insert( newScalarsVector.end(), addedScalars.begin(), addedScalars.end()); } @@ -486,7 +517,7 @@ namespace ttk { for(unsigned int i = 0; i < matchings.size(); ++i) { std::vector> nodesProcessedT; - for(auto tup : nodesProcessed[i]) + for(auto &tup : nodesProcessed[i]) nodesProcessedT.emplace_back( std::get<0>(tup), std::get<1>(tup), -1); matchings[i].insert(matchings[i].end(), nodesProcessedT.begin(), @@ -521,11 +552,6 @@ namespace ttk { std::vector &trees, std::vector &nodes, std::vector &alphas) { - ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree); - dataType mu_max = getMinMaxLocalFromVector( - baryTree, nodeId, newScalarsVector, false); - dataType mu_min = getMinMaxLocalFromVector( - baryTree, nodeId, newScalarsVector); dataType newBirth = 0, newDeath = 0; // Compute projection @@ -562,6 +588,11 @@ namespace ttk { newDeath += alphas[i] * iDeath; } if(normalizedWasserstein_) { + ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree); + dataType mu_max = getMinMaxLocalFromVector( + baryTree, nodeId, newScalarsVector, false); + dataType mu_min = getMinMaxLocalFromVector( + baryTree, nodeId, newScalarsVector); // Forbid compiler optimization to have same results on different // computers volatile dataType tempBirthT = newBirth * (mu_max - mu_min); @@ -581,12 +612,6 @@ namespace ttk { ftm::MergeTree &baryMergeTree, ftm::idNode nodeB, std::vector &newScalarsVector) { - ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree); - dataType mu_max = getMinMaxLocalFromVector( - baryTree, nodeB, newScalarsVector, false); - dataType mu_min - = getMinMaxLocalFromVector(baryTree, nodeB, newScalarsVector); - auto birthDeath = getParametrizedBirthDeath(tree, nodeId); dataType newBirth = std::get<0>(birthDeath); dataType newDeath = std::get<1>(birthDeath); @@ -596,6 +621,11 @@ namespace ttk { newDeath = alpha * newDeath + (1 - alpha) * projec; if(normalizedWasserstein_) { + ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree); + dataType mu_max = getMinMaxLocalFromVector( + baryTree, nodeB, newScalarsVector, false); + dataType mu_min = getMinMaxLocalFromVector( + baryTree, nodeB, newScalarsVector); // Forbid compiler optimization to have same results on different // computers volatile dataType tempBirthT = newBirth * (mu_max - mu_min); @@ -622,17 +652,20 @@ namespace ttk { // m[i][j] contains the node in trees[j] matched to the node i in the // barycenter std::vector> baryMatching( - baryTree->getNumberOfNodes(), + indexAddedNodes, std::vector( trees.size(), std::numeric_limits::max())); - std::vector nodesAddedTree(baryTree->getNumberOfNodes(), -1); + std::vector> nodesAddedTree( + baryTree->getNumberOfNodes(), std::make_tuple(-1, -1)); for(unsigned int i = 0; i < matchings.size(); ++i) { - auto matching = matchings[i]; - for(auto match : matching) { - baryMatching[std::get<0>(match)][i] = std::get<1>(match); - if(std::get<0>(match) - >= indexAddedNodes) // get the tree of this added node - nodesAddedTree[std::get<0>(match)] = i; + auto &matching = matchings[i]; + for(auto &match : matching) { + if(std::get<0>(match) >= indexAddedNodes) + // get the tree of this added node + nodesAddedTree[std::get<0>(match)] + = std::make_tuple(i, std::get<1>(match)); + else + baryMatching[std::get<0>(match)][i] = std::get<1>(match); } } @@ -650,8 +683,8 @@ namespace ttk { = interpolation(baryMergeTree, node, newScalarsVector, trees, baryMatching[node], alphas); } else { - int const i = nodesAddedTree[node]; - ftm::idNode const nodeT = baryMatching[node][i]; + int const i = std::get<0>(nodesAddedTree[node]); + ftm::idNode const nodeT = std::get<1>(nodesAddedTree[node]); newBirthDeath = interpolationAdded( trees[i], nodeT, alphas[i], baryMergeTree, node, newScalarsVector); } @@ -679,7 +712,6 @@ namespace ttk { } setTreeScalars(baryMergeTree, newScalarsVector); - std::vector deletedNodesT; persistenceThresholding( &(baryMergeTree.tree), 0, deletedNodesT); @@ -941,6 +973,8 @@ namespace ttk { int NoIteration = 0; while(not converged) { ++NoIteration; + if(barycenterMaxIter_ != -1 and NoIteration > barycenterMaxIter_) + break; printMsg(debug::Separator::L2); std::stringstream ss; @@ -1049,11 +1083,12 @@ namespace ttk { // --- Preprocessing if(preprocess_) { treesNodeCorr_.resize(trees.size()); - for(unsigned int i = 0; i < trees.size(); ++i) + for(unsigned int i = 0; i < trees.size(); ++i) { preprocessingPipeline(trees[i], epsilonTree2_, epsilon2Tree2_, epsilon3Tree2_, branchDecomposition_, useMinMaxPair_, cleanTree_, treesNodeCorr_[i]); + } printTreesStats(trees); } @@ -1106,36 +1141,27 @@ namespace ttk { // ------------------------------------------------------------------------ // Preprocessing // ------------------------------------------------------------------------ - template - void limitSizePercent(ftm::MergeTree &bary, - std::vector &trees, - double percent, - bool useBD) { - auto metric = getSizeLimitMetric(trees); - unsigned int const newNoNodes = metric * percent / 100.0; - keepMostImportantPairs(&(bary.tree), newNoNodes, useBD); - - unsigned int const noNodesAfter = bary.tree.getRealNumberOfNodes(); - if(bary.tree.isFullMerge() and noNodesAfter > newNoNodes * 1.1 + 1 - and noNodesAfter > 3) { - std::cout << "metric = " << metric << std::endl; - std::cout << "newNoNodes = " << newNoNodes << std::endl; - std::cout << "noNodesAfter = " << noNodesAfter << std::endl; - } - } - template void limitSizeBarycenter(ftm::MergeTree &bary, std::vector &trees, unsigned int barycenterMaximumNumberOfPairs, double percent, bool useBD = true) { - if(barycenterMaximumNumberOfPairs > 0) - keepMostImportantPairs( - &(bary.tree), barycenterMaximumNumberOfPairs, useBD); - if(percent > 0) - limitSizePercent(bary, trees, percent, useBD); + auto metric = getSizeLimitMetric(trees); + unsigned int percentMaxPairs = metric * percent / 100.0; + + unsigned int newNoNodes; + if(barycenterMaximumNumberOfPairs > 0 and percent > 0) + newNoNodes = std::min(barycenterMaximumNumberOfPairs, percentMaxPairs); + else if(barycenterMaximumNumberOfPairs > 0) + newNoNodes = barycenterMaximumNumberOfPairs; + else if(percent > 0) + newNoNodes = percentMaxPairs; + else + return; + keepMostImportantPairs(&(bary.tree), newNoNodes, useBD); } + template void limitSizeBarycenter(ftm::MergeTree &bary, std::vector &trees, @@ -1144,6 +1170,7 @@ namespace ttk { limitSizeBarycenter( bary, trees, barycenterMaximumNumberOfPairs_, percent, useBD); } + template void limitSizeBarycenter(ftm::MergeTree &bary, std::vector &trees, @@ -1161,7 +1188,7 @@ namespace ttk { return; ftm::FTMTree_MT *tree = &(barycenter.tree); - auto tup = fixMergedRootOrigin(tree); + auto &tup = fixMergedRootOrigin(tree); int maxIndex = std::get<0>(tup); dataType oldOriginValue = std::get<1>(tup); @@ -1225,7 +1252,7 @@ namespace ttk { std::vector( trees.size(), std::numeric_limits::max())); for(unsigned int i = 0; i < finalMatchings.size(); ++i) - for(auto match : finalMatchings[i]) + for(auto &match : finalMatchings[i]) baryMatched[std::get<0>(match)][i] = std::get<1>(match); std::queue queue; diff --git a/core/base/mergeTreeClustering/MergeTreeBase.h b/core/base/mergeTreeClustering/MergeTreeBase.h index 64a1b0b49f..92bf729e37 100644 --- a/core/base/mergeTreeClustering/MergeTreeBase.h +++ b/core/base/mergeTreeClustering/MergeTreeBase.h @@ -151,6 +151,14 @@ namespace ttk { isPersistenceDiagram_ = isPD; } + void setJoinSplitMixtureCoefficient(const double mixtureCoefficient) { + mixtureCoefficient_ = mixtureCoefficient; + } + + void setUseDoubleInput(const bool useDoubleInput) { + useDoubleInput_ = useDoubleInput; + } + std::vector> getTreesNodeCorr() { return treesNodeCorr_; } @@ -202,9 +210,9 @@ namespace ttk { treeNodeMerged.clear(); treeNodeMerged.resize(tree->getNumberOfNodes()); + // need to have the pairing (if merge by persistence) if(mergeByPersistence) - ftm::computePersistencePairs( - tree); // need to have the pairing (if merge by persistence) + ftm::computePersistencePairs(tree); // Compute epsilon value dataType maxValue = tree->getValue(0); @@ -352,15 +360,47 @@ namespace ttk { } } + void deletePersistenceDiagramsPairs(ftm::FTMTree_MT *tree, + std::vector &nodes) { + std::vector arcs; + for(unsigned int i = 0; i < nodes.size(); ++i) { + ftm::idNode node = nodes[i]; + if(!tree->isRoot(node)) + arcs.emplace_back(tree->getNode(node)->getUpSuperArcId(0)); + tree->getNode(node)->clearDownSuperArcs(); + tree->getNode(node)->clearUpSuperArcs(); + ftm::idNode const nodeOrigin = tree->getNode(node)->getOrigin(); + if(tree->isNodeOriginDefined(node) + and tree->getNode(nodeOrigin)->getOrigin() == (int)node) { + if(!tree->isRoot(nodeOrigin)) + arcs.emplace_back(tree->getNode(nodeOrigin)->getUpSuperArcId(0)); + tree->getNode(nodeOrigin)->clearDownSuperArcs(); + tree->getNode(nodeOrigin)->clearUpSuperArcs(); + } + } + tree->getNode(tree->getRoot())->removeDownSuperArcs(arcs); + } + template void keepMostImportantPairs(ftm::FTMTree_MT *tree, int n, bool useBD) { std::vector> pairs; tree->getPersistencePairsFromTree(pairs, useBD); n = std::max(n, 2); // keep at least 2 pairs - int const index = std::max((int)(pairs.size() - n), 0); - dataType threshold = std::get<2>(pairs[index]) * (1.0 - 1e-6) - / tree->getMaximumPersistence() * 100.0; - persistenceThresholding(tree, threshold); + unsigned int const index = std::max((int)(pairs.size() - n), 0); + if(isPersistenceDiagram_) { + std::vector nodes(index); + for(unsigned int i = 0; i < index; ++i) + nodes[i] = std::get<0>(pairs[i]); + deletePersistenceDiagramsPairs(tree, nodes); + } else { + for(unsigned int i = 0; i < index; ++i) { + ftm::idNode node = std::get<0>(pairs[i]); + ftm::idNode nodeOrigin = std::get<1>(pairs[i]); + tree->deleteNode(node); + if(tree->getNode(nodeOrigin)->getOrigin() == (int)node) + tree->deleteNode(nodeOrigin); + } + } } template @@ -376,6 +416,7 @@ namespace ttk { if(threshold >= secondMax) threshold = (1.0 - 1e-6) * secondMax; + std::vector nodes; for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) { if(tree->isRoot(i)) continue; @@ -387,16 +428,23 @@ namespace ttk { } if((nodePers == 0 or nodePers <= threshold or not tree->isNodeOriginDefined(i))) { - tree->deleteNode(i); + if(not isPersistenceDiagram_) + tree->deleteNode(i); + else + nodes.emplace_back(i); deletedNodes.push_back(i); ftm::idNode const nodeOrigin = tree->getNode(i)->getOrigin(); if(tree->isNodeOriginDefined(i) and tree->getNode(nodeOrigin)->getOrigin() == (int)i) { - tree->deleteNode(nodeOrigin); + if(not isPersistenceDiagram_) + tree->deleteNode(nodeOrigin); deletedNodes.push_back(nodeOrigin); } } } + + if(isPersistenceDiagram_) + deletePersistenceDiagramsPairs(tree, nodes); } template @@ -448,15 +496,13 @@ namespace ttk { if(deleteInconsistentNodes) { // Manage inconsistent critical points // Critical points with same scalar value than parent - for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) - if(!tree->isNodeAlone(i) and !tree->isRoot(i) - and tree->getValue(tree->getParentSafe(i)) - == tree->getValue(i)) { - /*printMsg("[preprocessTree] " + std::to_string(i) - + " has same scalar value than parent (will be - deleted).");*/ - tree->deleteNode(i); - } + if(not isPersistenceDiagram_) + for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) + if(!tree->isNodeAlone(i) and !tree->isRoot(i) + and tree->getValue(tree->getParentSafe(i)) + == tree->getValue(i)) { + tree->deleteNode(i); + } // Valence 2 nodes for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) if(tree->getNode(i)->getNumberOfUpSuperArcs() == 1 @@ -1116,7 +1162,7 @@ namespace ttk { } template - dataType deleteCost(ftm::FTMTree_MT *tree, ftm::idNode nodeId) { + dataType deleteCost(const ftm::FTMTree_MT *tree, ftm::idNode nodeId) { dataType cost = 0; dataType newMin = 0.0, newMax = 1.0; // Get birth/death @@ -1139,14 +1185,14 @@ namespace ttk { } template - dataType insertCost(ftm::FTMTree_MT *tree, ftm::idNode nodeId) { + dataType insertCost(const ftm::FTMTree_MT *tree, ftm::idNode nodeId) { return deleteCost(tree, nodeId); } template - dataType relabelCostOnly(ftm::FTMTree_MT *tree1, + dataType relabelCostOnly(const ftm::FTMTree_MT *tree1, ftm::idNode nodeId1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree2, ftm::idNode nodeId2) { dataType cost = 0; dataType newMin = 0.0, newMax = 1.0; @@ -1175,9 +1221,9 @@ namespace ttk { } template - dataType relabelCost(ftm::FTMTree_MT *tree1, + dataType relabelCost(const ftm::FTMTree_MT *tree1, ftm::idNode nodeId1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree2, ftm::idNode nodeId2) { // Full merge case and only one persistence pair case if(tree1->getNode(nodeId1)->getOrigin() == (int)nodeId1 diff --git a/core/base/mergeTreeClustering/MergeTreeDistance.h b/core/base/mergeTreeClustering/MergeTreeDistance.h index 0ce6bb929f..fbbf82bab3 100644 --- a/core/base/mergeTreeClustering/MergeTreeDistance.h +++ b/core/base/mergeTreeClustering/MergeTreeDistance.h @@ -205,8 +205,8 @@ namespace ttk { template dataType forestAssignmentProblem( - ftm::FTMTree_MT *ttkNotUsed(tree1), - ftm::FTMTree_MT *ttkNotUsed(tree2), + const ftm::FTMTree_MT *ttkNotUsed(tree1), + const ftm::FTMTree_MT *ttkNotUsed(tree2), std::vector> &treeTable, std::vector &children1, std::vector &children2, @@ -232,8 +232,8 @@ namespace ttk { template void computeForestsDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, int i, int j, std::vector> &treeTable, @@ -299,7 +299,7 @@ namespace ttk { // ------------------------------------------------------------------------ template void computeForestToEmptyDistance( - ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree1, ftm::idNode nodeI, int i, std::vector> &treeTable, @@ -313,7 +313,7 @@ namespace ttk { template void computeSubtreeToEmptyDistance( - ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree1, ftm::idNode nodeI, int i, std::vector> &treeTable, @@ -323,7 +323,7 @@ namespace ttk { template void computeEmptyToForestDistance( - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree2, ftm::idNode nodeJ, int j, std::vector> &treeTable, @@ -337,7 +337,7 @@ namespace ttk { template void computeEmptyToSubtreeDistance( - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree2, ftm::idNode nodeJ, int j, std::vector> &treeTable, @@ -374,8 +374,8 @@ namespace ttk { template void computeSubtreesDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, int i, int j, ftm::idNode nodeI, @@ -427,8 +427,8 @@ namespace ttk { // -------------------------------------------------------------------------------- template void computeMatching( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, std::vector>> &treeBackTable, std::vector>>> &forestBackTable, @@ -479,8 +479,8 @@ namespace ttk { // ------------------------------------------------------------------------ template dataType - computeDistance(ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + computeDistance(const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, std::vector> &outputMatching) { // --------------------- @@ -544,8 +544,8 @@ namespace ttk { template dataType computeDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, std::vector> &outputMatching) { std::vector> realOutputMatching; @@ -651,8 +651,8 @@ namespace ttk { template void computeEditDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, std::vector> &treeTable, std::vector> &forestTable, std::vector>> &treeBackTable, @@ -694,8 +694,8 @@ namespace ttk { template void classicEditDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, bool processTree1, bool computeEmptyTree, ftm::idNode nodeI, @@ -769,8 +769,8 @@ namespace ttk { // ------------------------------------------------------------------------ template void parallelEditDistance( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, std::vector> &treeTable, std::vector> &forestTable, std::vector>> &treeBackTable, @@ -816,8 +816,8 @@ namespace ttk { // Forests and subtrees distances template void parallelTreeDistance_v2( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, bool isTree1, int i, std::vector &tree1Leaves, @@ -831,7 +831,7 @@ namespace ttk { &forestBackTable, bool firstCall = false) { ftm::idNode const nodeT = -1; - ftm::FTMTree_MT *treeT = (isTree1) ? tree1 : tree2; + const ftm::FTMTree_MT *treeT = (isTree1) ? tree1 : tree2; std::vector treeChildDone(treeT->getNumberOfNodes(), 0); std::vector treeNodeDone(treeT->getNumberOfNodes(), false); std::queue treeQueue; @@ -863,8 +863,8 @@ namespace ttk { // (isCalled_=false) template void parallelTreeDistancePara( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, bool isTree1, int i, std::vector &tree1Leaves, @@ -900,8 +900,8 @@ namespace ttk { template void parallelTreeDistanceTask( - ftm::FTMTree_MT *tree1, - ftm::FTMTree_MT *tree2, + const ftm::FTMTree_MT *tree1, + const ftm::FTMTree_MT *tree2, bool isTree1, int i, std::vector &tree1Leaves, @@ -933,7 +933,7 @@ namespace ttk { treeChildDone, treeNodeDone) if(isTree1) { #endif - ftm::FTMTree_MT *treeT = (isTree1) ? tree1 : tree2; + const ftm::FTMTree_MT *treeT = (isTree1) ? tree1 : tree2; // while(nodeT != -1){ while(!taskQueue.empty()) { nodeT = taskQueue.front(); @@ -1006,7 +1006,7 @@ namespace ttk { // Subtree/Forest with empty tree distances template void parallelEmptyTreeDistance_v2( - ftm::FTMTree_MT *tree, + const ftm::FTMTree_MT *tree, bool isTree1, std::vector &treeLeaves, std::vector &treeNodeChildSize, @@ -1035,7 +1035,7 @@ namespace ttk { template void parallelEmptyTreeDistancePara( - ftm::FTMTree_MT *tree, + const ftm::FTMTree_MT *tree, bool isTree1, std::vector &treeLeaves, std::vector &treeNodeChildSize, @@ -1064,7 +1064,7 @@ namespace ttk { template void parallelEmptyTreeDistanceTask( - ftm::FTMTree_MT *tree, + const ftm::FTMTree_MT *tree, bool isTree1, std::vector &ttkNotUsed(treeLeaves), std::vector &treeNodeChildSize, diff --git a/core/base/mergeTreeClustering/MergeTreeUtils.h b/core/base/mergeTreeClustering/MergeTreeUtils.h index 56d97b19e2..45c71ffde7 100644 --- a/core/base/mergeTreeClustering/MergeTreeUtils.h +++ b/core/base/mergeTreeClustering/MergeTreeUtils.h @@ -27,7 +27,7 @@ namespace ttk { // Normalized Wasserstein // -------------------- template - dataType getMinMaxLocal(ftm::FTMTree_MT *tree, + dataType getMinMaxLocal(const ftm::FTMTree_MT *tree, ftm::idNode nodeId, bool getMin = true) { auto nodeIdParent = tree->getParentSafe(nodeId); @@ -61,11 +61,11 @@ namespace ttk { } template - std::tuple getNormalizedBirthDeath(ftm::FTMTree_MT *tree, - ftm::idNode nodeId, - dataType newMin = 0.0, - dataType newMax - = 1.0) { + std::tuple + getNormalizedBirthDeath(const ftm::FTMTree_MT *tree, + ftm::idNode nodeId, + dataType newMin = 0.0, + dataType newMax = 1.0) { auto birthDeath = tree->getBirthDeath(nodeId); dataType birth = std::get<0>(birthDeath); dataType death = std::get<1>(birthDeath); diff --git a/core/base/mergeTreeNeuralNetwork/CMakeLists.txt b/core/base/mergeTreeNeuralNetwork/CMakeLists.txt new file mode 100644 index 0000000000..a14a79d022 --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/CMakeLists.txt @@ -0,0 +1,22 @@ +ttk_add_base_library(mergeTreeNeuralNetwork + SOURCES + MergeTreeNeuralBase.cpp + MergeTreeNeuralLayer.cpp + MergeTreeNeuralNetwork.cpp + MergeTreeTorchUtils.cpp + HEADERS + MergeTreeNeuralBase.h + MergeTreeNeuralLayer.h + MergeTreeNeuralNetwork.h + MergeTreeTorchUtils.h + DEPENDS + mergeTreePrincipalGeodesics + geometry +) + +if(TTK_ENABLE_TORCH) + target_include_directories(mergeTreeNeuralNetwork PUBLIC ${TORCH_INCLUDE_DIRS}) + target_compile_options(mergeTreeNeuralNetwork PUBLIC "${TORCH_CXX_FLAGS}") + target_link_libraries(mergeTreeNeuralNetwork PUBLIC "${TORCH_LIBRARIES}") + target_compile_definitions(mergeTreeNeuralNetwork PUBLIC TTK_ENABLE_TORCH) +endif() diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.cpp b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.cpp new file mode 100644 index 0000000000..2cd3f1606d --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.cpp @@ -0,0 +1,319 @@ +#include +#include + +#ifdef TTK_ENABLE_TORCH +using namespace torch::indexing; +#endif + +ttk::MergeTreeNeuralBase::MergeTreeNeuralBase() { + // inherited from Debug: prefix will be printed at the beginning of every msg + this->setDebugMsgPrefix("MergeTreeNeuralBase"); +} + +#ifdef TTK_ENABLE_TORCH +// ----------------------------------------------------------------------- +// --- Setter +// ----------------------------------------------------------------------- +void ttk::MergeTreeNeuralBase::setDropout(const double dropout) { + dropout_ = dropout; +} + +void ttk::MergeTreeNeuralBase::setEuclideanVectorsInit( + const bool euclideanVectorsInit) { + euclideanVectorsInit_ = euclideanVectorsInit; +} + +void ttk::MergeTreeNeuralBase::setRandomAxesInit(const bool randomAxesInit) { + randomAxesInit_ = randomAxesInit; +} + +void ttk::MergeTreeNeuralBase::setInitBarycenterRandom( + const bool initBarycenterRandom) { + initBarycenterRandom_ = initBarycenterRandom; +} + +void ttk::MergeTreeNeuralBase::setInitBarycenterOneIter( + const bool initBarycenterOneIter) { + initBarycenterOneIter_ = initBarycenterOneIter; +} + +void ttk::MergeTreeNeuralBase::setInitOriginPrimeStructByCopy( + const bool initOriginPrimeStructByCopy) { + initOriginPrimeStructByCopy_ = initOriginPrimeStructByCopy; +} + +void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopy( + const bool initOriginPrimeValuesByCopy) { + initOriginPrimeValuesByCopy_ = initOriginPrimeValuesByCopy; +} + +void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopyRandomness( + const double initOriginPrimeValuesByCopyRandomness) { + initOriginPrimeValuesByCopyRandomness_ + = initOriginPrimeValuesByCopyRandomness; +} + +void ttk::MergeTreeNeuralBase::setActivate(const bool activate) { + activate_ = activate; +} + +void ttk::MergeTreeNeuralBase::setActivationFunction( + const unsigned int activationFunction) { + activationFunction_ = activationFunction; +} + +void ttk::MergeTreeNeuralBase::setUseGpu(const bool useGpu) { + useGpu_ = useGpu; +} + +void ttk::MergeTreeNeuralBase::setBigValuesThreshold( + const float bigValuesThreshold) { + bigValuesThreshold_ = bigValuesThreshold; +} + +// ----------------------------------------------------------------------- +// --- Utils +// ----------------------------------------------------------------------- +torch::Tensor ttk::MergeTreeNeuralBase::activation(torch::Tensor &in) { + torch::Tensor act; + switch(activationFunction_) { + case 1: + act = torch::nn::LeakyReLU()(in); + break; + case 0: + default: + act = torch::nn::ReLU()(in); + } + return act; +} + +void ttk::MergeTreeNeuralBase::fixTreePrecisionScalars( + ftm::MergeTree &mTree) { + double eps = 1e-6; + auto shiftSubtree + = [&mTree, &eps](ftm::idNode node, ftm::idNode birthNodeParent, + ftm::idNode deathNodeParent, std::vector &scalars, + bool invalidBirth, bool invalidDeath) { + std::queue queue; + queue.emplace(node); + while(!queue.empty()) { + ftm::idNode nodeT = queue.front(); + queue.pop(); + auto birthDeathNode = mTree.tree.getBirthDeathNode(node); + auto birthNode = std::get<0>(birthDeathNode); + auto deathNode = std::get<1>(birthDeathNode); + if(invalidBirth) + scalars[birthNode] = scalars[birthNodeParent] + 2 * eps; + if(invalidDeath) + scalars[deathNode] = scalars[deathNodeParent] - 2 * eps; + std::vector children; + mTree.tree.getChildren(nodeT, children); + for(auto &child : children) + queue.emplace(child); + } + }; + std::vector scalars; + getTreeScalars(mTree, scalars); + std::queue queue; + auto root = mTree.tree.getRoot(); + queue.emplace(root); + while(!queue.empty()) { + ftm::idNode node = queue.front(); + queue.pop(); + auto birthDeathNode = mTree.tree.getBirthDeathNode(node); + auto birthNode = std::get<0>(birthDeathNode); + auto deathNode = std::get<1>(birthDeathNode); + auto birthDeathNodeParent + = mTree.tree.getBirthDeathNode(mTree.tree.getParentSafe(node)); + auto birthNodeParent = std::get<0>(birthDeathNodeParent); + auto deathNodeParent = std::get<1>(birthDeathNodeParent); + bool invalidBirth = (scalars[birthNode] <= scalars[birthNodeParent] + eps); + bool invalidDeath = (scalars[deathNode] >= scalars[deathNodeParent] - eps); + if(!mTree.tree.isRoot(node) and (invalidBirth or invalidDeath)) + shiftSubtree(node, birthNodeParent, deathNodeParent, scalars, + invalidBirth, invalidDeath); + std::vector children; + mTree.tree.getChildren(node, children); + for(auto &child : children) + queue.emplace(child); + } + ftm::setTreeScalars(mTree, scalars); +} + +void ttk::MergeTreeNeuralBase::printPairs(const ftm::MergeTree &mTree, + bool useBD) { + std::stringstream ss; + if(mTree.tree.getRealNumberOfNodes() != 0) + ss = mTree.tree.template printPairsFromTree(useBD); + else { + std::vector nodeDone(mTree.tree.getNumberOfNodes(), false); + for(unsigned int i = 0; i < mTree.tree.getNumberOfNodes(); ++i) { + if(nodeDone[i]) + continue; + std::tuple pair + = std::make_tuple(i, mTree.tree.getNode(i)->getOrigin(), + mTree.tree.getNodePersistence(i)); + ss << std::get<0>(pair) << " (" + << mTree.tree.getValue(std::get<0>(pair)) << ") _ "; + ss << std::get<1>(pair) << " (" + << mTree.tree.getValue(std::get<1>(pair)) << ") _ "; + ss << std::get<2>(pair) << std::endl; + nodeDone[i] = true; + nodeDone[mTree.tree.getNode(i)->getOrigin()] = true; + } + } + ss << std::endl; + std::cout << ss.str(); +} + +// ----------------------------------------------------------------------- +// --- Distance +// ----------------------------------------------------------------------- +void ttk::MergeTreeNeuralBase::getDistanceMatrix( + const std::vector> &tmts, + std::vector> &distanceMatrix, + bool useDoubleInput, + bool isFirstInput) { + distanceMatrix.clear(); + distanceMatrix.resize(tmts.size(), std::vector(tmts.size(), 0)); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \ + shared(distanceMatrix, tmts) + { +#pragma omp single nowait + { +#endif + for(unsigned int i = 0; i < tmts.size(); ++i) { + for(unsigned int j = i + 1; j < tmts.size(); ++j) { +#ifdef TTK_ENABLE_OPENMP +#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j) + { +#endif + std::vector> matching; + float distance; + bool isCalled = true; + computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance, + isCalled, useDoubleInput, isFirstInput); + distance = distance * distance; + distanceMatrix[i][j] = distance; + distanceMatrix[j][i] = distance; +#ifdef TTK_ENABLE_OPENMP + } // pragma omp task +#endif + } + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp taskwait + } // pragma omp single nowait + } // pragma omp parallel +#endif +} + +void ttk::MergeTreeNeuralBase::getDistanceMatrix( + const std::vector> &tmts, + const std::vector> &tmts2, + std::vector> &distanceMatrix) { + getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_); + if(useDoubleInput_) { + std::vector> distanceMatrix2; + getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_, false); + mixDistancesMatrix(distanceMatrix, distanceMatrix2); + } +} + +void ttk::MergeTreeNeuralBase::getDifferentiableDistanceFromMatchings( + const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree1_2, + const mtu::TorchMergeTree &tree2_2, + std::vector> &matchings, + std::vector> &matchings2, + torch::Tensor &tensorDist, + bool doSqrt) { + torch::Tensor reorderedITensor, reorderedJTensor; + dataReorderingGivenMatching( + tree1, tree2, matchings, reorderedITensor, reorderedJTensor); + if(useDoubleInput_) { + torch::Tensor reorderedI2Tensor, reorderedJ2Tensor; + dataReorderingGivenMatching( + tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor); + reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor}); + reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor}); + } + tensorDist = (reorderedITensor - reorderedJTensor).pow(2).sum(); + if(doSqrt) + tensorDist = tensorDist.sqrt(); +} + +void ttk::MergeTreeNeuralBase::getDifferentiableDistance( + const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree1_2, + const mtu::TorchMergeTree &tree2_2, + torch::Tensor &tensorDist, + bool isCalled, + bool doSqrt) { + std::vector> matchings, + matchings2; + float distance; + computeOneDistance( + tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_); + if(useDoubleInput_) { + float distance2; + computeOneDistance(tree1_2.mTree, tree2_2.mTree, matchings2, + distance2, isCalled, useDoubleInput_, false); + } + getDifferentiableDistanceFromMatchings( + tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt); +} + +void ttk::MergeTreeNeuralBase::getDifferentiableDistance( + const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + torch::Tensor &tensorDist, + bool isCalled, + bool doSqrt) { + mtu::TorchMergeTree tree1_2, tree2_2; + getDifferentiableDistance( + tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt); +} + +void ttk::MergeTreeNeuralBase::getDifferentiableDistanceMatrix( + const std::vector *> &trees, + const std::vector *> &trees2, + std::vector> &outDistMat) { + outDistMat.resize(trees.size(), std::vector(trees.size())); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \ + shared(trees, trees2, outDistMat) + { +#pragma omp single nowait + { +#endif + for(unsigned int i = 0; i < trees.size(); ++i) { + outDistMat[i][i] = torch::tensor(0); + for(unsigned int j = i + 1; j < trees.size(); ++j) { +#ifdef TTK_ENABLE_OPENMP +#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j) + { +#endif + bool isCalled = true; + bool doSqrt = false; + torch::Tensor tensorDist; + getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]), + *(trees2[j]), tensorDist, isCalled, + doSqrt); + outDistMat[i][j] = tensorDist; + outDistMat[j][i] = tensorDist; +#ifdef TTK_ENABLE_OPENMP + } // pragma omp task +#endif + } + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp taskwait + } // pragma omp single nowait + } // pragma omp parallel +#endif +} +#endif diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.h b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.h new file mode 100644 index 0000000000..f426b65e0c --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralBase.h @@ -0,0 +1,175 @@ +/// \ingroup base +/// \class ttk::MergeTreeNeuralBase +/// \author Mathieu Pont +/// \date 2024. +/// +/// This module defines the %MergeTreeNeuralNetwork abstract class providing +/// functions to define a neural network processing merge trees or persistence +/// diagrams from end to end. +/// +/// \b Related \b publication: \n +/// "Wasserstein Auto-Encoders of Merge Trees (and Persistence Diagrams)" \n +/// Mathieu Pont, Julien Tierny.\n +/// IEEE Transactions on Visualization and Computer Graphics, 2023 +/// + +#pragma once + +// ttk common includes +#include +#include +#include +#include + +#ifdef TTK_ENABLE_TORCH +#include +#endif + +namespace ttk { + + /** + * This module defines the %MergeTreeNeuralNetwork abstract class providing + * functions to define a neural network processing merge trees or persistence + * diagrams from end to end. + */ + class MergeTreeNeuralBase : virtual public Debug, + public MergeTreeAxesAlgorithmBase { + + protected: + // ======== Model hyper-parameters + // Dropout to use when training. + double dropout_ = 0.0; + // If the vectors should be initialized using euclidean distance (between + // vectors representing the topological abstractions ordered given the + // assignment to the barycenter), faster but less accurate than using + // Wasserstein distance. + bool euclideanVectorsInit_ = false; + // If the vectors should be initialized randomly. + bool randomAxesInit_ = false; + // When computing the origin of the input basis, if the barycenter algorihm + // should be initialized randomly (instead to the topological representation + // minimizing the distance to the set), faster but less accurate. + bool initBarycenterRandom_ = false; + // When computing the origin of the input basis, if the barycenter algorithm + // should run for only one iteration, faster but less accurate. + bool initBarycenterOneIter_ = false; + // If the structure of the origin of the output basis should be initialized + // by copying the structure of the input basis. + bool initOriginPrimeStructByCopy_ = true; + // If the scalar values of the origin of the output basis should be + // initialized by copying the values of the input basis. + bool initOriginPrimeValuesByCopy_ = true; + // Value between 0 and 1 allowing to add some randomness to the values of + // the origin of the output basis when initOriginPrimeValuesByCopy_ is set + // to true. + double initOriginPrimeValuesByCopyRandomness_ = 0.0; + // If activation functions should be used. + bool activate_ = true; + // Choice of the activation function + // 0 : ReLU + // 1 : Leaky ReLU + unsigned int activationFunction_ = 1; + + bool useGpu_ = false; + + // ======== Testing + float bigValuesThreshold_ = 0; + + public: + MergeTreeNeuralBase(); + +#ifdef TTK_ENABLE_TORCH + // ----------------------------------------------------------------------- + // --- Setter + // ----------------------------------------------------------------------- + void setDropout(const double dropout); + + void setEuclideanVectorsInit(const bool euclideanVectorsInit); + + void setRandomAxesInit(const bool randomAxesInit); + + void setInitBarycenterRandom(const bool initBarycenterRandom); + + void setInitBarycenterOneIter(const bool initBarycenterOneIter); + + void setInitOriginPrimeStructByCopy(const bool initOriginPrimeStructByCopy); + + void setInitOriginPrimeValuesByCopy(const bool initOriginPrimeValuesByCopy); + + void setInitOriginPrimeValuesByCopyRandomness( + const double initOriginPrimeValuesByCopyRandomness); + + void setActivate(const bool activate); + + void setActivationFunction(const unsigned int activationFunction); + + void setUseGpu(const bool useGpu); + + void setBigValuesThreshold(const float bigValuesThreshold); + + // ----------------------------------------------------------------------- + // --- Utils + // ----------------------------------------------------------------------- + torch::Tensor activation(torch::Tensor &in); + + /** + * @brief Fix the scalars of a merge tree to ensure that the nesting + * condition is respected. + * + * @param[in] mTree Merge tree to process. + */ + void fixTreePrecisionScalars(ftm::MergeTree &mTree); + + /** + * @brief Util function to print pairs of a merge tree. + * + * @param[in] mTree merge tree to process. + * @param[in] useBD if the merge tree is in branch decomposition mode or + * not. + */ + void printPairs(const ftm::MergeTree &mTree, bool useBD = true); + + // ----------------------------------------------------------------------- + // --- Distance + // ----------------------------------------------------------------------- + void getDistanceMatrix(const std::vector> &tmts, + std::vector> &distanceMatrix, + bool useDoubleInput = false, + bool isFirstInput = true); + + void getDistanceMatrix(const std::vector> &tmts, + const std::vector> &tmts2, + std::vector> &distanceMatrix); + + void getDifferentiableDistanceFromMatchings( + const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree1_2, + const mtu::TorchMergeTree &tree2_2, + std::vector> &matchings, + std::vector> &matchings2, + torch::Tensor &tensorDist, + bool doSqrt); + + void getDifferentiableDistance(const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree1_2, + const mtu::TorchMergeTree &tree2_2, + torch::Tensor &tensorDist, + bool isCalled, + bool doSqrt); + + void getDifferentiableDistance(const mtu::TorchMergeTree &tree1, + const mtu::TorchMergeTree &tree2, + torch::Tensor &tensorDist, + bool isCalled, + bool doSqrt); + + void getDifferentiableDistanceMatrix( + const std::vector *> &trees, + const std::vector *> &trees2, + std::vector> &outDistMat); +#endif + }; // MergeTreeNeuralBase class + +} // namespace ttk diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.cpp b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.cpp new file mode 100644 index 0000000000..017d88758c --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.cpp @@ -0,0 +1,1317 @@ +#include +#include + +#ifdef TTK_ENABLE_TORCH +using namespace torch::indexing; +#endif + +ttk::MergeTreeNeuralLayer::MergeTreeNeuralLayer() { + // inherited from Debug: prefix will be printed at the beginning of every msg + this->setDebugMsgPrefix("MergeTreeNeuralLayer"); +} + +#ifdef TTK_ENABLE_TORCH +// ----------------------------------------------------------------------- +// --- Getter/Setter +// ----------------------------------------------------------------------- +const ttk::mtu::TorchMergeTree & + ttk::MergeTreeNeuralLayer::getOrigin() const { + return origin_; +} + +const ttk::mtu::TorchMergeTree & + ttk::MergeTreeNeuralLayer::getOriginPrime() const { + return originPrime_; +} + +const ttk::mtu::TorchMergeTree & + ttk::MergeTreeNeuralLayer::getOrigin2() const { + return origin2_; +} + +const ttk::mtu::TorchMergeTree & + ttk::MergeTreeNeuralLayer::getOrigin2Prime() const { + return origin2Prime_; +} + +const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSTensor() const { + return vSTensor_; +} + +const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSPrimeTensor() const { + return vSPrimeTensor_; +} + +const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2Tensor() const { + return vS2Tensor_; +} + +const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2PrimeTensor() const { + return vS2PrimeTensor_; +} + +void ttk::MergeTreeNeuralLayer::setOrigin( + const mtu::TorchMergeTree &tmt) { + mtu::copyTorchMergeTree(tmt, origin_); +} + +void ttk::MergeTreeNeuralLayer::setOriginPrime( + const mtu::TorchMergeTree &tmt) { + mtu::copyTorchMergeTree(tmt, originPrime_); +} + +void ttk::MergeTreeNeuralLayer::setOrigin2( + const mtu::TorchMergeTree &tmt) { + mtu::copyTorchMergeTree(tmt, origin2_); +} + +void ttk::MergeTreeNeuralLayer::setOrigin2Prime( + const mtu::TorchMergeTree &tmt) { + mtu::copyTorchMergeTree(tmt, origin2Prime_); +} + +void ttk::MergeTreeNeuralLayer::setVSTensor(const torch::Tensor &vS) { + mtu::copyTensor(vS, vSTensor_); +} + +void ttk::MergeTreeNeuralLayer::setVSPrimeTensor(const torch::Tensor &vS) { + mtu::copyTensor(vS, vSPrimeTensor_); +} + +void ttk::MergeTreeNeuralLayer::setVS2Tensor(const torch::Tensor &vS) { + mtu::copyTensor(vS, vS2Tensor_); +} + +void ttk::MergeTreeNeuralLayer::setVS2PrimeTensor(const torch::Tensor &vS) { + mtu::copyTensor(vS, vS2PrimeTensor_); +} + +// --------------------------------------------------------------------------- +// --- Init +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralLayer::initOutputBasisTreeStructure( + mtu::TorchMergeTree &originPrime, + bool isJT, + mtu::TorchMergeTree &baseOrigin) { + // ----- Create scalars vector + torch::Tensor originTensor = originPrime.tensor; + if(!originTensor.device().is_cpu()) + originTensor = originTensor.cpu(); + std::vector scalarsVector( + originTensor.data_ptr(), + originTensor.data_ptr() + originTensor.numel()); + unsigned int noNodes = scalarsVector.size() / 2; + std::vector> childrenFinal(noNodes); + + // ----- Init tree structure and modify scalars if necessary + if(isPersistenceDiagram_) { + for(unsigned int i = 2; i < scalarsVector.size(); i += 2) + childrenFinal[0].emplace_back(i / 2); + } else { + // --- Fix or swap min-max pair + float maxPers = std::numeric_limits::lowest(); + unsigned int indMax = 0; + for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { + if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) { + maxPers = (scalarsVector[i + 1] - scalarsVector[i]); + indMax = i; + } + } + if(indMax != 0) { + float temp = scalarsVector[0]; + scalarsVector[0] = scalarsVector[indMax]; + scalarsVector[indMax] = temp; + temp = scalarsVector[1]; + scalarsVector[1] = scalarsVector[indMax + 1]; + scalarsVector[indMax + 1] = temp; + } + ftm::idNode refNode = 0; + for(unsigned int i = 2; i < scalarsVector.size(); i += 2) { + ftm::idNode node = i / 2; + adjustNestingScalars(scalarsVector, node, refNode); + } + + if(not initOriginPrimeStructByCopy_ + or (int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) { + // --- Get possible children and parent relations + std::vector> parents(noNodes), children(noNodes); + for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { + for(unsigned int j = i; j < scalarsVector.size(); j += 2) { + if(i == j) + continue; + unsigned int iN = i / 2, jN = j / 2; + if(scalarsVector[i] <= scalarsVector[j] + and scalarsVector[i + 1] >= scalarsVector[j + 1]) { + // - i is parent of j + parents[jN].emplace_back(iN); + children[iN].emplace_back(jN); + } else if(scalarsVector[i] >= scalarsVector[j] + and scalarsVector[i + 1] <= scalarsVector[j + 1]) { + // - j is parent of i + parents[iN].emplace_back(jN); + children[jN].emplace_back(iN); + } + } + } + createBalancedBDT(parents, children, scalarsVector, childrenFinal); + } else { + ftm::MergeTree mTreeTemp + = ftm::copyMergeTree(baseOrigin.mTree); + bool useBD = true; + keepMostImportantPairs(&(mTreeTemp.tree), noNodes, useBD); + torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2}); + torch::Tensor order = torch::argsort( + (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1, + true); + std::vector nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0); + unsigned int nodeNum = 1; + std::queue queue; + queue.emplace(mTreeTemp.tree.getRoot()); + while(!queue.empty()) { + ftm::idNode node = queue.front(); + queue.pop(); + std::vector children; + mTreeTemp.tree.getChildren(node, children); + for(auto &child : children) { + queue.emplace(child); + unsigned int tNode = nodeCorr[node]; + nodeCorr[child] = order[nodeNum].item(); + ++nodeNum; + unsigned int tChild = nodeCorr[child]; + childrenFinal[tNode].emplace_back(tChild); + adjustNestingScalars(scalarsVector, tChild, tNode); + } + } + } + } + + // ----- Create new tree + originPrime.mTree = ftm::createEmptyMergeTree(scalarsVector.size()); + ftm::FTMTree_MT *tree = &(originPrime.mTree.tree); + if(isJT) { + for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { + float temp = scalarsVector[i]; + scalarsVector[i] = scalarsVector[i + 1]; + scalarsVector[i + 1] = temp; + } + } + ftm::setTreeScalars(originPrime.mTree, scalarsVector); + + // ----- Create tree structure + originPrime.nodeCorr.clear(); + originPrime.nodeCorr.assign( + scalarsVector.size(), std::numeric_limits::max()); + for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { + tree->makeNode(i); + tree->makeNode(i + 1); + tree->getNode(i)->setOrigin(i + 1); + tree->getNode(i + 1)->setOrigin(i); + originPrime.nodeCorr[i] = (unsigned int)(i / 2); + } + for(unsigned int i = 0; i < scalarsVector.size(); i += 2) { + unsigned int node = i / 2; + for(auto &child : childrenFinal[node]) + tree->makeSuperArc(child * 2, i); + } + mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri); + + if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) { + std::stringstream ss; + ss << originPrime.mTree.tree.printPairsFromTree(true).str() + << std::endl; + ss << "isTreeHasBigValues(originPrime.mTree)" << std::endl; + ss << "pause" << std::endl; + printMsg(ss.str()); + std::cin.get(); + } +} + +void ttk::MergeTreeNeuralLayer::initOutputBasis( + const unsigned int dim, + const unsigned int dim2, + const torch::Tensor &baseTensor) { + unsigned int originSize = origin_.tensor.sizes()[0]; + unsigned int origin2Size = 0; + if(useDoubleInput_) + origin2Size = origin2_.tensor.sizes()[0]; + + // --- Compute output basis origin + printMsg("Compute output basis origin", debug::Priority::DETAIL); + auto initOutputBasisOrigin = [this, &baseTensor]( + torch::Tensor &w, + mtu::TorchMergeTree &tmt, + mtu::TorchMergeTree &baseTmt) { + // - Create scalars + torch::nn::init::xavier_normal_(w); + torch::Tensor baseTmtTensor = baseTmt.tensor; + if(normalizedWasserstein_) + // Work on unnormalized tensor + mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor, false); + torch::Tensor b + = torch::full({w.sizes()[0], 1}, 0.01, + torch::TensorOptions().device(baseTmtTensor.device())); + tmt.tensor = (torch::matmul(w, baseTmtTensor) + b); + // - Shift to keep mean birth and max pers + mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor); + // - Shift to avoid diagonal points + mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor); + // + if(initOriginPrimeValuesByCopy_) { + auto baseTensorDiag = baseTensor.reshape({-1, 2}); + auto basePersDiag = (baseTensorDiag.index({Slice(), 1}) + - baseTensorDiag.index({Slice(), 0})); + auto tmtTensorDiag = tmt.tensor.reshape({-1, 2}); + auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1}) + - tmtTensorDiag.index({Slice(1, None), 0})); + int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]); + auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))}); + auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1; + auto zeros + = torch::zeros(1, torch::TensorOptions().device(indexes.device())); + indexes = torch::cat({zeros, indexes}).to(torch::kLong); + if(initOriginPrimeValuesByCopyRandomness_ != 0) { + topVal = (1 - initOriginPrimeValuesByCopyRandomness_) * topVal + + initOriginPrimeValuesByCopyRandomness_ + * tmtTensorDiag.index({indexes}); + } + tmtTensorDiag.index_put_({indexes}, topVal); + } + // - Create tree structure + initOutputBasisTreeStructure( + tmt, baseTmt.mTree.tree.isJoinTree(), baseTmt); + if(normalizedWasserstein_) + // Normalize tensor + mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor, true); + // - Projection + interpolationProjection(tmt); + }; + torch::Tensor w = torch::zeros( + {dim, originSize}, torch::TensorOptions().device(origin_.tensor.device())); + initOutputBasisOrigin(w, originPrime_, origin_); + torch::Tensor w2; + if(useDoubleInput_) { + w2 = torch::zeros({dim2, origin2Size}, + torch::TensorOptions().device(origin2_.tensor.device())); + initOutputBasisOrigin(w2, origin2Prime_, origin2_); + } + + // --- Compute output basis vectors + printMsg("Compute output basis vectors", debug::Priority::DETAIL); + initOutputBasisVectors(w, w2); +} + +void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(torch::Tensor &w, + torch::Tensor &w2) { + vSPrimeTensor_ = torch::matmul(w, vSTensor_); + if(useDoubleInput_) + vS2PrimeTensor_ = torch::matmul(w2, vS2Tensor_); + if(normalizedWasserstein_) { + mtu::normalizeVectors(originPrime_.tensor, vSPrimeTensor_); + if(useDoubleInput_) + mtu::normalizeVectors(origin2Prime_.tensor, vS2PrimeTensor_); + } +} + +void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(unsigned int dim, + unsigned int dim2) { + unsigned int originSize = origin_.tensor.sizes()[0]; + unsigned int origin2Size = 0; + if(useDoubleInput_) + origin2Size = origin2_.tensor.sizes()[0]; + torch::Tensor w = torch::zeros({dim, originSize}); + torch::nn::init::xavier_normal_(w); + torch::Tensor w2 = torch::zeros({dim2, origin2Size}); + torch::nn::init::xavier_normal_(w2); + initOutputBasisVectors(w, w2); +} + +void ttk::MergeTreeNeuralLayer::initInputBasisOrigin( + std::vector> &treesToUse, + std::vector> &trees2ToUse, + double barycenterSizeLimitPercent, + unsigned int barycenterMaxNoPairs, + unsigned int barycenterMaxNoPairs2, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2) { + int barycenterInitIndex = -1; + if(initBarycenterRandom_) { + std::random_device rd; + std::default_random_engine rng(deterministic_ ? 0 : rd()); + barycenterInitIndex + = std::uniform_int_distribution<>(0, treesToUse.size() - 1)(rng); + } + int maxNoPairs = (initBarycenterRandom_ ? barycenterMaxNoPairs : 0); + computeOneBarycenter(treesToUse, origin_.mTree, baryMatchings, + inputToBaryDistances, barycenterSizeLimitPercent, + maxNoPairs, barycenterInitIndex, + initBarycenterOneIter_, useDoubleInput_, true); + if(not initBarycenterRandom_ and barycenterMaxNoPairs > 0) + keepMostImportantPairs( + &(origin_.mTree.tree), barycenterMaxNoPairs, true); + if(useDoubleInput_) { + std::vector baryDistances2; + int maxNoPairs2 = (initBarycenterRandom_ ? barycenterMaxNoPairs2 : 0); + computeOneBarycenter(trees2ToUse, origin2_.mTree, baryMatchings2, + baryDistances2, barycenterSizeLimitPercent, + maxNoPairs2, barycenterInitIndex, + initBarycenterOneIter_, useDoubleInput_, false); + if(not initBarycenterRandom_ and barycenterMaxNoPairs2 > 0) + keepMostImportantPairs( + &(origin2_.mTree.tree), barycenterMaxNoPairs2, true); + for(unsigned int i = 0; i < inputToBaryDistances.size(); ++i) + inputToBaryDistances[i] + = mixDistances(inputToBaryDistances[i], baryDistances2[i]); + } + + mtu::getParentsVector(origin_.mTree, origin_.parentsOri); + mtu::mergeTreeToTorchTensor( + origin_.mTree, origin_.tensor, origin_.nodeCorr, normalizedWasserstein_); + if(useGpu_) + origin_.tensor = origin_.tensor.cuda(); + if(useDoubleInput_) { + mtu::getParentsVector(origin2_.mTree, origin2_.parentsOri); + mtu::mergeTreeToTorchTensor(origin2_.mTree, origin2_.tensor, + origin2_.nodeCorr, + normalizedWasserstein_); + if(useGpu_) + origin2_.tensor = origin2_.tensor.cuda(); + } +} + +void ttk::MergeTreeNeuralLayer::initInputBasisVectors( + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector> &trees, + std::vector> &trees2, + unsigned int noVectors, + std::vector &allAlphasInit, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2, + mtu::TorchMergeTree &origin, + mtu::TorchMergeTree &origin2, + torch::Tensor &vSTensor, + torch::Tensor &vS2Tensor, + bool useInputBasis) { + if(randomAxesInit_) { + auto initRandomAxes = [&noVectors](mtu::TorchMergeTree &originT, + torch::Tensor &axes) { + torch::Tensor w = torch::zeros({noVectors, originT.tensor.sizes()[0]}); + torch::nn::init::xavier_normal_(w); + axes = torch::linalg::pinv(w); + }; + initRandomAxes(origin, vSTensor); + if(useGpu_) + vSTensor = vSTensor.cuda(); + if(useDoubleInput_) { + initRandomAxes(origin2, vS2Tensor); + if(useGpu_) + vS2Tensor = vS2Tensor.cuda(); + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < trees.size(); ++i) + allAlphasInit[i] = torch::randn({noVectors, 1}); + return; + } + + // --- Initialized vectors projection function to avoid collinearity + auto initializedVectorsProjection + = [=](int ttkNotUsed(_axeNumber), + ftm::MergeTree &ttkNotUsed(_barycenter), + std::vector> &_v, + std::vector> &ttkNotUsed(_v2), + std::vector>> &_vS, + std::vector>> &ttkNotUsed(_v2s), + ftm::MergeTree &ttkNotUsed(_barycenter2), + std::vector> &ttkNotUsed(_trees2V), + std::vector> &ttkNotUsed(_trees2V2), + std::vector>> &ttkNotUsed(_trees2Vs), + std::vector>> &ttkNotUsed(_trees2V2s), + bool ttkNotUsed(_useSecondInput), + unsigned int ttkNotUsed(_noProjectionStep)) { + std::vector scaledV, scaledVSi; + Geometry::flattenMultiDimensionalVector(_v, scaledV); + Geometry::scaleVector( + scaledV, 1.0 / Geometry::magnitude(scaledV), scaledV); + for(unsigned int i = 0; i < _vS.size(); ++i) { + Geometry::flattenMultiDimensionalVector(_vS[i], scaledVSi); + Geometry::scaleVector( + scaledVSi, 1.0 / Geometry::magnitude(scaledVSi), scaledVSi); + auto prod = Geometry::dotProduct(scaledV, scaledVSi); + double tol = 0.01; + if(prod <= -1.0 + tol or prod >= 1.0 - tol) { + // Reset vector to initialize it again + for(unsigned int j = 0; j < _v.size(); ++j) + for(unsigned int k = 0; k < _v[j].size(); ++k) + _v[j][k] = 0; + break; + } + } + return 0; + }; + + // --- Init vectors + std::vector> inputToAxesDistances; + std::vector>> vS, v2s, trees2Vs, trees2V2s; + std::stringstream ss; + for(unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) { + ss.str(""); + ss << "Compute vectors " << vecNum; + printMsg(ss.str(), debug::Priority::VERBOSE); + std::vector> v1, v2, trees2V1, trees2V2; + int newVectorOffset = 0; + bool projectInitializedVectors = true; + int bestIndex = MergeTreeAxesAlgorithmBase::initVectors( + vecNum, origin.mTree, trees, origin2.mTree, trees2, v1, v2, trees2V1, + trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings, + baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s, + projectInitializedVectors, initializedVectorsProjection); + vS.emplace_back(v1); + v2s.emplace_back(v2); + trees2Vs.emplace_back(trees2V1); + trees2V2s.emplace_back(trees2V2); + + ss.str(""); + ss << "bestIndex = " << bestIndex; + printMsg(ss.str(), debug::Priority::VERBOSE); + + // Update inputToAxesDistances + printMsg("Update inputToAxesDistances", debug::Priority::VERBOSE); + inputToAxesDistances.resize(1, std::vector(trees.size())); + if(bestIndex == -1 and normalizedWasserstein_) { + mtu::normalizeVectors(origin, vS[vS.size() - 1]); + if(useDoubleInput_) + mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]); + } + mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor); + if(useGpu_) + vSTensor = vSTensor.cuda(); + if(useDoubleInput_) { + mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor); + if(useGpu_) + vS2Tensor = vS2Tensor.cuda(); + } + mtu::TorchMergeTree dummyTmt; + std::vector> + dummyBaryMatching2; +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < trees.size(); ++i) { + auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2[i]); + if(not euclideanVectorsInit_) { + unsigned int k = k_; + auto newAlpha = torch::ones({1, 1}); + if(bestIndex == -1) { + newAlpha = torch::zeros({1, 1}); + } + allAlphasInit[i] = (allAlphasInit[i].defined() + ? torch::cat({allAlphasInit[i], newAlpha}) + : newAlpha); + torch::Tensor bestAlphas; + bool isCalled = true; + inputToAxesDistances[0][i] + = assignmentOneData(tmTrees[i], tmt2ToUse, k, allAlphasInit[i], + bestAlphas, isCalled, useInputBasis); + allAlphasInit[i] = bestAlphas.detach(); + } else { + auto &baryMatching2ToUse + = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]); + torch::Tensor alphas; + computeAlphas(tmTrees[i], origin, vSTensor, origin, baryMatchings[i], + tmt2ToUse, origin2, vS2Tensor, origin2, + baryMatching2ToUse, alphas); + mtu::TorchMergeTree interpolated, interpolated2; + getMultiInterpolation(origin, vSTensor, alphas, interpolated); + if(useDoubleInput_) + getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2); + torch::Tensor tensorDist; + bool doSqrt = true; + getDifferentiableDistanceFromMatchings( + interpolated, tmTrees[i], interpolated2, tmt2ToUse, baryMatchings[i], + baryMatching2ToUse, tensorDist, doSqrt); + inputToAxesDistances[0][i] = tensorDist.item(); + allAlphasInit[i] = alphas.detach(); + } + } + } +} + +void ttk::MergeTreeNeuralLayer::initInputBasisVectors( + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector> &trees, + std::vector> &trees2, + unsigned int noVectors, + std::vector &allAlphasInit, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2, + bool useInputBasis) { + mtu::TorchMergeTree &origin = (useInputBasis ? origin_ : originPrime_); + mtu::TorchMergeTree &origin2 + = (useInputBasis ? origin2_ : origin2Prime_); + torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_); + torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_); + + initInputBasisVectors(tmTrees, tmTrees2, trees, trees2, noVectors, + allAlphasInit, inputToBaryDistances, baryMatchings, + baryMatchings2, origin, origin2, vSTensor, vS2Tensor, + useInputBasis); +} + +void ttk::MergeTreeNeuralLayer::requires_grad(const bool requireGrad) { + origin_.tensor.requires_grad_(requireGrad); + originPrime_.tensor.requires_grad_(requireGrad); + vSTensor_.requires_grad_(requireGrad); + vSPrimeTensor_.requires_grad_(requireGrad); + if(useDoubleInput_) { + origin2_.tensor.requires_grad_(requireGrad); + origin2Prime_.tensor.requires_grad_(requireGrad); + vS2Tensor_.requires_grad_(requireGrad); + vS2PrimeTensor_.requires_grad_(requireGrad); + } +} + +void ttk::MergeTreeNeuralLayer::cuda() { + origin_.tensor = origin_.tensor.cuda(); + originPrime_.tensor = originPrime_.tensor.cuda(); + vSTensor_ = vSTensor_.cuda(); + vSPrimeTensor_ = vSPrimeTensor_.cuda(); + if(useDoubleInput_) { + origin2_.tensor = origin2_.tensor.cuda(); + origin2Prime_.tensor = origin2Prime_.tensor.cuda(); + vS2Tensor_ = vS2Tensor_.cuda(); + vS2PrimeTensor_ = vS2PrimeTensor_.cuda(); + } +} + +// --------------------------------------------------------------------------- +// --- Interpolation +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralLayer::interpolationDiagonalProjection( + mtu::TorchMergeTree &interpolation) { + torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2}); + if(interpolation.tensor.requires_grad()) + diagTensor = diagTensor.detach(); + + torch::Tensor birthTensor = diagTensor.index({Slice(), 0}); + torch::Tensor deathTensor = diagTensor.index({Slice(), 1}); + + torch::Tensor indexer = (birthTensor > deathTensor); + + torch::Tensor allProj = (birthTensor + deathTensor) / 2.0; + allProj = allProj.index({indexer}); + allProj = allProj.reshape({-1, 1}); + + diagTensor.index_put_({indexer}, allProj); +} + +void ttk::MergeTreeNeuralLayer::interpolationNestingProjection( + mtu::TorchMergeTree &interpolation) { + torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2}); + if(interpolation.tensor.requires_grad()) + diagTensor = diagTensor.detach(); + + torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0}); + torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1}); + + torch::Tensor birthIndexer = (birthTensor < 0); + torch::Tensor deathIndexer = (deathTensor < 0); + birthTensor.index_put_( + {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer}))); + deathTensor.index_put_( + {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer}))); + + birthIndexer = (birthTensor > 1); + deathIndexer = (deathTensor > 1); + birthTensor.index_put_( + {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer}))); + deathTensor.index_put_( + {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer}))); +} + +void ttk::MergeTreeNeuralLayer::interpolationProjection( + mtu::TorchMergeTree &interpolation) { + interpolationDiagonalProjection(interpolation); + if(normalizedWasserstein_) + interpolationNestingProjection(interpolation); + + ftm::MergeTree interpolationNew; + bool noRoot = mtu::torchTensorToMergeTree( + interpolation, normalizedWasserstein_, interpolationNew); + if(noRoot) + printWrn("[interpolationProjection] no root found"); + interpolation.mTree = copyMergeTree(interpolationNew); + + persistenceThresholding(&(interpolation.mTree.tree), 0.001); + + if(isPersistenceDiagram_ and isThereMissingPairs(interpolation)) + printWrn("[getMultiInterpolation] missing pairs"); +} + +void ttk::MergeTreeNeuralLayer::getMultiInterpolation( + const mtu::TorchMergeTree &origin, + const torch::Tensor &vS, + torch::Tensor &alphas, + mtu::TorchMergeTree &interpolation) { + mtu::copyTorchMergeTree(origin, interpolation); + interpolation.tensor = origin.tensor + torch::matmul(vS, alphas); + interpolationProjection(interpolation); +} + +// --------------------------------------------------------------------------- +// --- Forward +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralLayer::getAlphasOptimizationTensors( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &origin, + torch::Tensor &vSTensor, + mtu::TorchMergeTree &interpolated, + std::vector> &matching, + torch::Tensor &reorderedTreeTensor, + torch::Tensor &deltaOrigin, + torch::Tensor &deltaA, + torch::Tensor &originTensor_f, + torch::Tensor &vSTensor_f) { + // Create matching indexing + std::vector tensorMatching; + mtu::getTensorMatching(interpolated, tree, matching, tensorMatching); + + torch::Tensor indexes = torch::tensor(tensorMatching); + torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1}); + + dataReorderingGivenMatching( + origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin); + + // Create axes projection given matching + deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2}); + deltaA = (deltaA.index({Slice(), Slice(), 0}) + + deltaA.index({Slice(), Slice(), 1})) + / 2.0; + deltaA = torch::stack({deltaA, deltaA}, 2); + if(!deltaA.device().is_cpu()) + projIndexer = projIndexer.to(deltaA.device()); + deltaA = deltaA * projIndexer; + deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1); + + // + originTensor_f = origin.tensor; + vSTensor_f = vSTensor; +} + +void ttk::MergeTreeNeuralLayer::computeAlphas( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &origin, + torch::Tensor &vSTensor, + mtu::TorchMergeTree &interpolated, + std::vector> &matching, + mtu::TorchMergeTree &tree2, + mtu::TorchMergeTree &origin2, + torch::Tensor &vS2Tensor, + mtu::TorchMergeTree &interpolated2, + std::vector> &matching2, + torch::Tensor &alphasOut) { + torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f, + vSTensor_f; + getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching, + reorderedTreeTensor, deltaOrigin, deltaA, + originTensor_f, vSTensor_f); + + if(useDoubleInput_) { + torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f, + vS2Tensor_f; + getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2, + matching2, reorderedTree2Tensor, deltaOrigin2, + deltaA2, origin2Tensor_f, vS2Tensor_f); + vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f}); + deltaA = torch::cat({deltaA, deltaA2}); + reorderedTreeTensor + = torch::cat({reorderedTreeTensor, reorderedTree2Tensor}); + originTensor_f = torch::cat({originTensor_f, origin2Tensor_f}); + deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2}); + } + + torch::Tensor r_axes = vSTensor_f - deltaA; + torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin; + + // Pseudo inverse + auto driver = "gelsd"; + bool is_cpu = r_axes.device().is_cpu(); + auto device = r_axes.device(); + if(!is_cpu) { + r_axes = r_axes.cpu(); + r_data = r_data.cpu(); + } + alphasOut + = std::get<0>(torch::linalg::lstsq(r_axes, r_data, c10::nullopt, driver)); + if(!is_cpu) + alphasOut = alphasOut.to(device); + + alphasOut.reshape({-1, 1}); +} + +float ttk::MergeTreeNeuralLayer::assignmentOneData( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + std::vector> &bestMatching, + std::vector> &bestMatching2, + torch::Tensor &bestAlphas, + bool isCalled, + bool useInputBasis) { + mtu::TorchMergeTree &origin = (useInputBasis ? origin_ : originPrime_); + mtu::TorchMergeTree &origin2 + = (useInputBasis ? origin2_ : origin2Prime_); + torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_); + torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_); + + torch::Tensor alphas, oldAlphas; + std::vector> matching, matching2; + float bestDistance = std::numeric_limits::max(); + mtu::TorchMergeTree interpolated, interpolated2; + unsigned int i = 0; + auto reset = [&]() { + alphasInit = torch::randn_like(alphas); + i = 0; + }; + unsigned int noUpdate = 0; + unsigned int noReset = 0; + while(i < k) { + if(i == 0) { + if(alphasInit.defined()) + alphas = alphasInit; + else + alphas = torch::zeros({vSTensor.sizes()[1], 1}); + } else { + computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2, + origin2, vS2Tensor, interpolated2, matching2, alphas); + if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas) + and i != 1) { + break; + } + } + mtu::copyTensor(alphas, oldAlphas); + getMultiInterpolation(origin, vSTensor, alphas, interpolated); + if(useDoubleInput_) + getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2); + if(interpolated.mTree.tree.getRealNumberOfNodes() == 0 + or (useDoubleInput_ + and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) { + ++noReset; + if(noReset >= 100) + printWrn("[assignmentOneData] noReset >= 100"); + reset(); + continue; + } + float distance; + computeOneDistance(interpolated.mTree, tree.mTree, matching, + distance, isCalled, useDoubleInput_); + if(useDoubleInput_) { + float distance2; + computeOneDistance(interpolated2.mTree, tree2.mTree, matching2, + distance2, isCalled, useDoubleInput_, false); + distance = mixDistances(distance, distance2); + } + if(distance < bestDistance and i != 0) { + bestDistance = distance; + bestMatching = matching; + bestMatching2 = matching2; + bestAlphas = alphas; + noUpdate += 1; + } + i += 1; + } + if(noUpdate == 0) + printErr("[assignmentOneData] noUpdate == 0"); + return bestDistance; +} + +float ttk::MergeTreeNeuralLayer::assignmentOneData( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + torch::Tensor &bestAlphas, + bool isCalled, + bool useInputBasis) { + std::vector> bestMatching, + bestMatching2; + return assignmentOneData(tree, tree2, k, alphasInit, bestMatching, + bestMatching2, bestAlphas, isCalled, useInputBasis); +} + +void ttk::MergeTreeNeuralLayer::outputBasisReconstruction( + torch::Tensor &alphas, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + bool activate, + bool train) { + if(not activate_) + activate = false; + torch::Tensor act = (activate ? activation(alphas) : alphas); + if(dropout_ != 0.0 and train) { + torch::nn::Dropout model(torch::nn::DropoutOptions().p(dropout_)); + act = model(act); + } + getMultiInterpolation(originPrime_, vSPrimeTensor_, act, out); + if(useDoubleInput_) + getMultiInterpolation(origin2Prime_, vS2PrimeTensor_, act, out2); +} + +bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + torch::Tensor &bestAlphas, + float &bestDistance, + bool train) { + bool goodOutput = false; + int noReset = 0; + while(not goodOutput) { + bool isCalled = true; + bestDistance + = assignmentOneData(tree, tree2, k, alphasInit, bestAlphas, isCalled); + outputBasisReconstruction(bestAlphas, out, out2, true, train); + goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0 + and (not useDoubleInput_ + or out2.mTree.tree.getRealNumberOfNodes() != 0)); + if(not goodOutput) { + ++noReset; + if(noReset >= 100) { + printWrn("[forwardOneLayer] noReset >= 100"); + return true; + } + alphasInit = torch::randn_like(alphasInit); + } + } + return false; +} + +bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + torch::Tensor &bestAlphas, + bool train) { + float bestDistance; + return forward( + tree, tree2, k, alphasInit, out, out2, bestAlphas, bestDistance, train); +} + +// --------------------------------------------------------------------------- +// --- Projection +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralLayer::projectionStep() { + auto projectTree = [this](mtu::TorchMergeTree &tmt) { + interpolationProjection(tmt); + tmt.tensor = tmt.tensor.detach(); + tmt.tensor.requires_grad_(true); + }; + projectTree(origin_); + projectTree(originPrime_); + if(useDoubleInput_) { + projectTree(origin2_); + projectTree(origin2Prime_); + } +} + +// --------------------------------------------------------------------------- +// --- Utils +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralLayer::copyParams( + mtu::TorchMergeTree &origin, + mtu::TorchMergeTree &originPrime, + torch::Tensor &vS, + torch::Tensor &vSPrime, + mtu::TorchMergeTree &origin2, + mtu::TorchMergeTree &origin2Prime, + torch::Tensor &vS2, + torch::Tensor &vS2Prime, + bool get) { + + // Source + mtu::TorchMergeTree &srcOrigin = (get ? origin_ : origin); + mtu::TorchMergeTree &srcOriginPrime + = (get ? originPrime_ : originPrime); + torch::Tensor &srcVS = (get ? vSTensor_ : vS); + torch::Tensor &srcVSPrime = (get ? vSPrimeTensor_ : vSPrime); + mtu::TorchMergeTree &srcOrigin2 = (get ? origin2_ : origin2); + mtu::TorchMergeTree &srcOrigin2Prime + = (get ? origin2Prime_ : origin2Prime); + torch::Tensor &srcVS2 = (get ? vS2Tensor_ : vS2); + torch::Tensor &srcVS2Prime = (get ? vS2PrimeTensor_ : vS2Prime); + + // Destination + mtu::TorchMergeTree &dstOrigin = (!get ? origin_ : origin); + mtu::TorchMergeTree &dstOriginPrime + = (!get ? originPrime_ : originPrime); + torch::Tensor &dstVS = (!get ? vSTensor_ : vS); + torch::Tensor &dstVSPrime = (!get ? vSPrimeTensor_ : vSPrime); + mtu::TorchMergeTree &dstOrigin2 = (!get ? origin2_ : origin2); + mtu::TorchMergeTree &dstOrigin2Prime + = (!get ? origin2Prime_ : origin2Prime); + torch::Tensor &dstVS2 = (!get ? vS2Tensor_ : vS2); + torch::Tensor &dstVS2Prime = (!get ? vS2PrimeTensor_ : vS2Prime); + + // Copy + mtu::copyTorchMergeTree(srcOrigin, dstOrigin); + mtu::copyTorchMergeTree(srcOriginPrime, dstOriginPrime); + mtu::copyTensor(srcVS, dstVS); + mtu::copyTensor(srcVSPrime, dstVSPrime); + if(useDoubleInput_) { + mtu::copyTorchMergeTree(srcOrigin2, dstOrigin2); + mtu::copyTorchMergeTree(srcOrigin2Prime, dstOrigin2Prime); + mtu::copyTensor(srcVS2, dstVS2); + mtu::copyTensor(srcVS2Prime, dstVS2Prime); + } +} + +void ttk::MergeTreeNeuralLayer::adjustNestingScalars( + std::vector &scalarsVector, ftm::idNode node, ftm::idNode refNode) { + float birth = scalarsVector[refNode * 2]; + float death = scalarsVector[refNode * 2 + 1]; + auto getSign = [](float v) { return (v > 0 ? 1 : -1); }; + auto getPrecValue = [&getSign](float v, bool opp = false) { + return v * (1 + (opp ? -1 : 1) * getSign(v) * 1e-6); + }; + // Shift scalars + if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) { + float diff = scalarsVector[node * 2 + 1] - getPrecValue(death, true); + scalarsVector[node * 2] -= diff; + scalarsVector[node * 2 + 1] -= diff; + } else if(scalarsVector[node * 2] < getPrecValue(birth)) { + float diff = getPrecValue(birth) - scalarsVector[node * 2]; + scalarsVector[node * 2] += getPrecValue(diff); + scalarsVector[node * 2 + 1] += getPrecValue(diff); + } + // Cut scalars + if(scalarsVector[node * 2] < getPrecValue(birth)) + scalarsVector[node * 2] = getPrecValue(birth); + if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) + scalarsVector[node * 2 + 1] = getPrecValue(death, true); +} + +void ttk::MergeTreeNeuralLayer::createBalancedBDT( + std::vector> &parents, + std::vector> &children, + std::vector &scalarsVector, + std::vector> &childrenFinal) { + // ----- Some variables + unsigned int noNodes = scalarsVector.size() / 2; + childrenFinal.resize(noNodes); + int mtLevel = ceil(log(noNodes * 2) / log(2)) + 1; + int bdtLevel = mtLevel - 1; + int noDim = bdtLevel; + + // ----- Get node levels + std::vector nodeLevels(noNodes, -1); + std::queue queueLevels; + std::vector noChildDone(noNodes, 0); + for(unsigned int i = 0; i < children.size(); ++i) { + if(children[i].size() == 0) { + queueLevels.emplace(i); + nodeLevels[i] = 1; + } + } + while(!queueLevels.empty()) { + ftm::idNode node = queueLevels.front(); + queueLevels.pop(); + for(auto &parent : parents[node]) { + ++noChildDone[parent]; + nodeLevels[parent] = std::max(nodeLevels[parent], nodeLevels[node] + 1); + if(noChildDone[parent] >= (int)children[parent].size()) + queueLevels.emplace(parent); + } + } + + // ----- Sort heuristic lambda + auto sortChildren = [this, &parents, &scalarsVector, &noNodes]( + ftm::idNode nodeOrigin, std::vector &nodeDone, + std::vector> &childrenT) { + double refPers = scalarsVector[1] - scalarsVector[0]; + auto getRemaining = [&nodeDone](std::vector &vec) { + unsigned int remaining = 0; + for(auto &e : vec) + remaining += (not nodeDone[e]); + return remaining; + }; + std::vector parentsRemaining(noNodes, 0), + childrenRemaining(noNodes, 0); + for(auto &child : childrenT[nodeOrigin]) { + parentsRemaining[child] = getRemaining(parents[child]); + childrenRemaining[child] = getRemaining(childrenT[child]); + } + TTK_PSORT( + threadNumber_, childrenT[nodeOrigin].begin(), childrenT[nodeOrigin].end(), + [&](ftm::idNode nodeI, ftm::idNode nodeJ) { + double persI = scalarsVector[nodeI * 2 + 1] - scalarsVector[nodeI * 2]; + double persJ = scalarsVector[nodeJ * 2 + 1] - scalarsVector[nodeJ * 2]; + return parentsRemaining[nodeI] + childrenRemaining[nodeI] + - persI / refPers * noNodes + < parentsRemaining[nodeJ] + childrenRemaining[nodeJ] + - persJ / refPers * noNodes; + }); + }; + + // ----- Greedy approach to find balanced BDT structures + const auto findStructGivenDim = + [&children, &noNodes, &nodeLevels]( + ftm::idNode _nodeOrigin, int _dimToFound, bool _searchMaxDim, + std::vector &_nodeDone, std::vector &_dimFound, + std::vector> &_childrenFinalOut) { + // --- Recursive lambda + auto findStructGivenDimImpl = + [&children, &noNodes, &nodeLevels]( + ftm::idNode nodeOrigin, int dimToFound, bool searchMaxDim, + std::vector &nodeDone, std::vector &dimFound, + std::vector> &childrenFinalOut, + auto &findStructGivenDimRef) mutable { + childrenFinalOut.resize(noNodes); + // - Find structures + int dim = (searchMaxDim ? dimToFound - 1 : 0); + unsigned int i = 0; + // + auto searchMaxDimReset = [&i, &dim, &nodeDone]() { + --dim; + i = 0; + unsigned int noDone = 0; + for(auto done : nodeDone) + if(done) + ++noDone; + return noDone == nodeDone.size() - 1; // -1 for root + }; + while(i < children[nodeOrigin].size()) { + auto child = children[nodeOrigin][i]; + // Skip if child was already processed + if(nodeDone[child]) { + // If we have processed all children while searching for max + // dim then restart at the beginning to find a lower dim + if(searchMaxDim and i == children[nodeOrigin].size() - 1) { + if(searchMaxDimReset()) + break; + } else + ++i; + continue; + } + if(dim == 0) { + // Base case + childrenFinalOut[nodeOrigin].emplace_back(child); + nodeDone[child] = true; + dimFound[0] = true; + if(dimToFound <= 1 or searchMaxDim) + return true; + ++dim; + } else { + // General case + std::vector> childrenFinalDim; + std::vector nodeDoneDim; + std::vector dimFoundDim(dim); + bool found = false; + if(nodeLevels[child] > dim) { + nodeDoneDim = nodeDone; + found = findStructGivenDimRef(child, dim, false, nodeDoneDim, + dimFoundDim, childrenFinalDim, + findStructGivenDimRef); + } + if(found) { + dimFound[dim] = true; + childrenFinalOut[nodeOrigin].emplace_back(child); + for(unsigned int j = 0; j < childrenFinalDim.size(); ++j) + for(auto &e : childrenFinalDim[j]) + childrenFinalOut[j].emplace_back(e); + nodeDone[child] = true; + for(unsigned int j = 0; j < nodeDoneDim.size(); ++j) + nodeDone[j] = nodeDone[j] || nodeDoneDim[j]; + // Return if it is the last dim to found + if(dim == dimToFound - 1 and not searchMaxDim) + return true; + // Reset index if we search for the maximum dim + if(searchMaxDim) { + if(searchMaxDimReset()) + break; + } else { + ++dim; + } + continue; + } else if(searchMaxDim and i == children[nodeOrigin].size() - 1) { + // If we have processed all children while searching for max dim + // then restart at the beginning to find a lower dim + if(searchMaxDimReset()) + break; + continue; + } + } + ++i; + } + return false; + }; + return findStructGivenDimImpl(_nodeOrigin, _dimToFound, _searchMaxDim, + _nodeDone, _dimFound, _childrenFinalOut, + findStructGivenDimImpl); + }; + std::vector dimFound(noDim - 1, false); + std::vector nodeDone(noNodes, false); + for(unsigned int i = 0; i < children.size(); ++i) + sortChildren(i, nodeDone, children); + Timer t_find; + ftm::idNode startNode = 0; + findStructGivenDim(startNode, noDim, true, nodeDone, dimFound, childrenFinal); + + // ----- Greedy approach to create non found structures + const auto createStructGivenDim = + [this, &children, &noNodes, &findStructGivenDim, &nodeLevels]( + int _nodeOrigin, int _dimToCreate, std::vector &_nodeDone, + ftm::idNode &_structOrigin, std::vector &_scalarsVectorOut, + std::vector> &_childrenFinalOut) { + // --- Recursive lambda + auto createStructGivenDimImpl = + [this, &children, &noNodes, &findStructGivenDim, &nodeLevels]( + int nodeOrigin, int dimToCreate, std::vector &nodeDoneImpl, + ftm::idNode &structOrigin, std::vector &scalarsVectorOut, + std::vector> &childrenFinalOut, + auto &createStructGivenDimRef) mutable { + // Deduction of auto lambda type + if(false) + return; + // - Find structures of lower dimension + int dimToFound = dimToCreate - 1; + std::vector>> childrenFinalT(2); + std::array structOrigins; + for(unsigned int n = 0; n < 2; ++n) { + bool found = false; + for(unsigned int i = 0; i < children[nodeOrigin].size(); ++i) { + auto child = children[nodeOrigin][i]; + if(nodeDoneImpl[child]) + continue; + if(dimToFound != 0) { + if(nodeLevels[child] > dimToFound) { + std::vector dimFoundT(dimToFound, false); + childrenFinalT[n].clear(); + childrenFinalT[n].resize(noNodes); + std::vector nodeDoneImplFind = nodeDoneImpl; + found = findStructGivenDim(child, dimToFound, false, + nodeDoneImplFind, dimFoundT, + childrenFinalT[n]); + } + } else + found = true; + if(found) { + structOrigins[n] = child; + nodeDoneImpl[child] = true; + for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) { + for(auto &e : childrenFinalT[n][j]) { + childrenFinalOut[j].emplace_back(e); + nodeDoneImpl[e] = true; + } + } + break; + } + } // end for children[nodeOrigin] + if(not found) { + if(dimToFound <= 0) { + structOrigins[n] = std::numeric_limits::max(); + continue; + } + childrenFinalT[n].clear(); + childrenFinalT[n].resize(noNodes); + createStructGivenDimRef( + nodeOrigin, dimToFound, nodeDoneImpl, structOrigins[n], + scalarsVectorOut, childrenFinalT[n], createStructGivenDimRef); + for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) { + for(auto &e : childrenFinalT[n][j]) { + if(e == structOrigins[n]) + continue; + childrenFinalOut[j].emplace_back(e); + } + } + } + } // end for n + // - Combine both structures + if(structOrigins[0] == std::numeric_limits::max() + and structOrigins[1] == std::numeric_limits::max()) { + structOrigin = std::numeric_limits::max(); + return; + } + bool firstIsParent = true; + if(structOrigins[0] == std::numeric_limits::max()) + firstIsParent = false; + else if(structOrigins[1] == std::numeric_limits::max()) + firstIsParent = true; + else if(scalarsVectorOut[structOrigins[1] * 2 + 1] + - scalarsVectorOut[structOrigins[1] * 2] + > scalarsVectorOut[structOrigins[0] * 2 + 1] + - scalarsVectorOut[structOrigins[0] * 2]) + firstIsParent = false; + structOrigin = (firstIsParent ? structOrigins[0] : structOrigins[1]); + ftm::idNode modOrigin + = (firstIsParent ? structOrigins[1] : structOrigins[0]); + childrenFinalOut[nodeOrigin].emplace_back(structOrigin); + if(modOrigin != std::numeric_limits::max()) { + childrenFinalOut[structOrigin].emplace_back(modOrigin); + std::queue> queue; + queue.emplace(std::array{modOrigin, structOrigin}); + while(!queue.empty()) { + auto &nodeAndParent = queue.front(); + ftm::idNode node = nodeAndParent[0]; + ftm::idNode parent = nodeAndParent[1]; + queue.pop(); + adjustNestingScalars(scalarsVectorOut, node, parent); + // Push children + for(auto &child : childrenFinalOut[node]) + queue.emplace(std::array{child, node}); + } + } + return; + }; + return createStructGivenDimImpl( + _nodeOrigin, _dimToCreate, _nodeDone, _structOrigin, _scalarsVectorOut, + _childrenFinalOut, createStructGivenDimImpl); + }; + for(unsigned int i = 0; i < children.size(); ++i) + sortChildren(i, nodeDone, children); + Timer t_create; + for(unsigned int i = 0; i < dimFound.size(); ++i) { + if(dimFound[i]) + continue; + ftm::idNode structOrigin; + createStructGivenDim( + startNode, i, nodeDone, structOrigin, scalarsVector, childrenFinal); + } +} + +// --------------------------------------------------------------------------- +// --- Testing +// --------------------------------------------------------------------------- +bool ttk::MergeTreeNeuralLayer::isTreeHasBigValues(ftm::MergeTree &mTree, + float threshold) { + bool found = false; + for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) { + if(mTree.tree.isNodeAlone(n)) + continue; + auto birthDeath = mTree.tree.template getBirthDeath(n); + if(std::abs(std::get<0>(birthDeath)) > threshold + or std::abs(std::get<1>(birthDeath)) > threshold) { + found = true; + break; + } + } + return found; +} +#endif diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.h b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.h new file mode 100644 index 0000000000..bf6d708d15 --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralLayer.h @@ -0,0 +1,604 @@ +/// \ingroup base +/// \class ttk::MergeTreeNeuralLayer +/// \author Mathieu Pont +/// \date 2023. +/// +/// This module defines the %MergeTreeNeuralLayer class that provide methods to +/// define and use a wasserstein layer able to process merge trees or +/// persistence diagrams. +/// +/// To initialize the layer you can use the following functions: +/// "initInputBasisOrigin" and "initInputBasisVectors" for the input basis, then +/// "initOutputBasis" for the output basis. Please refer to the documentation of +/// these functions for how to use them. +/// +/// Then, you should call the "requires_grad" function, with a parameter sets to +/// true, to enable torch to compute gradient for this layer for it to be +/// optimized. +/// +/// The layer can then be used with the "forward" function to pass a topological +/// representation as input and get the topological representation at the output +/// of the layer (the transformed input). +/// +/// \b Related \b publication: \n +/// "Wasserstein Auto-Encoders of Merge Trees (and Persistence Diagrams)" \n +/// Mathieu Pont, Julien Tierny.\n +/// IEEE Transactions on Visualization and Computer Graphics, 2023 +/// + +#pragma once + +// ttk common includes +#include +#include +#include +#include + +#ifdef TTK_ENABLE_TORCH +#include +#endif + +namespace ttk { + + /** + * The MergeTreeNeuralLayer class provides methods to define and use a + * wasserstein layer able to process merge trees or persistence diagrams. + */ + class MergeTreeNeuralLayer : virtual public Debug, + public MergeTreeNeuralBase { + +#ifdef TTK_ENABLE_TORCH + // Layer parameters + torch::Tensor vSTensor_, vSPrimeTensor_, vS2Tensor_, vS2PrimeTensor_; + mtu::TorchMergeTree origin_, originPrime_, origin2_, origin2Prime_; +#endif + + public: + MergeTreeNeuralLayer(); + +#ifdef TTK_ENABLE_TORCH + // ----------------------------------------------------------------------- + // --- Getter/Setter + // ----------------------------------------------------------------------- + const mtu::TorchMergeTree &getOrigin() const; + + const mtu::TorchMergeTree &getOriginPrime() const; + + const mtu::TorchMergeTree &getOrigin2() const; + + const mtu::TorchMergeTree &getOrigin2Prime() const; + + const torch::Tensor &getVSTensor() const; + + const torch::Tensor &getVSPrimeTensor() const; + + const torch::Tensor &getVS2Tensor() const; + + const torch::Tensor &getVS2PrimeTensor() const; + + void setOrigin(const mtu::TorchMergeTree &tmt); + + void setOriginPrime(const mtu::TorchMergeTree &tmt); + + void setOrigin2(const mtu::TorchMergeTree &tmt); + + void setOrigin2Prime(const mtu::TorchMergeTree &tmt); + + void setVSTensor(const torch::Tensor &vS); + + void setVSPrimeTensor(const torch::Tensor &vS); + + void setVS2Tensor(const torch::Tensor &vS); + + void setVS2PrimeTensor(const torch::Tensor &vS); + + // ----------------------------------------------------------------------- + // --- Init + // ----------------------------------------------------------------------- + /** + * @brief Initialize the tree structure of the origin in an output basis + * whose scalars have already been initialized. + * + * @param[out] originPrime origin merge tree of an output basis. + * @param[in] isJT if the tree is a join tree. + * @param[in] baseOrigin a merge tree whose tree structure can be used to + * initialize the tree structure of originPrime (typically, it can be the + * origin of the input basis). + */ + void initOutputBasisTreeStructure(mtu::TorchMergeTree &originPrime, + bool isJT, + mtu::TorchMergeTree &baseOrigin); + + /** + * @brief Initialize the output basis. + * + * @param[in] dim the number of nodes in the origin of the output basis + * (corresponds to twice the number of persistence pairs). + * @param[in] dim2 same as dim but for second input, if any (i.e. when join + * trees and split trees are given). + * @param[in] baseTensor the scalars of a merge tree that can be used to + * initialize the scalars of the origin in the outuput basis (typically, it + * can be the scalars of the origin of the previous output basis or the + * origin of the input basis of this layer if this is the first one). + */ + void initOutputBasis(const unsigned int dim, + const unsigned int dim2, + const torch::Tensor &baseTensor); + + /** + * @brief Initialize the axes of the output basis. + * + * @param[in] w a matrix that will be used to compute the axes of the output + * basis, if B is a matrix corresponding to the axes of the input basis then + * the axes of the output basis will be initialized as wB. + * @param[in] w2 same as w but for second input, if any (i.e. when join + * trees and split trees are given). + */ + void initOutputBasisVectors(torch::Tensor &w, torch::Tensor &w2); + + /** + * @brief Initialize the axes of the output basis. + * + * @param[in] dim the number of nodes in the origin of the output basis + * (corresponds to twice the number of persistence pairs). + * @param[in] dim2 same as dim but for second input, if any (i.e. when join + * trees and split trees are given). + */ + void initOutputBasisVectors(unsigned int dim, unsigned int dim2); + + /** + * @brief Initialize the origin of the input basis as the barycenter of an + * ensemble of trees. + * + * @param[in] treesToUse the trees to use for the initialization (typically, + * the input trees of this layer). + * @param[in] trees2ToUse same as treesToUse but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] barycenterSizeLimitPercent the maximum number of nodes allowed + * for the barycenter as a percentage of the total number of nodes in the + * input trees (0 for no effect). + * @param[in] barycenterMaxNoPairs the maximum number of nodes in the + * barycenter (0 for no effect). + * @param[in] barycenterMaxNoPairs2 same as barycenterMaxNoPairs but for + * second input, if any (i.e. when join trees and split trees are given). + * @param[out] inputToBaryDistances the distances of the input trees to the + * origin of the basis. + * @param[out] baryMatchings the matchings between the input trees and the + * origin of the basis. + * @param[out] baryMatchings2 same as baryMatchings but for second input, if + * any (i.e. when join trees and split trees are given). + */ + void initInputBasisOrigin( + std::vector> &treesToUse, + std::vector> &trees2ToUse, + double barycenterSizeLimitPercent, + unsigned int barycenterMaxNoPairs, + unsigned int barycenterMaxNoPairs2, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2); + + /** + * @brief Initialize the axes of the input basis. + * + * @param[in] tmTrees the trees to use for the initialization as + * TorchMergeTree objects (typically, the input trees of this layer). + * @param[in] tmTrees2 same as tmTrees but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] trees the trees to use for the initialization as MergeTree + * objects (typically, the input trees of this layer). + * @param[in] trees2 same as tmTrees but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] noVectors number of axes in the basis. + * @param[out] allAlphasInit the coordinates of each input tree in the + * basis. + * @param[in] inputToBaryDistances the distances of the input trees to the + * origin of the basis. + * @param[in] baryMatchings the matchings between the input trees and the + * origin of the basis. + * @param[in] baryMatchings2 same as baryMatchings but for second input, if + * any (i.e. when join trees and split trees are given). + * @param[in] origin the origin of the basis. + * @param[in] origin2 same as origin but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[out] vSTensor the tensor representing the basis. + * @param[out] vS2Tensor same as vS2Tensor but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] useInputBasis this boolean allows this function to also be + * used for initializing the output basis (by setting this parameter to + * false). + */ + void initInputBasisVectors( + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector> &trees, + std::vector> &trees2, + unsigned int noVectors, + std::vector &allAlphasInit, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2, + mtu::TorchMergeTree &origin, + mtu::TorchMergeTree &origin2, + torch::Tensor &vSTensor, + torch::Tensor &vS2Tensor, + bool useInputBasis = true); + + /** + * @brief Overloaded function that initialize the axes of the input basis of + * this instantiated MergeTreeNeuralLayer object. + * + * @param[in] tmTrees the trees to use for the initialization as + * TorchMergeTree objects (typically, the input trees of this layer). + * @param[in] tmTrees2 same as tmTrees but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] trees the trees to use for the initialization as MergeTree + * objects (typically, the input trees of this layer). + * @param[in] trees2 same as tmTrees but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] noVectors number of axes in the basis. + * @param[out] allAlphasInit the coordinates of each input tree in the + * basis. + * @param[in] inputToBaryDistances the distances of the input trees to the + * origin of the basis. + * @param[in] baryMatchings the matchings between the input trees and the + * origin of the basis. + * @param[out] baryMatchings2 same as baryMatchings but for second input, if + * any (i.e. when join trees and split trees are given). + * @param[in] useInputBasis this boolean allows this function to also be + * used for initializing the output basis (by setting this parameter to + * false). + */ + void initInputBasisVectors( + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector> &trees, + std::vector> &trees2, + unsigned int noVectors, + std::vector &allAlphasInit, + std::vector &inputToBaryDistances, + std::vector>> + &baryMatchings, + std::vector>> + &baryMatchings2, + bool useInputBasis = true); + + void requires_grad(const bool requireGrad); + + void cuda(); + + // ----------------------------------------------------------------------- + // --- Interpolation + // ----------------------------------------------------------------------- + /** + * @brief Projection ensuring that no pairs are below diagonal. + * Warning: this function only updates the Tensor object and not the scalars + * of the MergeTree object, for this please call interpolationProjection. + * + * @param[in,out] interpolationTensor merge tree to process. + */ + void interpolationDiagonalProjection( + mtu::TorchMergeTree &interpolationTensor); + + /** + * @brief Projection ensuring the nesting condition. + * Warning: this function only updates the Tensor object and not the scalars + * of the MergeTree object, for this please call interpolationProjection. + * + * @param[in,out] interpolation merge tree to process. + */ + void + interpolationNestingProjection(mtu::TorchMergeTree &interpolation); + + /** + * @brief Projection ensuring the elder rule (no pairs below diagonal and + * nesting condition), updates the Tensor object AND the scalars of the + * MergeTree object. + * + * @param[in,out] interpolation merge tree to process. + */ + void interpolationProjection(mtu::TorchMergeTree &interpolation); + + /** + * @brief Creates the merge tree at coordinates alphas of the basis with + * origin as origin and vS as axes, followed by a projection ensuring the + * elder rule. + * + * @param[in] origin origin of the basis. + * @param[in] vS axes of the basis. + * @param[in] alphas coordinates on the basis to evaluate. + * @param[out] interpolation output merge tree. + */ + void getMultiInterpolation(const mtu::TorchMergeTree &origin, + const torch::Tensor &vS, + torch::Tensor &alphas, + mtu::TorchMergeTree &interpolation); + + // ----------------------------------------------------------------------- + // --- Forward + // ----------------------------------------------------------------------- + /** + * @brief Computes the necessary tensors used in the computation of the best + * coordinates in the input basis of the input merge tree at constant + * assignment. + * According Appendix B of the reference "Wasserstein Auto-Encoders of Merge + * Trees (and Persistence Diagrams), reorderedTreeTensor corresponds to + * Beta_1', deltaOrigin to Beta_3', deltaA to B_2', originTensor_f to O' and + * vSTensor_f to (B(O'))'. + * + * @param[in] tree input merge tree. + * @param[in] origin origin of the basis. + * @param[in] vSTensor axes of the basis. + * @param[in] interpolated current estimation of the input tree in the + * basis. + * @param[in] matching matching between the input tree and the current + * estimation in the basis. + * @param[out] reorderedTreeTensor reordered tensor of the input tree given + * the matching to its estimation on the basis (with zero on non-matched + * pairs). + * @param[out] deltaOrigin tensor of the projected pairs on the + * diagonal of the origin in the input tree. + * @param[out] deltaA tensor corresponding to the linear combination of the + * axes of the basis given the coordinates for the projected pairs on the + * diagonal of the estimated tree in the input tree. + * @param[out] originTensor_f tensor of the origin. + * @param[out] vSTensor_f tensor of the basis. + */ + void getAlphasOptimizationTensors( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &origin, + torch::Tensor &vSTensor, + mtu::TorchMergeTree &interpolated, + std::vector> &matching, + torch::Tensor &reorderedTreeTensor, + torch::Tensor &deltaOrigin, + torch::Tensor &deltaA, + torch::Tensor &originTensor_f, + torch::Tensor &vSTensor_f); + + /** + * @brief Computes the best coordinates in the input basis of the input + * merge tree at constant assignment. + * + * @param[in] tree input merge tree. + * @param[in] origin origin of the basis. + * @param[in] vSTensor axes of the basis. + * @param[in] interpolated current estimation of the input tree in the + * basis. + * @param[in] matching matching between the input tree and the current + * estimation in the basis. + * @param[in] tree2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] origin2 same as origin but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] vSTensor2 same as vSTensor but for second input, if any (i.e. + * when join trees and split trees are given). + * @param[in] interpolated2 same as interpolation but for second input, if + * any (i.e. when join trees and split trees are given). + * @param[in] matching2 same as matching but for second input, if any (i.e. + * when join trees and split trees are given). + * @param[out] alphasOut best coordinates of the input tree in the basis at + * constant assignment. + */ + void computeAlphas( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &origin, + torch::Tensor &vSTensor, + mtu::TorchMergeTree &interpolated, + std::vector> &matching, + mtu::TorchMergeTree &tree2, + mtu::TorchMergeTree &origin2, + torch::Tensor &vS2Tensor, + mtu::TorchMergeTree &interpolated2, + std::vector> &matching2, + torch::Tensor &alphasOut); + + /** + * @brief Estimates the coordinates in the input basis of the input merge + * tree. + * + * @param[in] tree input merge tree. + * @param[in] tree2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] alphasInit the initial coordinates to use when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] bestMatching matching between the input tree and the tree at + * the best estimation of the coordinates in the basis. + * @param[in] bestMatching2 same as bestMatching but for second input, if + * any (i.e. when join trees and split trees are given). + * @param[out] bestAlphas the best estimation of the coordinates in the + * input basis of the input merge tree. + * @param[in] isCalled true if this function is called from a parallalized + * context, i.e. if a team of threads has already been created and that + * therefore it is not needed to create one. + * @param[in] useInputBasis this boolean allows this function to also be + * used for the output basis (by setting this parameter to false). + * + * @return the distance between the input merge tree and its best estimated + * projection in the input basis. + */ + float assignmentOneData( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + std::vector> &bestMatching, + std::vector> &bestMatching2, + torch::Tensor &bestAlphas, + bool isCalled = false, + bool useInputBasis = true); + + /** + * @brief Estimates the coordinates in the input basis of the input merge + * tree. + * + * @param[in] tree input merge tree. + * @param[in] tree2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] alphasInit the initial coordinates to use when estimating the + * coordinates in the input basis of the input merge tree. + * @param[out] bestAlphas the best estimation of the coordinates in the + * input basis of the input merge tree. + * @param[in] isCalled true if this function is called from a parallalized + * context, i.e. if a team of threads has already been created and that + * therefore it is not needed to create one. + * @param[in] useInputBasis this boolean allows this function to also be + * used for the output basis (by setting this parameter to false). + * + * @return the distance between the input merge tree and its best estimated + * projection in the input basis. + */ + float assignmentOneData(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + torch::Tensor &bestAlphas, + bool isCalled = false, + bool useInputBasis = true); + + /** + * @brief Reconstruct an ouput merge tree given coordinates. + * + * @param[in] alphas coordinates to use in the output basis. + * @param[out] out the output merge tree. + * @param[out] out2 same as out but for second input, if any (i.e. when join + * trees and split trees are given). + * @param[in] activate true if activation function should be used, false + * otherwise. + * @param[in] train true if the input merge tree is in the training set + * (false if validation/testing set). + */ + void outputBasisReconstruction(torch::Tensor &alphas, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + bool activate = true, + bool train = false); + + /** + * @brief Pass a merge tree through the layer and get the output. + * + * @param[in] tree input merge tree. + * @param[in] tree2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] alphasInit the initial coordinates to use when estimating the + * coordinates in the input basis of the input merge tree. + * @param[out] out the output merge tree. + * @param[out] out2 same as out but for second input, if any (i.e. when join + * trees and split trees are given). + * @param[out] bestAlphas the best estimation of the coordinates in the + * input basis of the input merge tree. + * @param[out] bestDistance the distance between the input merge tree and + * its best estimated projection in the input basis. + * @param[in] train true if the input merge tree is in the training set + * (false if validation/testing set). + * + * @return true if the output merge tree has no nodes. + */ + bool forward(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + torch::Tensor &bestAlphas, + float &bestDistance, + bool train = false); + + /** + * @brief Pass a merge tree through the layer and get the output. + * + * @param[in] tree input merge tree. + * @param[in] tree2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] alphasInit the initial coordinates to use when estimating the + * coordinates in the input basis of the input merge tree. + * @param[out] out the output merge tree. + * @param[out] out2 same as out but for second input, if any (i.e. when join + * trees and split trees are given). + * @param[out] bestAlphas the best estimation of the coordinates in the + * input basis of the input merge tree. + * @param[in] train true if the input merge tree is in the training set + * (false if validation/testing set). + * + * @return true if the output merge tree has no nodes. + */ + bool forward(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int k, + torch::Tensor &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + torch::Tensor &bestAlphas, + bool train = false); + + // ----------------------------------------------------------------------- + // --- Projection + // ----------------------------------------------------------------------- + /** + * @brief Projection that ensures that the origins of the input and output + * bases respect the elder rule. + */ + void projectionStep(); + + // ----------------------------------------------------------------------- + // --- Utils + // ----------------------------------------------------------------------- + void copyParams(mtu::TorchMergeTree &origin, + mtu::TorchMergeTree &originPrime, + torch::Tensor &vS, + torch::Tensor &vSPrime, + mtu::TorchMergeTree &origin2, + mtu::TorchMergeTree &origin2Prime, + torch::Tensor &vS2, + torch::Tensor &vS2Prime, + bool get); + + /** + * @brief Fix the scalars of a merge tree to ensure that the nesting + * condition is respected. + * + * @param[in] scalarsVector scalars array to process. + * @param[in] node node to adjust. + * @param[in] refNode reference node. + */ + void adjustNestingScalars(std::vector &scalarsVector, + ftm::idNode node, + ftm::idNode refNode); + + /** + * @brief Create a balanced BDT structure (for output basis initialization). + * + * @param[in] parents vector containing the possible parents for each node. + * @param[in] children vector containing the possible children for each + * node. + * @param[in] scalarsVector vector containing the scalars value. + * @param[out] childrenFinal output vector containing the children of each + * node, representing the tree structure. + */ + void + createBalancedBDT(std::vector> &parents, + std::vector> &children, + std::vector &scalarsVector, + std::vector> &childrenFinal); + + // ----------------------------------------------------------------------- + // --- Testing + // ----------------------------------------------------------------------- + bool isTreeHasBigValues(ftm::MergeTree &mTree, + float threshold = 10000); +#endif + }; // MergeTreeNeuralLayer class + +} // namespace ttk diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.cpp b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.cpp new file mode 100644 index 0000000000..87fb5d9822 --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.cpp @@ -0,0 +1,1324 @@ +#include +#include + +#ifdef TTK_ENABLE_TORCH +using namespace torch::indexing; +#endif + +ttk::MergeTreeNeuralNetwork::MergeTreeNeuralNetwork() { + // inherited from Debug: prefix will be printed at the beginning of every msg + this->setDebugMsgPrefix("MergeTreeNeuralNetwork"); +} + +#ifdef TTK_ENABLE_TORCH +// ----------------------------------------------------------------------- +// --- Init +// ----------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::initInputBasis( + unsigned int l, + unsigned int layerNoAxes, + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector &ttkNotUsed(isTrain), + std::vector> &allAlphasInit) { + // TODO is there a way to avoid copy of merge trees? + std::vector> trees, trees2; + for(unsigned int i = 0; i < tmTrees.size(); ++i) { + trees.emplace_back(tmTrees[i].mTree); + if(useDoubleInput_) + trees2.emplace_back(tmTrees2[i].mTree); + } + + // - Compute origin + printMsg("Compute origin...", debug::Priority::DETAIL); + Timer t_origin; + std::vector inputToBaryDistances; + std::vector>> + baryMatchings, baryMatchings2; + if(l != 0 or not layers_[0].getOrigin().tensor.defined()) { + double sizeLimit = (l == 0 ? barycenterSizeLimitPercent_ : 0); + unsigned int maxNoPairs + = (l == 0 ? 0 : layers_[l - 1].getOriginPrime().tensor.sizes()[0] / 2); + unsigned int maxNoPairs2 + = (l == 0 or not useDoubleInput_ + ? 0 + : layers_[l - 1].getOrigin2Prime().tensor.sizes()[0] / 2); + layers_[l].initInputBasisOrigin(trees, trees2, sizeLimit, maxNoPairs, + maxNoPairs2, inputToBaryDistances, + baryMatchings, baryMatchings2); + if(l == 0) { + baryMatchings_L0_ = baryMatchings; + baryMatchings2_L0_ = baryMatchings2; + inputToBaryDistances_L0_ = inputToBaryDistances; + } + } else { + baryMatchings = baryMatchings_L0_; + baryMatchings2 = baryMatchings2_L0_; + inputToBaryDistances = inputToBaryDistances_L0_; + } + printMsg("Compute origin time", 1, t_origin.getElapsedTime(), threadNumber_, + debug::LineMode::NEW, debug::Priority::DETAIL); + + // - Compute vectors + printMsg("Compute vectors...", debug::Priority::DETAIL); + Timer t_vectors; + std::vector allAlphasInitT(tmTrees.size()); + layers_[l].initInputBasisVectors( + tmTrees, tmTrees2, trees, trees2, layerNoAxes, allAlphasInitT, + inputToBaryDistances, baryMatchings, baryMatchings2); + for(unsigned int i = 0; i < allAlphasInitT.size(); ++i) + allAlphasInit[i][l] = allAlphasInitT[i]; + printMsg("Compute vectors time", 1, t_vectors.getElapsedTime(), threadNumber_, + debug::LineMode::NEW, debug::Priority::DETAIL); +} + +void ttk::MergeTreeNeuralNetwork::initOutputBasis( + unsigned int l, + double layerOriginPrimeSizePercent, + std::vector> &tmTrees, + std::vector> &tmTrees2, + std::vector &ttkNotUsed(isTrain)) { + std::vector ftmTrees(tmTrees.size()), + ftmTrees2(tmTrees2.size()); + for(unsigned int i = 0; i < tmTrees.size(); ++i) + ftmTrees[i] = &(tmTrees[i].mTree.tree); + for(unsigned int i = 0; i < tmTrees2.size(); ++i) + ftmTrees2[i] = &(tmTrees2[i].mTree.tree); + auto sizeMetric = getSizeLimitMetric(ftmTrees); + auto sizeMetric2 = getSizeLimitMetric(ftmTrees2); + auto getDim = [](double _sizeMetric, double _percent) { + unsigned int dim = std::max((int)(_sizeMetric * _percent / 100.0), 2) * 2; + return dim; + }; + + unsigned int dim = getDim(sizeMetric, layerOriginPrimeSizePercent); + dim = std::min(dim, (unsigned int)layers_[l].getOrigin().tensor.sizes()[0]); + unsigned int dim2 = getDim(sizeMetric2, layerOriginPrimeSizePercent); + if(useDoubleInput_) + dim2 + = std::min(dim2, (unsigned int)layers_[l].getOrigin2().tensor.sizes()[0]); + auto baseTensor = (l == 0 ? layers_[0].getOrigin().tensor + : layers_[l - 1].getOriginPrime().tensor); + layers_[l].initOutputBasis(dim, dim2, baseTensor); +} + +bool ttk::MergeTreeNeuralNetwork::initGetReconstructed( + unsigned int l, + unsigned int layerNoAxes, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain, + std::vector> &recs, + std::vector> &recs2, + std::vector> &allAlphasInit) { + printMsg("Get reconstructed", debug::Priority::DETAIL); + recs.resize(trees.size()); + recs2.resize(trees.size()); + unsigned int i = 0; + unsigned int noReset = 0; + while(i < trees.size()) { + layers_[l].outputBasisReconstruction( + allAlphasInit[i][l], recs[i], recs2[i], activateOutputInit_); + if(recs[i].mTree.tree.getRealNumberOfNodes() == 0) { + bool fullReset = initResetOutputBasis( + l, layerNoAxes, layerOriginPrimeSizePercent, trees, trees2, isTrain); + if(fullReset) + return true; + i = 0; + ++noReset; + if(noReset >= 100) { + printWrn("[initParameters] noReset >= 100"); + return true; + } + } + ++i; + } + return false; +} + +void ttk::MergeTreeNeuralNetwork::initStep( + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain) { + layers_.clear(); + + float bestError = std::numeric_limits::max(); + std::vector bestVSTensor, bestVSPrimeTensor, bestVS2Tensor, + bestVS2PrimeTensor, bestLatentCentroids; + std::vector> bestOrigins, bestOriginsPrime, + bestOrigins2, bestOrigins2Prime; + std::vector> bestAlphasInit; + for(unsigned int n = 0; n < noInit_; ++n) { + // Init parameters + float error = initParameters(trees, trees2, isTrain, (noInit_ != 1)); + // Save best parameters + if(noInit_ != 1) { + std::stringstream ss; + ss << "Init error = " << error; + printMsg(ss.str()); + if(error < bestError) { + bestError = error; + copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, + bestVSPrimeTensor, bestOrigins2, bestOrigins2Prime, + bestVS2Tensor, bestVS2PrimeTensor, allAlphas_, + bestAlphasInit, true); + copyCustomParams(true); + } + } + } + // TODO this copy can be avoided if initParameters takes dummy tensors to fill + // as parameters and then copy to the member tensors when a better init is + // found. + if(noInit_ != 1) { + // Put back best parameters + std::stringstream ss; + ss << "Best init error = " << bestError; + printMsg(ss.str()); + copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, + bestOrigins2, bestOrigins2Prime, bestVS2Tensor, + bestVS2PrimeTensor, bestAlphasInit, allAlphas_, false); + copyCustomParams(false); + } + + for(unsigned int l = 0; l < noLayers_; ++l) { + layers_[l].requires_grad(true); + + // Print + printMsg(debug::Separator::L2); + std::stringstream ss; + ss << "Layer " << l; + printMsg(ss.str()); + if(isTreeHasBigValues(layers_[l].getOrigin().mTree, bigValuesThreshold_)) { + ss.str(""); + ss << "origins_[" << l << "] has big values!" << std::endl; + printMsg(ss.str()); + printPairs(layers_[l].getOrigin().mTree); + } + if(isTreeHasBigValues( + layers_[l].getOriginPrime().mTree, bigValuesThreshold_)) { + ss.str(""); + ss << "originsPrime_[" << l << "] has big values!" << std::endl; + printMsg(ss.str()); + printPairs(layers_[l].getOriginPrime().mTree); + } + ss.str(""); + ss << "vS size = " << layers_[l].getVSTensor().sizes(); + printMsg(ss.str()); + ss.str(""); + ss << "vS' size = " << layers_[l].getVSPrimeTensor().sizes(); + printMsg(ss.str()); + if(trees2.size() != 0) { + ss.str(""); + ss << "vS2 size = " << layers_[l].getVS2Tensor().sizes(); + printMsg(ss.str()); + ss.str(""); + ss << "vS2' size = " << layers_[l].getVS2PrimeTensor().sizes(); + printMsg(ss.str()); + } + } +} + +void ttk::MergeTreeNeuralNetwork::passLayerParameters( + MergeTreeNeuralLayer &layer) { + layer.setDropout(dropout_); + layer.setEuclideanVectorsInit(euclideanVectorsInit_); + layer.setRandomAxesInit(randomAxesInit_); + layer.setInitBarycenterRandom(initBarycenterRandom_); + layer.setInitBarycenterOneIter(initBarycenterOneIter_); + layer.setInitOriginPrimeStructByCopy(initOriginPrimeStructByCopy_); + layer.setInitOriginPrimeValuesByCopy(initOriginPrimeValuesByCopy_); + layer.setInitOriginPrimeValuesByCopyRandomness( + initOriginPrimeValuesByCopyRandomness_); + layer.setActivate(activate_); + layer.setActivationFunction(activationFunction_); + layer.setUseGpu(useGpu_); + layer.setBigValuesThreshold(bigValuesThreshold_); + + layer.setDeterministic(deterministic_); + layer.setNumberOfProjectionSteps(k_); + layer.setBarycenterSizeLimitPercent(barycenterSizeLimitPercent_); + layer.setProbabilisticVectorsInit(probabilisticVectorsInit_); + + layer.setNormalizedWasserstein(normalizedWasserstein_); + layer.setAssignmentSolver(assignmentSolverID_); + layer.setNodePerTask(nodePerTask_); + layer.setUseDoubleInput(useDoubleInput_); + layer.setJoinSplitMixtureCoefficient(mixtureCoefficient_); + layer.setIsPersistenceDiagram(isPersistenceDiagram_); + + layer.setDebugLevel(debugLevel_); + layer.setThreadNumber(threadNumber_); +} + +// --------------------------------------------------------------------------- +// --- Forward +// --------------------------------------------------------------------------- +bool ttk::MergeTreeNeuralNetwork::forwardOneData( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int treeIndex, + unsigned int k, + std::vector &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + std::vector &dataAlphas, + std::vector> &outs, + std::vector> &outs2, + bool train) { + outs.resize(noLayers_ - 1); + outs2.resize(noLayers_ - 1); + dataAlphas.resize(noLayers_); + for(unsigned int l = 0; l < noLayers_; ++l) { + auto &treeToUse = (l == 0 ? tree : outs[l - 1]); + auto &tree2ToUse = (l == 0 ? tree2 : outs2[l - 1]); + auto &outToUse = (l != noLayers_ - 1 ? outs[l] : out); + auto &out2ToUse = (l != noLayers_ - 1 ? outs2[l] : out2); + bool reset = layers_[l].forward(treeToUse, tree2ToUse, k, alphasInit[l], + outToUse, out2ToUse, dataAlphas[l], train); + if(reset) + return true; + // Update recs + auto updateRecs + = [this, &treeIndex, &l]( + std::vector>> &recs, + mtu::TorchMergeTree &outT) { + if(recs[treeIndex].size() > noLayers_) + mtu::copyTorchMergeTree(outT, recs[treeIndex][l + 1]); + else { + mtu::TorchMergeTree tmt; + mtu::copyTorchMergeTree(outT, tmt); + recs[treeIndex].emplace_back(tmt); + } + }; + updateRecs(recs_, outToUse); + if(useDoubleInput_) + updateRecs(recs2_, out2ToUse); + } + return false; +} + +bool ttk::MergeTreeNeuralNetwork::forwardStep( + std::vector> &trees, + std::vector> &trees2, + std::vector &indexes, + std::vector &isTrain, + unsigned int k, + std::vector> &allAlphasInit, + bool computeError, + std::vector> &outs, + std::vector> &outs2, + std::vector> &bestAlphas, + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector>> + &matchings, + std::vector>> + &matchings2, + float &loss, + float &testLoss) { + loss = 0; + testLoss = 0; + outs.resize(trees.size()); + outs2.resize(trees.size()); + bestAlphas.resize(trees.size()); + layersOuts.resize(trees.size()); + layersOuts2.resize(trees.size()); + matchings.resize(trees.size()); + if(useDoubleInput_) + matchings2.resize(trees2.size()); + mtu::TorchMergeTree dummyTMT; + bool reset = false; + unsigned int noTrainLoss = 0, noTestLoss = 0; +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) reduction(|| : reset) \ + reduction(+ : loss) +#endif + for(unsigned int ind = 0; ind < indexes.size(); ++ind) { + unsigned int i = indexes[ind]; + auto &tree2ToUse = (trees2.size() == 0 ? dummyTMT : trees2[i]); + bool dReset = forwardOneData(trees[i], tree2ToUse, i, k, allAlphasInit[i], + outs[i], outs2[i], bestAlphas[i], + layersOuts[i], layersOuts2[i], isTrain[i]); + if(computeError) { + float iLoss + = computeOneLoss(trees[i], outs[i], trees2[i], outs2[i], matchings[i], + matchings2[i], bestAlphas[i], i); + if(isTrain[i]) { + loss += iLoss; + ++noTrainLoss; + } else { + testLoss += iLoss; + ++noTestLoss; + } + } + if(dReset) + reset = reset || dReset; + } + if(noTrainLoss != 0) + loss /= noTrainLoss; + if(noTestLoss != 0) + testLoss /= noTestLoss; + return reset; +} + +bool ttk::MergeTreeNeuralNetwork::forwardStep( + std::vector> &trees, + std::vector> &trees2, + std::vector &indexes, + unsigned int k, + std::vector> &allAlphasInit, + bool computeError, + std::vector> &outs, + std::vector> &outs2, + std::vector> &bestAlphas, + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector>> + &matchings, + std::vector>> + &matchings2, + float &loss) { + std::vector isTrain(trees.size(), false); + float tempLoss; + return forwardStep(trees, trees2, indexes, isTrain, k, allAlphasInit, + computeError, outs, outs2, bestAlphas, layersOuts, + layersOuts2, matchings, matchings2, tempLoss, loss); +} + +// --------------------------------------------------------------------------- +// --- Projection +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::projectionStep() { + for(unsigned int l = 0; l < noLayers_; ++l) + layers_[l].projectionStep(); +} + +// ----------------------------------------------------------------------- +// --- Convergence +// ----------------------------------------------------------------------- +bool ttk::MergeTreeNeuralNetwork::isBestLoss(float loss, + float &minLoss, + unsigned int &cptBlocked) { + bool isBestEnergy = false; + if(loss + ENERGY_COMPARISON_TOLERANCE < minLoss) { + minLoss = loss; + cptBlocked = 0; + isBestEnergy = true; + } + return isBestEnergy; +} + +bool ttk::MergeTreeNeuralNetwork::convergenceStep(float loss, + float &oldLoss, + float &minLoss, + unsigned int &cptBlocked) { + double tol = oldLoss / 125.0; + bool converged = std::abs(loss - oldLoss) < std::abs(tol); + oldLoss = loss; + if(not converged) { + cptBlocked += (minLoss < loss) ? 1 : 0; + converged = (cptBlocked >= 10 * 10); + if(converged) + printMsg("Blocked!", debug::Priority::DETAIL); + } + return converged; +} + +// ----------------------------------------------------------------------- +// --- Main Functions +// ----------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::fit( + std::vector> &trees, + std::vector> &trees2) { + torch::set_num_threads(1); + if(useGpu_) { + if(torch::cuda::device_count() > 0 and torch::cuda::is_available()) + printMsg("Computation with GPU support."); + else { + printMsg("Disabling GPU support because no device were found."); + useGpu_ = false; + // TODO cache useGpu parameter to be in accordance with ParaView GUI + } + } else { + printMsg("Computation without GPU support."); + } + // ----- Determinism + if(deterministic_) { + int m_seed = 0; + bool m_torch_deterministic = true; + srand(m_seed); + torch::manual_seed(m_seed); + at::globalContext().setDeterministicCuDNN(m_torch_deterministic ? true + : false); + if(not useGpu_) + at::globalContext().setDeterministicAlgorithms( + m_torch_deterministic ? true : false, true); + } + + // ----- Testing + for(unsigned int i = 0; i < trees.size(); ++i) { + for(unsigned int n = 0; n < trees[i].tree.getNumberOfNodes(); ++n) { + if(trees[i].tree.isNodeAlone(n)) + continue; + auto birthDeath = trees[i].tree.template getBirthDeath(n); + bigValuesThreshold_ + = std::max(std::abs(std::get<0>(birthDeath)), bigValuesThreshold_); + bigValuesThreshold_ + = std::max(std::abs(std::get<1>(birthDeath)), bigValuesThreshold_); + } + } + bigValuesThreshold_ *= 100; + + // ----- Convert MergeTree to TorchMergeTree + std::vector> torchTrees, torchTrees2; + mergeTreesToTorchTrees(trees, torchTrees, normalizedWasserstein_); + mergeTreesToTorchTrees(trees2, torchTrees2, normalizedWasserstein_); + if(useGpu_) { + for(unsigned i = 0; i < torchTrees.size(); ++i) + torchTrees[i].tensor = torchTrees[i].tensor.cuda(); + for(unsigned i = 0; i < torchTrees2.size(); ++i) + torchTrees2[i].tensor = torchTrees2[i].tensor.cuda(); + } + + auto initRecs = [](std::vector>> &recs, + std::vector> &torchTreesT) { + recs.clear(); + recs.resize(torchTreesT.size()); + for(unsigned int i = 0; i < torchTreesT.size(); ++i) { + mtu::TorchMergeTree tmt; + mtu::copyTorchMergeTree(torchTreesT[i], tmt); + recs[i].emplace_back(tmt); + } + }; + initRecs(recs_, torchTrees); + if(useDoubleInput_) + initRecs(recs2_, torchTrees2); + + // --- Train/Test Split + unsigned int trainSize = std::min( + std::max((int)(trees.size() * trainTestSplit_), 1), (int)trees.size()); + std::vector trainIndexes(trees.size()), testIndexes; + std::iota(trainIndexes.begin(), trainIndexes.end(), 0); + std::random_device rd; + std::default_random_engine rng(deterministic_ ? 0 : rd()); + bool trainTestSplitted = trainSize != trees.size(); + if(trainTestSplitted) { + if(shuffleBeforeSplit_) + std::shuffle(trainIndexes.begin(), trainIndexes.end(), rng); + testIndexes.insert( + testIndexes.end(), trainIndexes.begin() + trainSize, trainIndexes.end()); + trainIndexes.resize(trainSize); + } + std::vector isTrain(trees.size(), true); + for(auto &ind : testIndexes) + isTrain[ind] = false; + + // ----- Custom Init + customInit(torchTrees, torchTrees2); + + // ----- Init Model Parameters + Timer t_init; + initStep(torchTrees, torchTrees2, isTrain); + printMsg("Init", 1, t_init.getElapsedTime(), threadNumber_); + + // --- Init optimizer + std::vector parameters; + for(unsigned int l = 0; l < noLayers_; ++l) { + parameters.emplace_back(layers_[l].getOrigin().tensor); + parameters.emplace_back(layers_[l].getOriginPrime().tensor); + parameters.emplace_back(layers_[l].getVSTensor()); + parameters.emplace_back(layers_[l].getVSPrimeTensor()); + if(trees2.size() != 0) { + parameters.emplace_back(layers_[l].getOrigin2().tensor); + parameters.emplace_back(layers_[l].getOrigin2Prime().tensor); + parameters.emplace_back(layers_[l].getVS2Tensor()); + parameters.emplace_back(layers_[l].getVS2PrimeTensor()); + } + } + addCustomParameters(parameters); + + torch::optim::Optimizer *optimizer; + // - Init Adam + auto adamOptions = torch::optim::AdamOptions(gradientStepSize_); + adamOptions.betas(std::make_tuple(beta1_, beta2_)); + auto adamOptimizer = torch::optim::Adam(parameters, adamOptions); + // - Init SGD optimizer + auto sgdOptions = torch::optim::SGDOptions(gradientStepSize_); + auto sgdOptimizer = torch::optim::SGD(parameters, sgdOptions); + // -Init RMSprop optimizer + auto rmspropOptions = torch::optim::RMSpropOptions(gradientStepSize_); + auto rmspropOptimizer = torch::optim::RMSprop(parameters, rmspropOptions); + // - Set optimizer pointer + switch(optimizer_) { + case 1: + optimizer = &sgdOptimizer; + break; + case 2: + optimizer = &rmspropOptimizer; + break; + case 0: + default: + optimizer = &adamOptimizer; + } + + // --- Print train/test split + if(trainTestSplitted) { + std::stringstream ss; + ss << "trainSize = " << trainIndexes.size() << " / " << trees.size(); + printMsg(ss.str()); + ss.str(""); + ss << "testSize = " << testIndexes.size() << " / " << trees.size(); + printMsg(ss.str()); + } + + // --- Init batches indexes + unsigned int batchSize + = std::min(std::max((int)(trainIndexes.size() * batchSize_), 1), + (int)trainIndexes.size()); + std::stringstream ssBatch; + ssBatch << "batchSize = " << batchSize; + printMsg(ssBatch.str()); + unsigned int noBatch = trainIndexes.size() / batchSize + + ((trainIndexes.size() % batchSize) != 0 ? 1 : 0); + std::vector> allIndexes(noBatch); + if(noBatch == 1) { + // Yes, trees.size() below is correct and it is not trainIndexes.size(), the + // goal is to forward everyone (even test data) if noBatch == 1 to benefit + // from full parallelism, but only train data will be used for backward. + allIndexes[0].resize(trees.size()); + std::iota(allIndexes[0].begin(), allIndexes[0].end(), 0); + } + + // ----- Testing + originsNoZeroGrad_.resize(noLayers_); + originsPrimeNoZeroGrad_.resize(noLayers_); + vSNoZeroGrad_.resize(noLayers_); + vSPrimeNoZeroGrad_.resize(noLayers_); + for(unsigned int l = 0; l < noLayers_; ++l) { + originsNoZeroGrad_[l] = 0; + originsPrimeNoZeroGrad_[l] = 0; + vSNoZeroGrad_[l] = 0; + vSPrimeNoZeroGrad_[l] = 0; + } + if(useDoubleInput_) { + origins2NoZeroGrad_.resize(noLayers_); + origins2PrimeNoZeroGrad_.resize(noLayers_); + vS2NoZeroGrad_.resize(noLayers_); + vS2PrimeNoZeroGrad_.resize(noLayers_); + for(unsigned int l = 0; l < noLayers_; ++l) { + origins2NoZeroGrad_[l] = 0; + origins2PrimeNoZeroGrad_[l] = 0; + vS2NoZeroGrad_[l] = 0; + vS2PrimeNoZeroGrad_[l] = 0; + } + } + + // ----- Init Variables + unsigned int k = k_; + float oldLoss, minLoss, minTestLoss; + std::vector minCustomLoss; + unsigned int cptBlocked, iteration = 0; + auto initLoop = [&]() { + oldLoss = -1; + minLoss = std::numeric_limits::max(); + minTestLoss = std::numeric_limits::max(); + cptBlocked = 0; + iteration = 0; + }; + initLoop(); + int convWinSize = 5; + int noConverged = 0, noConvergedToGet = 10; + std::vector gapLosses, gapTestLosses; + std::vector> gapCustomLosses; + float windowLoss = 0; + + double assignmentTime = 0.0, updateTime = 0.0, projectionTime = 0.0, + lossTime = 0.0; + + int bestIteration = 0; + std::vector bestVSTensor, bestVSPrimeTensor, bestVS2Tensor, + bestVS2PrimeTensor; + std::vector> bestOrigins, bestOriginsPrime, + bestOrigins2, bestOrigins2Prime; + std::vector> bestAlphasInit; + std::vector>> bestRecs, bestRecs2; + double bestTime = 0; + + auto printLoss = [this, trainTestSplitted]( + float loss, float testLoss, std::vector &customLoss, + int iterationT, int iterationTT, double time, + const debug::Priority &priority = debug::Priority::INFO) { + std::stringstream prefix; + prefix << (priority == debug::Priority::VERBOSE ? "Iter " : "Best "); + std::stringstream ssBestLoss; + ssBestLoss << prefix.str() << "loss is " << loss << " (iteration " + << iterationT << " / " << iterationTT << ") at time " << time; + printMsg(ssBestLoss.str(), priority); + if(trainTestSplitted) { + ssBestLoss.str(""); + ssBestLoss << prefix.str() << "test loss is " << testLoss; + printMsg(ssBestLoss.str(), priority); + } + printCustomLosses(customLoss, prefix, priority); + }; + + auto copyAlphas = [this](std::vector> &alphas, + std::vector &indexes) { + for(unsigned int ind = 0; ind < indexes.size(); ++ind) { + unsigned int i = indexes[ind]; + for(unsigned int j = 0; j < alphas[i].size(); ++j) + mtu::copyTensor(alphas[i][j], allAlphas_[i][j]); + } + }; + + // ----- Algorithm + Timer t_alg; + bool converged = false; + while(not converged) { + if(iteration % iterationGap_ == 0) { + std::stringstream ss; + ss << "Iteration " << iteration; + printMsg(debug::Separator::L2); + printMsg(ss.str()); + } + + bool forwardReset = false; + std::vector iterationLosses, iterationTestLosses; + std::vector> iterationCustomLosses; + if(noBatch != 1) { + std::vector indexes = trainIndexes; + std::shuffle(std::begin(indexes), std::end(indexes), rng); + for(unsigned int i = 0; i < allIndexes.size(); ++i) { + unsigned int noProcessed = batchSize * i; + unsigned int remaining = trainIndexes.size() - noProcessed; + unsigned int size = std::min(batchSize, remaining); + allIndexes[i].resize(size); + for(unsigned int j = 0; j < size; ++j) + allIndexes[i][j] = indexes[noProcessed + j]; + } + } + for(unsigned batchNum = 0; batchNum < allIndexes.size(); ++batchNum) { + auto &indexes = allIndexes[batchNum]; + + // --- Forward + Timer t_assignment; + std::vector> outs, outs2; + std::vector> bestAlphas; + std::vector>> layersOuts, + layersOuts2; + std::vector>> + matchings, matchings2; + float loss, testLoss; + bool computeError = true; + forwardReset + = forwardStep(torchTrees, torchTrees2, indexes, isTrain, k, allAlphas_, + computeError, outs, outs2, bestAlphas, layersOuts, + layersOuts2, matchings, matchings2, loss, testLoss); + if(forwardReset) + break; + copyAlphas(bestAlphas, indexes); + assignmentTime += t_assignment.getElapsedTime(); + + // --- Loss + Timer t_loss; + gapLosses.emplace_back(loss); + iterationLosses.emplace_back(loss); + if(noBatch == 1 and trainTestSplitted) { + gapTestLosses.emplace_back(testLoss); + iterationTestLosses.emplace_back(testLoss); + } + std::vector torchCustomLoss; + computeCustomLosses(layersOuts, layersOuts2, bestAlphas, indexes, isTrain, + iteration, gapCustomLosses, iterationCustomLosses, + torchCustomLoss); + lossTime += t_loss.getElapsedTime(); + + // --- Backward + Timer t_update; + backwardStep(torchTrees, outs, matchings, torchTrees2, outs2, matchings2, + bestAlphas, *optimizer, indexes, isTrain, torchCustomLoss); + updateTime += t_update.getElapsedTime(); + + // --- Projection + Timer t_projection; + projectionStep(); + projectionTime += t_projection.getElapsedTime(); + } // end batch + + if(noBatch != 1 and trainTestSplitted) { + std::vector> outs, outs2; + std::vector> bestAlphas; + std::vector>> layersOuts, + layersOuts2; + std::vector>> + matchings, matchings2; + float loss, testLoss; + bool computeError = true; + forwardStep(torchTrees, torchTrees2, testIndexes, isTrain, k, allAlphas_, + computeError, outs, outs2, bestAlphas, layersOuts, + layersOuts2, matchings, matchings2, loss, testLoss); + copyAlphas(bestAlphas, testIndexes); + gapTestLosses.emplace_back(testLoss); + iterationTestLosses.emplace_back(testLoss); + std::vector torchCustomLoss; + computeCustomLosses(layersOuts, layersOuts2, bestAlphas, testIndexes, + isTrain, iteration, gapCustomLosses, + iterationCustomLosses, torchCustomLoss); + } + + if(forwardReset) { + // TODO better manage reset by init new parameters and start again for + // example (should not happen anymore) + printWrn("Forward reset!"); + break; + } + + // --- Get iteration loss + // TODO an approximation is made here if batch size != 1 because the + // iteration loss that will be printed will not be exact, we need to do a + // forward step and compute loss with the whole dataset + float iterationLoss = torch::tensor(iterationLosses).mean().item(); + float iterationTestLoss + = torch::tensor(iterationTestLosses).mean().item(); + std::vector iterationCustomLoss; + float iterationTotalLoss = computeIterationTotalLoss( + iterationLoss, iterationCustomLosses, iterationCustomLoss); + printLoss(iterationTotalLoss, iterationTestLoss, iterationCustomLoss, + iteration, iteration, + t_alg.getElapsedTime() - t_allVectorCopy_time_, + debug::Priority::VERBOSE); + + // --- Update best parameters + bool isBest = false; + if(not trainTestSplitted) + isBest = isBestLoss(iterationTotalLoss, minLoss, cptBlocked); + else { + // TODO generalize these lines when accuracy is not the metric computed or + // evaluated + if(minCustomLoss.empty()) + isBest = true; + else { + float minusAcc = -iterationCustomLoss[1]; + float minMinusAcc = -minCustomLoss[1]; + isBest = isBestLoss(minusAcc, minMinusAcc, cptBlocked); + } + } + if(isBest) { + Timer t_copy; + bestIteration = iteration; + copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, + bestOrigins2, bestOrigins2Prime, bestVS2Tensor, + bestVS2PrimeTensor, allAlphas_, bestAlphasInit, true); + copyCustomParams(true); + copyParams(recs_, bestRecs); + copyParams(recs2_, bestRecs2); + t_allVectorCopy_time_ += t_copy.getElapsedTime(); + bestTime = t_alg.getElapsedTime() - t_allVectorCopy_time_; + minCustomLoss = iterationCustomLoss; + if(trainTestSplitted) { + minLoss = iterationTotalLoss; + minTestLoss = iterationTestLoss; + } + printLoss(minLoss, minTestLoss, minCustomLoss, bestIteration, iteration, + bestTime, debug::Priority::DETAIL); + } + + // --- Convergence + windowLoss += iterationTotalLoss; + if((iteration + 1) % convWinSize == 0) { + windowLoss /= convWinSize; + converged = convergenceStep(windowLoss, oldLoss, minLoss, cptBlocked); + windowLoss = 0; + if(converged) { + ++noConverged; + } else + noConverged = 0; + converged = noConverged >= noConvergedToGet; + if(converged and iteration < minIteration_) + printMsg("convergence is detected but iteration < minIteration_", + debug::Priority::DETAIL); + if(iteration < minIteration_) + converged = false; + if(converged) + break; + } + + // --- Print + if(iteration % iterationGap_ == 0) { + printMsg("Assignment", 1, assignmentTime, threadNumber_); + printMsg("Loss", 1, lossTime, threadNumber_); + printMsg("Update", 1, updateTime, threadNumber_); + printMsg("Projection", 1, projectionTime, threadNumber_); + assignmentTime = 0.0; + lossTime = 0.0; + updateTime = 0.0; + projectionTime = 0.0; + float loss = torch::tensor(gapLosses).mean().item(); + gapLosses.clear(); + float testLoss = torch::tensor(gapTestLosses).mean().item(); + gapTestLosses.clear(); + if(trainTestSplitted) { + std::stringstream ss; + ss << "Test Loss = " << testLoss; + printMsg(ss.str()); + } + printGapLoss(loss, gapCustomLosses); + + // Verify grad and big values (testing) + for(unsigned int l = 0; l < noLayers_; ++l) { + std::stringstream ss; + if(originsNoZeroGrad_[l] != 0) + ss << originsNoZeroGrad_[l] << " originsNoZeroGrad_[" << l << "]" + << std::endl; + if(originsPrimeNoZeroGrad_[l] != 0) + ss << originsPrimeNoZeroGrad_[l] << " originsPrimeNoZeroGrad_[" << l + << "]" << std::endl; + if(vSNoZeroGrad_[l] != 0) + ss << vSNoZeroGrad_[l] << " vSNoZeroGrad_[" << l << "]" << std::endl; + if(vSPrimeNoZeroGrad_[l] != 0) + ss << vSPrimeNoZeroGrad_[l] << " vSPrimeNoZeroGrad_[" << l << "]" + << std::endl; + originsNoZeroGrad_[l] = 0; + originsPrimeNoZeroGrad_[l] = 0; + vSNoZeroGrad_[l] = 0; + vSPrimeNoZeroGrad_[l] = 0; + if(useDoubleInput_) { + if(origins2NoZeroGrad_[l] != 0) + ss << origins2NoZeroGrad_[l] << " origins2NoZeroGrad_[" << l << "]" + << std::endl; + if(origins2PrimeNoZeroGrad_[l] != 0) + ss << origins2PrimeNoZeroGrad_[l] << " origins2PrimeNoZeroGrad_[" + << l << "]" << std::endl; + if(vS2NoZeroGrad_[l] != 0) + ss << vS2NoZeroGrad_[l] << " vS2NoZeroGrad_[" << l << "]" + << std::endl; + if(vS2PrimeNoZeroGrad_[l] != 0) + ss << vS2PrimeNoZeroGrad_[l] << " vS2PrimeNoZeroGrad_[" << l << "]" + << std::endl; + origins2NoZeroGrad_[l] = 0; + origins2PrimeNoZeroGrad_[l] = 0; + vS2NoZeroGrad_[l] = 0; + vS2PrimeNoZeroGrad_[l] = 0; + } + if(isTreeHasBigValues( + layers_[l].getOrigin().mTree, bigValuesThreshold_)) + ss << "origins_[" << l << "] has big values!" << std::endl; + if(isTreeHasBigValues( + layers_[l].getOriginPrime().mTree, bigValuesThreshold_)) + ss << "originsPrime_[" << l << "] has big values!" << std::endl; + if(ss.rdbuf()->in_avail() != 0) + printMsg(ss.str(), debug::Priority::DETAIL); + } + } + + ++iteration; + if(maxIteration_ != 0 and iteration >= maxIteration_) { + printMsg("iteration >= maxIteration_", debug::Priority::DETAIL); + break; + } + } + printMsg(debug::Separator::L2); + printLoss( + minLoss, minTestLoss, minCustomLoss, bestIteration, iteration, bestTime); + printMsg(debug::Separator::L2); + bestLoss_ = (trainTestSplitted ? minTestLoss : minLoss); + + Timer t_copy; + copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor, + bestOrigins2, bestOrigins2Prime, bestVS2Tensor, bestVS2PrimeTensor, + bestAlphasInit, allAlphas_, false); + copyCustomParams(false); + copyParams(bestRecs, recs_); + copyParams(bestRecs2, recs2_); + t_allVectorCopy_time_ += t_copy.getElapsedTime(); + printMsg("Copy time", 1, t_allVectorCopy_time_, threadNumber_); +} + +// --------------------------------------------------------------------------- +// --- End Functions +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::computeTrackingInformation( + unsigned int endLayer) { + originsMatchings_.resize(endLayer); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int l = 0; l < endLayer; ++l) { + auto &tree1 + = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime()); + auto &tree2 + = (l == 0 ? layers_[0].getOriginPrime() : layers_[l].getOriginPrime()); + bool isCalled = true; + float distance; + computeOneDistance(tree1.mTree, tree2.mTree, originsMatchings_[l], + distance, isCalled, useDoubleInput_); + } + + // Data matchings + ++endLayer; + dataMatchings_.resize(endLayer); + for(unsigned int l = 0; l < endLayer; ++l) { + dataMatchings_[l].resize(recs_.size()); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < recs_.size(); ++i) { + bool isCalled = true; + float distance; + auto &origin + = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime()); + computeOneDistance(origin.mTree, recs_[i][l].mTree, + dataMatchings_[l][i], distance, isCalled, + useDoubleInput_); + } + } + + // Reconst matchings + reconstMatchings_.resize(recs_.size()); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < recs_.size(); ++i) { + bool isCalled = true; + float distance; + auto l = recs_[i].size() - 1; + computeOneDistance(recs_[i][0].mTree, recs_[i][l].mTree, + reconstMatchings_[i], distance, isCalled, + useDoubleInput_); + } +} + +void ttk::MergeTreeNeuralNetwork::computeCorrelationMatrix( + std::vector> &trees, unsigned int layer) { + std::vector> allTs; + auto noGeod = allAlphas_[0][layer].sizes()[0]; + allTs.resize(noGeod); + for(unsigned int i = 0; i < noGeod; ++i) { + allTs[i].resize(allAlphas_.size()); + for(unsigned int j = 0; j < allAlphas_.size(); ++j) + allTs[i][j] = allAlphas_[j][layer][i].item(); + } + computeBranchesCorrelationMatrix( + layers_[0].getOrigin().mTree, trees, dataMatchings_[0], allTs, + branchesCorrelationMatrix_, persCorrelationMatrix_); +} + +void ttk::MergeTreeNeuralNetwork::createScaledAlphas( + std::vector> &alphas, + std::vector> &scaledAlphas) { + scaledAlphas.clear(); + scaledAlphas.resize( + alphas.size(), std::vector(alphas[0].size())); + for(unsigned int l = 0; l < alphas[0].size(); ++l) { + torch::Tensor scale = layers_[l].getVSTensor().pow(2).sum(0).sqrt(); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < alphas.size(); ++i) { + scaledAlphas[i][l] = alphas[i][l] * scale.reshape({-1, 1}); + } + } +} + +void ttk::MergeTreeNeuralNetwork::createScaledAlphas() { + createScaledAlphas(allAlphas_, allScaledAlphas_); +} + +void ttk::MergeTreeNeuralNetwork::createActivatedAlphas() { + allActAlphas_ = allAlphas_; +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < allActAlphas_.size(); ++i) + for(unsigned int j = 0; j < allActAlphas_[i].size(); ++j) + allActAlphas_[i][j] = activation(allActAlphas_[i][j]); + createScaledAlphas(allActAlphas_, allActScaledAlphas_); +} + +// --------------------------------------------------------------------------- +// --- Utils +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::copyParams( + std::vector> &origins, + std::vector> &originsPrime, + std::vector &vS, + std::vector &vSPrime, + std::vector> &origins2, + std::vector> &origins2Prime, + std::vector &vS2, + std::vector &vS2Prime, + std::vector> &srcAlphas, + std::vector> &dstAlphas, + bool get) { + dstAlphas.resize(srcAlphas.size(), std::vector(noLayers_)); + if(get) { + origins.resize(noLayers_); + originsPrime.resize(noLayers_); + vS.resize(noLayers_); + vSPrime.resize(noLayers_); + if(useDoubleInput_) { + origins2.resize(noLayers_); + origins2Prime.resize(noLayers_); + vS2.resize(noLayers_); + vS2Prime.resize(noLayers_); + } + } + for(unsigned int l = 0; l < noLayers_; ++l) { + layers_[l].copyParams(origins[l], originsPrime[l], vS[l], vSPrime[l], + origins2[l], origins2Prime[l], vS2[l], vS2Prime[l], + get); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < srcAlphas.size(); ++i) + mtu::copyTensor(srcAlphas[i][l], dstAlphas[i][l]); + } +} + +void ttk::MergeTreeNeuralNetwork::copyParams( + std::vector>> &src, + std::vector>> &dst) { + dst.resize(src.size()); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < src.size(); ++i) { + dst[i].resize(src[i].size()); + for(unsigned int j = 0; j < src[i].size(); ++j) + mtu::copyTorchMergeTree(src[i][j], dst[i][j]); + } +} + +void ttk::MergeTreeNeuralNetwork::getAlphasTensor( + std::vector> &alphas, + std::vector &indexes, + std::vector &toGet, + unsigned int layerIndex, + torch::Tensor &alphasOut) { + unsigned int beg = 0; + while(not toGet[indexes[beg]]) + ++beg; + alphasOut = alphas[indexes[beg]][layerIndex].transpose(0, 1); + for(unsigned int ind = beg + 1; ind < indexes.size(); ++ind) { + if(not toGet[indexes[ind]]) + continue; + alphasOut = torch::cat( + {alphasOut, alphas[indexes[ind]][layerIndex].transpose(0, 1)}); + } +} + +void ttk::MergeTreeNeuralNetwork::getAlphasTensor( + std::vector> &alphas, + std::vector &indexes, + unsigned int layerIndex, + torch::Tensor &alphasOut) { + std::vector toGet(indexes.size(), true); + getAlphasTensor(alphas, indexes, toGet, layerIndex, alphasOut); +} + +void ttk::MergeTreeNeuralNetwork::getAlphasTensor( + std::vector> &alphas, + unsigned int layerIndex, + torch::Tensor &alphasOut) { + std::vector indexes(alphas.size()); + std::iota(indexes.begin(), indexes.end(), 0); + getAlphasTensor(alphas, indexes, layerIndex, alphasOut); +} + +// --------------------------------------------------------------------------- +// --- Testing +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::checkZeroGrad(unsigned int l, + bool checkOutputBasis) { + if(not layers_[l].getOrigin().tensor.grad().defined() + or not layers_[l].getOrigin().tensor.grad().count_nonzero().is_nonzero()) + ++originsNoZeroGrad_[l]; + if(not layers_[l].getVSTensor().grad().defined() + or not layers_[l].getVSTensor().grad().count_nonzero().is_nonzero()) + ++vSNoZeroGrad_[l]; + if(checkOutputBasis) { + if(not layers_[l].getOriginPrime().tensor.grad().defined() + or not layers_[l] + .getOriginPrime() + .tensor.grad() + .count_nonzero() + .is_nonzero()) + ++originsPrimeNoZeroGrad_[l]; + if(not layers_[l].getVSPrimeTensor().grad().defined() + or not layers_[l].getVSPrimeTensor().grad().count_nonzero().is_nonzero()) + ++vSPrimeNoZeroGrad_[l]; + } + if(useDoubleInput_) { + if(not layers_[l].getOrigin2().tensor.grad().defined() + or not layers_[l] + .getOrigin2() + .tensor.grad() + .count_nonzero() + .is_nonzero()) + ++origins2NoZeroGrad_[l]; + if(not layers_[l].getVS2Tensor().grad().defined() + or not layers_[l].getVS2Tensor().grad().count_nonzero().is_nonzero()) + ++vS2NoZeroGrad_[l]; + if(checkOutputBasis) { + if(not layers_[l].getOrigin2Prime().tensor.grad().defined() + or not layers_[l] + .getOrigin2Prime() + .tensor.grad() + .count_nonzero() + .is_nonzero()) + ++origins2PrimeNoZeroGrad_[l]; + if(not layers_[l].getVS2PrimeTensor().grad().defined() + or not layers_[l] + .getVS2PrimeTensor() + .grad() + .count_nonzero() + .is_nonzero()) + ++vS2PrimeNoZeroGrad_[l]; + } + } +} + +bool ttk::MergeTreeNeuralNetwork::isTreeHasBigValues( + const ftm::MergeTree &mTree, float threshold) { + bool found = false; + for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) { + if(mTree.tree.isNodeAlone(n)) + continue; + auto birthDeath = mTree.tree.template getBirthDeath(n); + if(std::abs(std::get<0>(birthDeath)) > threshold + or std::abs(std::get<1>(birthDeath)) > threshold) { + found = true; + break; + } + } + return found; +} +#endif + +// --------------------------------------------------------------------------- +// --- Main Functions +// --------------------------------------------------------------------------- +void ttk::MergeTreeNeuralNetwork::execute( + std::vector> &trees, + std::vector> &trees2) { +#ifndef TTK_ENABLE_TORCH + TTK_FORCE_USE(trees); + TTK_FORCE_USE(trees2); + printErr("This module requires Torch."); +#else +#ifdef TTK_ENABLE_OPENMP + int ompNested = omp_get_nested(); + omp_set_nested(1); +#endif + // makeExponentialExample(trees, trees2); + + // --- Preprocessing + Timer t_preprocess; + preprocessingTrees(trees, treesNodeCorr_); + if(trees2.size() != 0) + preprocessingTrees(trees2, trees2NodeCorr_); + printMsg("Preprocessing", 1, t_preprocess.getElapsedTime(), threadNumber_); + useDoubleInput_ = (trees2.size() != 0); + + // --- Fit neural network + Timer t_total; + fit(trees, trees2); + auto totalTime = t_total.getElapsedTime() - t_allVectorCopy_time_; + printMsg(debug::Separator::L1); + printMsg("Total time", 1, totalTime, threadNumber_); + + // --- End functions + Timer t_end; + createScaledAlphas(); + createActivatedAlphas(); + executeEndFunction(trees, trees2); + printMsg("End functions", 1, t_end.getElapsedTime(), threadNumber_); + + // --- Postprocessing + if(createOutput_) { +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < trees.size(); ++i) + postprocessingPipeline(&(trees[i].tree)); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < trees2.size(); ++i) + postprocessingPipeline(&(trees2[i].tree)); + + originsCopy_.resize(layers_.size()); + originsPrimeCopy_.resize(layers_.size()); +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int l = 0; l < layers_.size(); ++l) { + mtu::copyTorchMergeTree(layers_[l].getOrigin(), originsCopy_[l]); + mtu::copyTorchMergeTree( + layers_[l].getOriginPrime(), originsPrimeCopy_[l]); + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int l = 0; l < originsCopy_.size(); ++l) { + fillMergeTreeStructure(originsCopy_[l]); + postprocessingPipeline(&(originsCopy_[l].mTree.tree)); + fillMergeTreeStructure(originsPrimeCopy_[l]); + postprocessingPipeline(&(originsPrimeCopy_[l].mTree.tree)); + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < recs_.size(); ++i) { + for(unsigned int j = 0; j < recs_[i].size(); ++j) { + fixTreePrecisionScalars(recs_[i][j].mTree); + postprocessingPipeline(&(recs_[i][j].mTree.tree)); + } + } + } + + if(not isPersistenceDiagram_) { + for(unsigned int l = 0; l < originsMatchings_.size(); ++l) { + auto &tree1 = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]); + auto &tree2 = (l == 0 ? originsPrimeCopy_[0] : originsPrimeCopy_[l]); + convertBranchDecompositionMatching( + &(tree1.mTree.tree), &(tree2.mTree.tree), originsMatchings_[l]); + } +#ifdef TTK_ENABLE_OPENMP +#pragma omp parallel for schedule(dynamic) \ + num_threads(this->threadNumber_) if(parallelize_) +#endif + for(unsigned int i = 0; i < recs_.size(); ++i) { + for(unsigned int l = 0; l < dataMatchings_.size(); ++l) { + auto &origin = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]); + convertBranchDecompositionMatching(&(origin.mTree.tree), + &(recs_[i][l].mTree.tree), + dataMatchings_[l][i]); + } + } + for(unsigned int i = 0; i < reconstMatchings_.size(); ++i) { + auto l = recs_[i].size() - 1; + convertBranchDecompositionMatching(&(recs_[i][0].mTree.tree), + &(recs_[i][l].mTree.tree), + reconstMatchings_[i]); + } + } +#ifdef TTK_ENABLE_OPENMP + omp_set_nested(ompNested); +#endif +#endif +} diff --git a/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.h b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.h new file mode 100644 index 0000000000..9c41c3391b --- /dev/null +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeNeuralNetwork.h @@ -0,0 +1,880 @@ +/// \ingroup base +/// \class ttk::MergeTreeNeuralNetwork +/// \author Mathieu Pont +/// \date 2023. +/// +/// This module defines the %MergeTreeNeuralNetwork class providing functions to +/// define a neural network able to process merge trees or persistence diagrams. +/// +/// This is an abstract class, to implement a derived class you need to define +/// the following functions: +/// +/// - "initParameters" : initializes the network, like the different layers. +/// A strategy to initialize a sequence of layers consist in initializing a +/// first layer with the input topological representations, then pass them +/// through this layer to initialize the second one and so on. +/// A simple loop (whose number of iterations corresponds to the number of +/// layers) can do this using the "initInputBasis" and the "initOutputBasis" +/// function, then the "initGetReconstructed" function to pass the +/// representations to the layer that just have been initialized. +/// +/// - "initResetOutputBasis" : please refer to the documentation of this +/// function. +/// +/// - "customInit" : called just before the "initStep" function (that call the +/// "initParameters" function), is is intended to do custom operations depending +/// on the architecture and the optimization you want to define (such as +/// computing the distance matrix for the metric loss in the autoencoder case). +/// This function can be empty. +/// +/// - "backwardStep" : optimizes the parameters of the network. A loss using +/// differentiable torch operations should be computed using the output of some +/// layers of the network (usually the output of the last layer but it can also +/// be any other layers). You can either use the torch coordinates of the +/// representations in a layer or their torch tensors to compute the loss. Then +/// use the torch::Tensor "backward" function to compute the gradients, then the +/// torch::optim::Optimizer "step" function to update the model parameters, +/// after this, the torch::optim::Optimizer "zero_grad" function should be +/// called to reset the gradient. If you have correctly created the +/// MergeTreeNeuralLayer objects (refer to the corresponding class +/// documentation), basically by calling the "requires_grad" function (with true +/// as parameter) for each layer after initializing its parameters, then +/// everything would be automatically handled to backpropagate the gradient of +/// the loss through the layers. +/// +/// - "addCustomParameters" : adds custom parameters to the parameter list that +/// will be given to the optimizer, depending on the architecture and the +/// optimization you want to define (such as the centroids for the cluster loss +/// in the autoencoder case). This function can be empty. +/// +/// - "computeOneLoss" : computes the loss for one input topological +/// representation, the loss computed here does not need to be differentiable +/// because it will only be used to print it in the console and to check +/// convergence of the method (i.e. it is not called in the "backwardStep" +/// function). +/// +/// - "computeCustomLosses" : computes custom losses for all input topological +/// representations depending on the architecture and the optimization you want +/// to define (such as the clustering and the metric loss in the autoendoer +/// case). Like "computeOneLoss", the losses do not need to be differentiable +/// because they will only be used to print them in the console and to check +/// convergence of the method. This function can be empty. +/// +/// - "computeIterationTotalLoss" +/// +/// - "printCustomLosses" : prints the custom loss depending on the architecture +/// and the optimization you want to define (such as the clustering and the +/// metric loss in the autoendoer case). This function can be empty. +/// +/// - "printGapLoss" : prints the "gap" loss, the aggregated loss over +/// iterationGap_ iterations. +/// +/// - "copyCustomParams" : copy the custom parameters (for instance to save them +/// when a better loss is reached during the optimization) depending on the +/// architecture and the optimization you want to define (such as the centroids +/// for the cluster loss in the autoencoder case). This function can be empty. +/// +/// - "executeEndFunction" : does specific operations at the end of the +/// optimization, such as calling the "computeTrackingInformation" and the +/// "computeCorrelationMatrix" functions. +/// +/// \b Related \b publication: \n +/// "Wasserstein Auto-Encoders of Merge Trees (and Persistence Diagrams)" \n +/// Mathieu Pont, Julien Tierny.\n +/// IEEE Transactions on Visualization and Computer Graphics, 2023 +/// + +#pragma once + +// ttk common includes +#include +#include +#include +#include +#include + +#ifdef TTK_ENABLE_TORCH +#include +#endif + +namespace ttk { + + /** + * The MergeTreeNeuralNetwork class provides methods to define a neural + * network able to process merge trees or persistence diagrams. + */ + class MergeTreeNeuralNetwork : virtual public Debug, + public MergeTreeNeuralBase { + + protected: + // Minimum number of iterations to run + unsigned int minIteration_ = 0; + // Maximum number of iterations to run + unsigned int maxIteration_ = 0; + // Number of iterations between each print + unsigned int iterationGap_ = 100; + // Batch size between 0 and 1 + double batchSize_ = 1; + // Optimizer + // 0 : Adam + // 1 : Stochastic Gradient Descent + // 2 : RMS Prop + int optimizer_ = 0; + // Gradient Step/Learning rate + double gradientStepSize_ = 0.1; + // Adam parameters + double beta1_ = 0.9; + double beta2_ = 0.999; + // Number of initializations to do (the better will be kept) + unsigned int noInit_ = 4; + // If activation functions should be used during the initialization + bool activateOutputInit_ = false; + // Limit in the size of the origin in output basis as a percentage of the + // input total number of nodes + double originPrimeSizePercent_ = 15; + // Proportion between the train set and the validation/test set + double trainTestSplit_ = 1.0; + // If the input data should be shuffled before splitted + bool shuffleBeforeSplit_ = true; + + bool createOutput_ = true; + +#ifdef TTK_ENABLE_TORCH + // Model optimized parameters + std::vector layers_; + + // Filled by the algorithm + std::vector> allAlphas_, allScaledAlphas_, + allActAlphas_, allActScaledAlphas_; + std::vector>> recs_, recs2_; + + std::vector> originsCopy_, originsPrimeCopy_; +#endif + + // Tracking matchings + std::vector>> + originsMatchings_, reconstMatchings_, customMatchings_; + std::vector< + std::vector>>> + dataMatchings_; + + // Filled by the algorithm + unsigned noLayers_; + float bestLoss_; + std::vector clusterAsgn_; + std::vector>> + baryMatchings_L0_, baryMatchings2_L0_; + std::vector inputToBaryDistances_L0_; + std::vector> branchesCorrelationMatrix_, + persCorrelationMatrix_; + + // Testing + double t_allVectorCopy_time_ = 0.0; + std::vector originsNoZeroGrad_, originsPrimeNoZeroGrad_, + vSNoZeroGrad_, vSPrimeNoZeroGrad_, origins2NoZeroGrad_, + origins2PrimeNoZeroGrad_, vS2NoZeroGrad_, vS2PrimeNoZeroGrad_; + + public: + MergeTreeNeuralNetwork(); + +#ifdef TTK_ENABLE_TORCH + // ----------------------------------------------------------------------- + // --- Init + // ----------------------------------------------------------------------- + /** + * @brief Initialize an input basis. + * + * @param[in] l index of the layer to initialize. + * @param[in] layerNoAxes number of axes in the basis. + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[out] allAlphasInit best estimation of the coordinates of each + * input tree in the basis. + */ + void initInputBasis(unsigned int l, + unsigned int layerNoAxes, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain, + std::vector> &allAlphasInit); + + /** + * @brief Initialize an output basis. + * + * @param[in] l index of the layer to initialize. + * @param[in] layerOriginPrimeSizePercent the maximum number of nodes + * allowed for the origin as a percentage of the total number of nodes + * in the input trees (0 for no effect). + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + */ + void initOutputBasis(unsigned int l, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain); + + /** + * @brief It is possible for the output basis of a layer to be badly + * initialized such a merge tree has no nodes after being passed through it. + * This function is intended to initialize a new output basis to avoid this + * problem. It is called in the initGetReconstructed function during the + * initialization procedure. + * A simple call to initOutputBasis could be enough. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] l index of the layer to initialize. + * @param[in] layerNoAxes number of axes in the basis. + * @param[in] layerOriginPrimeSizePercent the maximum number of nodes + * allowed for the origin as a percentage of the total number of nodes + * in the input trees (0 for no effect). + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * + * @return true if it is not possible to initialize a new output basis, + * false otherwise. + */ + virtual bool + initResetOutputBasis(unsigned int l, + unsigned int layerNoAxes, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain) + = 0; + + /** + * @brief Pass trees through a layer that just have been initialized. + * + * @param[in] l index of the layer being initialized. + * @param[in] layerNoAxes number of axes in the basis. + * @param[in] layerOriginPrimeSizePercent the maximum number of nodes + * allowed for the origin as a percentage of the total number of nodes + * in the input trees (0 for no effect). + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[out] recs the input trees after being passed through the layer. + * @param[out] recs2 same as recs but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[out] allAlphasInit best estimation of the coordinates of each + * input tree in the basis. + * + * @return true if one output merge tree has no nodes. + */ + bool initGetReconstructed( + unsigned int l, + unsigned int layerNoAxes, + double layerOriginPrimeSizePercent, + std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain, + std::vector> &recs, + std::vector> &recs2, + std::vector> &allAlphasInit); + + /** + * @brief Initialize the parameters of the network. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[in] computeError boolean stating whether or not the initialization + * error should be computed. + * + * @return return an initialization error, a value stating how bad the + * initialization is (the higher the worst). + */ + virtual float + initParameters(std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain, + bool computeError = false) + = 0; + + /** + * @brief Initialize the parameters of the network a specific number of + * times and keep the best one. + * + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + */ + void initStep(std::vector> &trees, + std::vector> &trees2, + std::vector &isTrain); + + /** + * @brief Pass all the parameters from this class to a MergeTreeNeuralLayer + * object. + * Should be called just after creating a layer and before using it. + * + * @param[in] layer layer to process. + */ + void passLayerParameters(MergeTreeNeuralLayer &layer); + + // ----------------------------------------------------------------------- + // --- Forward + // ----------------------------------------------------------------------- + /** + * @brief Pass one tree through all layers of the network. + * + * @param[in] tree input tree. + * @param[in] trees2 same as tree but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] treeIndex index of the input tree in the ensemble. + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] alphasInit the initial coordinates (for each layer) to use + * when estimating the coordinates in the input basis of the input merge + * tree. + * @param[out] out the final output merge tree. + * @param[out] out2 same as out but for second input, if any (i.e. when join + * trees and split trees are given). + * @param[out] dataAlphas the best estimated coordinates of the input tree + * at each layer. + * @param[out] outs the output merge tree of each layer (except the last + * one). + * @param[out] outs2 same as outs but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] train true if the input merge tree is in the training set + * (false if validation/testing set). + * + * @return true if the output merge tree of a layer has no nodes. + */ + bool forwardOneData(mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &tree2, + unsigned int treeIndex, + unsigned int k, + std::vector &alphasInit, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &out2, + std::vector &dataAlphas, + std::vector> &outs, + std::vector> &outs2, + bool train = false); + + /** + * @brief Pass all trees through all layers of the network. + * + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] indexes batch indexes of the input trees to process. + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] allAlphasInit the initial coordinates for each input tree for + * each layer to use when estimating the coordinates in the input basis. + * @param[in] computeError true if the loss for each processed input tree + * should be computed, false otherwise. + * @param[out] outs the final output of each merge tree. + * @param[out] outs2 same as out but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[out] bestAlphas the best estimated coordinates for each input tree + * at each layer. + * @param[out] layersOuts the output for each input merge tree of each layer + * (except the last one). + * @param[out] layersOuts2 same as layersOuts but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[out] matchings stores a matching for each processed input tree if + * the loss involves an assignment problem. + * @param[out] matchings2 same as matchings but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[out] loss the loss of the trees in the training set. + * @param[out] testLoss the loss of the trees not in the training set. + * + * @return true if an output merge tree of a layer has no nodes. + */ + bool forwardStep( + std::vector> &trees, + std::vector> &trees2, + std::vector &indexes, + std::vector &isTrain, + unsigned int k, + std::vector> &allAlphasInit, + bool computeError, + std::vector> &outs, + std::vector> &outs2, + std::vector> &bestAlphas, + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector>> + &matchings, + std::vector>> + &matchings2, + float &loss, + float &testLoss); + + /** + * @brief Pass all trees through all layers of the network. + * + * @param[in] trees input trees. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] indexes batch indexes of the input trees to process. + * @param[in] k number of projection steps to do when estimating the + * coordinates in the input basis of the input merge tree. + * @param[in] allAlphasInit the initial coordinates for each input tree for + * each layer to use when estimating the coordinates in the input basis. + * @param[in] computeError true if the loss for each processed input tree + * should be computed, false otherwise. + * @param[out] outs the final output of each merge tree. + * @param[out] outs2 same as out but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[out] bestAlphas the best estimated coordinates for each input tree + * at each layer. + * @param[out] layersOuts the output for each input merge tree of each layer + * (except the last one). + * @param[out] layersOuts2 same as layersOuts but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[out] matchings stores a matching for each processed input tree if + * the loss involves an assignment problem. + * @param[out] matchings2 same as matchings but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[out] loss the loss of the trees in the training set. + * + * @return true if an output merge tree of a layer has no nodes. + */ + bool forwardStep( + std::vector> &trees, + std::vector> &trees2, + std::vector &indexes, + unsigned int k, + std::vector> &allAlphasInit, + bool computeError, + std::vector> &outs, + std::vector> &outs2, + std::vector> &bestAlphas, + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector>> + &matchings, + std::vector>> + &matchings2, + float &loss); + + // ----------------------------------------------------------------------- + // --- Backward + // ----------------------------------------------------------------------- + /** + * @brief Updates the parameters of the network to minimize the error. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] trees input trees. + * @param[in] outs the final output of each merge tree. + * @param[in] matchings stores a matching for each processed input tree if + * the loss involves an assignment problem. + * @param[in] trees2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] outs2 same as out but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] matchings2 same as matchings but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] alphas the best estimated coordinates for each input tree + * at each layer. + * @param[in] optimizer optimizer to use to modify the parameters. + * @param[in] indexes batch indexes of the input trees to process. + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[in] torchCustomLoss custom losses that can be added to the + * optimization (such as the loss ensuring the preservation of the clusters + * or the distances in the autoencoder case). + * + * @return not used (false). + */ + virtual bool backwardStep( + std::vector> &trees, + std::vector> &outs, + std::vector>> + &matchings, + std::vector> &trees2, + std::vector> &outs2, + std::vector>> + &matchings2, + std::vector> &alphas, + torch::optim::Optimizer &optimizer, + std::vector &indexes, + std::vector &isTrain, + std::vector &torchCustomLoss) + = 0; + + // ----------------------------------------------------------------------- + // --- Projection + // ----------------------------------------------------------------------- + /** + * @brief Projection that ensures that the origins of the input and output + * bases of each layer respect the elder rule. + */ + void projectionStep(); + + // ----------------------------------------------------------------------- + // --- Convergence + // ----------------------------------------------------------------------- + /** + * @brief Computes the loss for one input tree. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] tree input tree. + * @param[in] out the final output of the input merge tree. + * @param[in] tree2 same as trees but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] out2 same as out but for second input, if any (i.e. when + * join trees and split trees are given). + * @param[in] matchings stores a matching involving the input tree if + * the loss involves an assignment problem. + * @param[in] matchings2 same as matchings but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] alphas the best estimated coordinates for the input tree + * at each layer. + * @param[in] treeIndex index of the input tree in the ensemble. + * + * @return loss value. + */ + virtual float computeOneLoss( + mtu::TorchMergeTree &tree, + mtu::TorchMergeTree &out, + mtu::TorchMergeTree &tree2, + mtu::TorchMergeTree &out2, + std::vector> &matching, + std::vector> &matching2, + std::vector &alphas, + unsigned int treeIndex) + = 0; + + /** + * @brief Tests if the current loss is the lowest one. + * + * @param[in] loss the current loss. + * @param[in,out] minLoss the minimum loss that was reached, will be updated + * if it is higher than the current loss. + * @param[in,out] cptBlocked value that will be reset to 0 if the current + * loss is better than the best one. + * + * @return true if the current loss is the lowest one. + */ + bool isBestLoss(float loss, float &minLoss, unsigned int &cptBlocked); + + /** + * @brief Tests if the optimization is done. + * + * @param[in] loss the current loss. + * @param[in,out] oldLoss the loss of the previous iteration, will be set to + * the current loss at the end of this function. + * @param[in] minLoss the minimum loss that was reached. + * @param[in,out] cptBlocked number of iterations during the minimum loss + * was not updated. + * + * @return true if the optimization is done. + */ + bool convergenceStep(float loss, + float &oldLoss, + float &minLoss, + unsigned int &cptBlocked); + + // ----------------------------------------------------------------------- + // --- Main Functions + // ----------------------------------------------------------------------- + /** + * @brief Custom operations that need to be done before starting the + * optimization. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] torchTrees the input trees. + * @param[in] torchTrees2 same as torchTrees but for second input, if any + * (i.e. when join trees and split trees are given). + */ + virtual void + customInit(std::vector> &torchTrees, + std::vector> &torchTrees2) + = 0; + + /** + * @brief This function adds parameters to the parameter list that will be + * given to the optimizer. The vector should NOT be reinitialized in this + * function, only use emplace_back. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in,out] parameters list of parameters to optimize that will be + * given to the optimizer, should not be reset. + */ + virtual void addCustomParameters(std::vector ¶meters) + = 0; + + /** + * @brief Compute the custom losses (such as the metric or cluster loss in + * the autoencoder case). + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] layersOuts the output for each input merge tree of each layer + * (except the last one). + * @param[in] layersOuts2 same as layersOuts but for second input, if any + * (i.e. when join trees and split trees are given). + * @param[in] bestAlphas the best estimated coordinates for each input tree + * at each layer. + * @param[in] indexes batch indexes of the input trees to process. + * @param[in] isTrain vector stating for each input tree if it is in the + * training set (true) or not (false). + * @param[in] iteration iteration number. + * @param[out] gapCustomLosses vector needing to be resized in order to have + * a size corresponding to the number of different custom losses and should + * be appended each custom loss computed here. + * @param[out] iterationCustomLosses vector needing to be resized in order + * to have a size corresponding to the number of different custom losses and + * should be appended each custom loss computed here. + * @param[out] torchCustomLoss vector needing to be resized in order to have + * a size corresponding to the number of different custom losses and should + * be appended each custom loss computed here. + */ + virtual void computeCustomLosses( + std::vector>> &layersOuts, + std::vector>> &layersOuts2, + std::vector> &bestAlphas, + std::vector &indexes, + std::vector &isTrain, + unsigned int iteration, + std::vector> &gapCustomLosses, + std::vector> &iterationCustomLosses, + std::vector &torchCustomLoss) + = 0; + + /** + * @brief Computes the total loss of the current iteration. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] iterationLoss loss of the current iteration. + * @param[in] iterationCustomLosses all the custom losses of this iteration, + * one for each different custom loss for each batch. + * @param[out] iterationCustomLoss a vector containing the aggregated value + * of each custom loss (usually the mean over each batch). + * + * @return the total loss of the current iteration. + */ + virtual float computeIterationTotalLoss( + float iterationLoss, + std::vector> &iterationCustomLosses, + std::vector &iterationCustomLoss) + = 0; + + /** + * @brief Print the custom losses. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] customLoss a vector containing the aggregated value + * of each custom loss (usually the mean over each batch). + * @param[in] prefix the prefix of each message. + * @param[in] priority the priority of the TTK message. + */ + virtual void printCustomLosses(std::vector &customLoss, + std::stringstream &prefix, + const debug::Priority &priority + = debug::Priority::INFO) + = 0; + + /** + * @brief Print the gap loss (the aggregated loss over iterationGap_ + * iterations). + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] loss the loss to print. + * @param[in] gapCustomLosses the custom losses to print, a vector + * containing a vector for each custom loss. + */ + virtual void printGapLoss(float loss, + std::vector> &gapCustomLosses) + = 0; + + /** + * @brief Initiliazes the network and trains it. + * + * @param[in] trees the input trees. + * @param[in] trees2 same as torchTrees but for second input, if any (i.e. + * when join trees and split trees are given). + */ + void fit(std::vector> &trees, + std::vector> &trees2); + + // --------------------------------------------------------------------------- + // --- End Functions + // --------------------------------------------------------------------------- + /** + * @brief Computes the tracking information, i.e, the matchings between the + * origin of the output basis of two consecutive layers (from the first one + * to the one specified by the endLayer parameter), the matchings + * between the input representations and the origin of the input basis of + * the first layer, and the matching between the latter and the origin of + * the output basis of the first layer. + * + * @param[in] endLayer layer number to stop the computation. + */ + void computeTrackingInformation(unsigned int endLayer); + + /** + * @brief Computes the correlation matrix between the pairs of the input + * trees and the basis at the layer specified in parameter. The + * "computeTrackingInformation" should have been called before. + * + * @param[in] trees the input trees. + * @param[in] layer the layer at which the correlation should be computed. + */ + void computeCorrelationMatrix(std::vector> &trees, + unsigned int layer); + + /** + * @brief Scales the coordinates given in input by the norm of the basis at + * the corresponding layer. + * + * @param[in] alphas coordinates for each input topological representation + * for each layer. + * @param[out] scaledAlphas scaled coordinates. + */ + void + createScaledAlphas(std::vector> &alphas, + std::vector> &scaledAlphas); + + /** + * @brief Scales the coordinates of the input topological representations by + * the norm of the basis at the corresponding layer. + */ + void createScaledAlphas(); + + /** + * @brief Scales the activated coordinates (the coordinates passed through + * the activation function) of the input topological representations by the + * norm of the basis at the corresponding layer. + */ + void createActivatedAlphas(); + + // ----------------------------------------------------------------------- + // --- Utils + // ----------------------------------------------------------------------- + void copyParams(std::vector> &origins, + std::vector> &originsPrime, + std::vector &vS, + std::vector &vSPrime, + std::vector> &origins2, + std::vector> &origins2Prime, + std::vector &vS2, + std::vector &vS2Prime, + std::vector> &srcAlphas, + std::vector> &dstAlphas, + bool get); + + void copyParams(std::vector>> &src, + std::vector>> &dst); + + /** + * @brief Set/Get custom parameters. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] get if the internal parameters should set from a copy (true) + * or copied (false). + */ + virtual void copyCustomParams(bool get) = 0; + + /** + * @brief Construct a matrix as a torch tensor object containing all the + * coordinates of the input topological representations in a specific layer. + * + * @param[in] alphas all the coordinates of the input topological + * representations. + * @param[in] indexes indexes of the topological representations to get. + * @param[in] toGet boolean vector stating if a topological representation + * should be processed or not. + * @param[in] layerIndex index of the layer to process. + * @param[in] alphasOut output torch tensor. + */ + void getAlphasTensor(std::vector> &alphas, + std::vector &indexes, + std::vector &toGet, + unsigned int layerIndex, + torch::Tensor &alphasOut); + + /** + * @brief Construct a matrix as a torch tensor object containing all the + * coordinates of the input topological representations in a specific layer. + * + * @param[in] alphas all the coordinates of the input topological + * representations. + * @param[in] indexes indexes of the topological representations to get. + * @param[in] layerIndex index of the layer to process. + * @param[in] alphasOut output torch tensor. + */ + void getAlphasTensor(std::vector> &alphas, + std::vector &indexes, + unsigned int layerIndex, + torch::Tensor &alphasOut); + + /** + * @brief Construct a matrix as a torch tensor object containing all the + * coordinates of the input topological representations in a specific layer. + * + * @param[in] alphas all the coordinates of the input topological + * representations. + * @param[in] layerIndex index of the layer to process. + * @param[in] alphasOut output torch tensor. + */ + void getAlphasTensor(std::vector> &alphas, + unsigned int layerIndex, + torch::Tensor &alphasOut); + + // ----------------------------------------------------------------------- + // --- Testing + // ----------------------------------------------------------------------- + void checkZeroGrad(unsigned int l, bool checkOutputBasis = true); + + bool isTreeHasBigValues(const ftm::MergeTree &mTree, + float threshold = 10000); +#endif + + // --------------------------------------------------------------------------- + // --- Main Functions + // --------------------------------------------------------------------------- + /** + * @brief Specific operations that can be done at the end of the + * optimization. Like calling the "computeTrackingInformation" and the + * "computeCorrelationMatrix" functions. + * + * This is a pure virtual function to define in derived classes. + * + * @param[in] trees the input trees. + * @param[in] trees2 same as torchTrees but for second input, if any (i.e. + * when join trees and split trees are given). + */ + virtual void executeEndFunction(std::vector> &trees, + std::vector> &trees2) + = 0; + + void execute(std::vector> &trees, + std::vector> &trees2); + }; // MergeTreeNeuralNetwork class + +} // namespace ttk diff --git a/core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.cpp b/core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.cpp similarity index 85% rename from core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.cpp rename to core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.cpp index 249ca81213..888fe39d28 100644 --- a/core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.cpp +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.cpp @@ -6,7 +6,7 @@ using namespace ttk; #ifdef TTK_ENABLE_TORCH using namespace torch::indexing; -void mtu::copyTensor(torch::Tensor &a, torch::Tensor &b) { +void mtu::copyTensor(const torch::Tensor &a, torch::Tensor &b) { b = a.detach().clone(); b.requires_grad_(a.requires_grad()); } @@ -19,8 +19,8 @@ void mtu::getDeltaProjTensor(torch::Tensor &diagTensor, deltaProjTensor = torch::cat({deltaProjTensor, deltaProjTensor}, 1); } -void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, +void mtu::dataReorderingGivenMatching(const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, torch::Tensor &tree1ProjIndexer, torch::Tensor &tree2ReorderingIndexes, torch::Tensor &tree2ReorderedTensor, @@ -30,12 +30,16 @@ void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree &tree, bool doubleReordering) { // Reorder tree2 tensor torch::Tensor tree2DiagTensor = tree2.tensor.reshape({-1, 2}); - tree2ReorderedTensor = torch::cat({tree2DiagTensor, torch::zeros({1, 2})}); + auto zeros = torch::zeros( + {1, 2}, torch::TensorOptions().device(tree2DiagTensor.device())); + tree2ReorderedTensor = torch::cat({tree2DiagTensor, zeros}); tree2ReorderedTensor = tree2ReorderedTensor.index({tree2ReorderingIndexes}); // Create tree projection given matching torch::Tensor treeDiagTensor = tree.tensor.reshape({-1, 2}); getDeltaProjTensor(treeDiagTensor, tree2DeltaProjTensor); + if(!tree2DeltaProjTensor.device().is_cpu()) + tree1ProjIndexer = tree1ProjIndexer.to(tree2DeltaProjTensor.device()); tree2DeltaProjTensor = tree2DeltaProjTensor * tree1ProjIndexer; // Double reordering @@ -59,8 +63,8 @@ void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree &tree, tree2DeltaProjTensor = tree2DeltaProjTensor.reshape({-1, 1}); } -void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, +void mtu::dataReorderingGivenMatching(const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, torch::Tensor &tree1ProjIndexer, torch::Tensor &tree2ReorderingIndexes, torch::Tensor &tree2ReorderedTensor, @@ -75,8 +79,8 @@ void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree &tree, } void mtu::dataReorderingGivenMatching( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, std::vector> &matching, torch::Tensor &tree1ReorderedTensor, torch::Tensor &tree2ReorderedTensor, @@ -108,8 +112,8 @@ void mtu::dataReorderingGivenMatching( } void mtu::dataReorderingGivenMatching( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, std::vector> &matching, torch::Tensor &tree2ReorderedTensor) { torch::Tensor tree1ReorderedTensor; @@ -123,7 +127,8 @@ void mtu::meanBirthShift(torch::Tensor &diagTensor, torch::Tensor birthShiftValue = diagBaseTensor.index({Slice(), 0}).mean() - diagTensor.index({Slice(), 0}).mean(); torch::Tensor shiftTensor - = torch::full({diagTensor.sizes()[0], 2}, birthShiftValue.item()); + = torch::full({diagTensor.sizes()[0], 2}, birthShiftValue.item(), + torch::TensorOptions().device(diagTensor.device())); diagTensor.index_put_({None}, diagTensor + shiftTensor); } @@ -139,6 +144,8 @@ void mtu::meanBirthMaxPersShift(torch::Tensor &tensor, = (diagTensor.index({Slice(), 1}) - diagTensor.index({Slice(), 0})).max(); torch::Tensor shiftTensor = (baseMaxPers - maxPers) / 2.0; shiftTensor = torch::stack({-shiftTensor, shiftTensor}); + if(!diagTensor.device().is_cpu()) + shiftTensor = shiftTensor.to(diagTensor.device()); diagTensor.index_put_({None}, diagTensor + shiftTensor); // Shift to have same birth mean meanBirthShift(diagTensor, diagBaseTensor); @@ -158,10 +165,11 @@ void mtu::belowDiagonalPointsShift(torch::Tensor &tensor, = (goodPoints.index({Slice(), 1}) - goodPoints.index({Slice(), 0})) .median(); torch::Tensor shiftTensor - = (torch::full({badPoints.sizes()[0], 1}, pers.item()) - - badPoints.index({Slice(), 1}).reshape({-1, 1}) - + badPoints.index({Slice(), 0}).reshape({-1, 1})) - / 2.0; + = torch::full({badPoints.sizes()[0], 1}, pers.item(), + torch::TensorOptions().device(badPoints.device())); + shiftTensor = (shiftTensor - badPoints.index({Slice(), 1}).reshape({-1, 1}) + + badPoints.index({Slice(), 0}).reshape({-1, 1})) + / 2.0; shiftTensor = torch::cat({-shiftTensor, shiftTensor}, 1); badPoints = badPoints + shiftTensor; // Update tensor @@ -202,6 +210,8 @@ bool mtu::isThereMissingPairs(mtu::TorchMergeTree &interpolation) { - interTensor.reshape({-1, 2}).index({Slice(), 1})) > (maxPers * 0.001 / 100.0); torch::Tensor indexed = interTensor.reshape({-1, 2}).index({indexer}); - return indexed.sizes()[0] > interpolation.mTree.tree.getRealNumberOfNodes(); + bool isMissingPairs + = indexed.sizes()[0] > interpolation.mTree.tree.getRealNumberOfNodes(); + return isMissingPairs; } #endif diff --git a/core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.h b/core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.h similarity index 95% rename from core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.h rename to core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.h index fe25b1cde6..9bceec200c 100644 --- a/core/base/mergeTreeAutoencoder/MergeTreeTorchUtils.h +++ b/core/base/mergeTreeNeuralNetwork/MergeTreeTorchUtils.h @@ -28,7 +28,7 @@ namespace ttk { * @param[in] a input tensor. * @param[out] b copied output tensor. */ - void copyTensor(torch::Tensor &a, torch::Tensor &b); + void copyTensor(const torch::Tensor &a, torch::Tensor &b); template struct TorchMergeTree { @@ -65,8 +65,8 @@ namespace ttk { * second tree. * @param[in] doubleReordering choose to also reorder first tree. */ - void dataReorderingGivenMatching(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + void dataReorderingGivenMatching(const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, torch::Tensor &tree1ProjIndexer, torch::Tensor &tree2ReorderingIndexes, torch::Tensor &tree2ReorderedTensor, @@ -89,8 +89,8 @@ namespace ttk { * @param[out] tree2DeltaProjTensor tensor of the projected pairs on the * diagonal of the first tree in the second tree. */ - void dataReorderingGivenMatching(mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + void dataReorderingGivenMatching(const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, torch::Tensor &tree1ProjIndexer, torch::Tensor &tree2ReorderingIndexes, torch::Tensor &tree2ReorderedTensor, @@ -107,8 +107,8 @@ namespace ttk { * @param[in] doubleReordering choose to also reorder first tree. */ void dataReorderingGivenMatching( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, std::vector> &matching, torch::Tensor &tree1ReorderedTensor, torch::Tensor &tree2ReorderedTensor, @@ -123,8 +123,8 @@ namespace ttk { * @param[out] tree2ReorderedTensor reordered torch tensor of second tree. */ void dataReorderingGivenMatching( - mtu::TorchMergeTree &tree, - mtu::TorchMergeTree &tree2, + const mtu::TorchMergeTree &tree, + const mtu::TorchMergeTree &tree2, std::vector> &matching, torch::Tensor &tree2ReorderedTensor); @@ -196,7 +196,7 @@ namespace ttk { * @param[out] out output copied torch merge tree. */ template - void copyTorchMergeTree(TorchMergeTree &tmTree, + void copyTorchMergeTree(const TorchMergeTree &tmTree, TorchMergeTree &out) { out.mTree = ftm::copyMergeTree(tmTree.mTree); copyTensor(tmTree.tensor, out.tensor); @@ -390,7 +390,6 @@ namespace ttk { bool normalized, ftm::MergeTree &mTreeOut) { std::vector &nodeCorr = tmt.nodeCorr; - torch::Tensor &tensor = tmt.tensor; std::vector &parentsOri = tmt.parentsOri; mTreeOut = ttk::ftm::copyMergeTree(tmt.mTree); @@ -405,6 +404,9 @@ namespace ttk { return true; bool isJT = tmt.mTree.tree.template isJoinTree(); + torch::Tensor tensor = tmt.tensor; + if(!tensor.device().is_cpu()) + tensor = tensor.cpu(); std::vector tensorVec( tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); std::vector scalarsVector; @@ -538,8 +540,8 @@ namespace ttk { */ template void getTensorMatching( - TorchMergeTree &a, - TorchMergeTree &b, + const TorchMergeTree &a, + const TorchMergeTree &b, std::vector> &matching, std::vector &tensorMatching) { tensorMatching.clear(); @@ -562,8 +564,8 @@ namespace ttk { */ template void getInverseTensorMatching( - TorchMergeTree &a, - TorchMergeTree &b, + const TorchMergeTree &a, + const TorchMergeTree &b, std::vector> &matching, std::vector &tensorMatching) { std::vector> invMatching( diff --git a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.cpp b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.cpp index 4d94f7f838..5a50e8791f 100644 --- a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.cpp +++ b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.cpp @@ -1,12 +1,24 @@ #include -void ttk::MergeTreeAxesAlgorithmBase::reverseMatchingVector( - unsigned int noNodes, - std::vector &matchingVector, - std::vector &invMatchingVector) { - invMatchingVector.clear(); - invMatchingVector.resize(noNodes, std::numeric_limits::max()); - for(unsigned int i = 0; i < matchingVector.size(); ++i) - if(matchingVector[i] < noNodes) - invMatchingVector[matchingVector[i]] = i; +//---------------------------------------------------------------------------- +// Setter +//---------------------------------------------------------------------------- +void ttk::MergeTreeAxesAlgorithmBase::setDeterministic( + const bool deterministic) { + deterministic_ = deterministic; +} + +void ttk::MergeTreeAxesAlgorithmBase::setNumberOfProjectionSteps( + const unsigned int k) { + k_ = k; +} + +void ttk::MergeTreeAxesAlgorithmBase::setBarycenterSizeLimitPercent( + const double barycenterSizeLimitPercent) { + barycenterSizeLimitPercent_ = barycenterSizeLimitPercent; +} + +void ttk::MergeTreeAxesAlgorithmBase::setProbabilisticVectorsInit( + const bool probabilisticVectorsInit) { + probabilisticVectorsInit_ = probabilisticVectorsInit; } diff --git a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.h b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.h index c8865368cf..85afff4ad7 100644 --- a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.h +++ b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmBase.h @@ -11,6 +11,7 @@ #pragma once #include +#include #include #include #include @@ -29,6 +30,7 @@ namespace ttk { unsigned int numberOfAxes_ = 2; unsigned int k_ = 16; double barycenterSizeLimitPercent_ = 20.0; + bool probabilisticVectorsInit_ = false; // Clean correspondence std::vector> trees2NodeCorr_; @@ -41,13 +43,24 @@ namespace ttk { // every msg } + //---------------------------------------------------------------------------- + // Setter + //---------------------------------------------------------------------------- + void setDeterministic(const bool deterministic); + + void setNumberOfProjectionSteps(const unsigned int k); + + void setBarycenterSizeLimitPercent(const double barycenterSizeLimitPercent); + + void setProbabilisticVectorsInit(const bool probabilisticVectorsInit); + //---------------------------------------------------------------------------- // Matching / Distance //---------------------------------------------------------------------------- template void computeOneDistance( - ftm::MergeTree &tree1, - ftm::MergeTree &tree2, + const ftm::MergeTree &tree1, + const ftm::MergeTree &tree2, std::vector> &matching, dataType &distance, bool isCalled = false, @@ -74,8 +87,8 @@ namespace ttk { } template - void computeOneDistance(ftm::MergeTree &tree1, - ftm::MergeTree &tree2, + void computeOneDistance(const ftm::MergeTree &tree1, + const ftm::MergeTree &tree2, dataType &distance, bool isCalled = false, bool useDoubleInput = false, @@ -98,7 +111,8 @@ namespace ttk { ftm::FTMTree_MT *treeTree = &(tree.tree); std::vector matchingVector; - getMatchingVector(barycenter, tree, matching, matchingVector); + ttk::axa::getMatchingVector( + barycenter, tree, matching, matchingVector); v.resize(barycenter.tree.getNumberOfNodes(), std::vector(2, 0)); for(unsigned int j = 0; j < barycenter.tree.getNumberOfNodes(); ++j) { @@ -218,12 +232,21 @@ namespace ttk { } // Sort all distances and their respective indexes - if(axeNumber != 0) - std::sort(distancesAndIndexes.begin(), distancesAndIndexes.end(), - [](const std::tuple &a, - const std::tuple &b) -> bool { - return (std::get<0>(a) > std::get<0>(b)); - }); + std::vector scores; + std::random_device rd; + std::default_random_engine generator(deterministic_ ? 0 : rd()); + if(axeNumber != 0) { + if(probabilisticVectorsInit_) { + scores.resize(distancesAndIndexes.size()); + for(unsigned int i = 0; i < distancesAndIndexes.size(); ++i) + scores[i] = std::get<0>(distancesAndIndexes[i]); + } else + std::sort(distancesAndIndexes.begin(), distancesAndIndexes.end(), + [](const std::tuple &a, + const std::tuple &b) -> bool { + return (std::get<0>(a) > std::get<0>(b)); + }); + } // Init vectors according farthest input // (repeat with the ith farthest until projection gives non null vector) @@ -282,10 +305,15 @@ namespace ttk { // Init next bestIndex if(not foundGoodIndex) { i += 1; - if(i < distancesAndIndexes.size()) - bestIndex = std::get<1>(distancesAndIndexes[i]); - else + if(i >= distancesAndIndexes.size()) bestIndex = -1; + else if(probabilisticVectorsInit_) { + scores[bestIndex] = 0; + std::discrete_distribution distribution( + scores.begin(), scores.end()); + bestIndex = distribution(generator); + } else + bestIndex = std::get<1>(distancesAndIndexes[i]); } // If newVector jump to the next valid bestIndex @@ -312,6 +340,8 @@ namespace ttk { std::vector &finalDistances, double barycenterSizeLimitPercent, unsigned int barycenterMaximumNumberOfPairs, + int barycenterInitIndex, + bool oneIter, bool useDoubleInput = false, bool isFirstInput = true) { MergeTreeBarycenter mergeTreeBary; @@ -324,9 +354,13 @@ namespace ttk { mergeTreeBary.setAssignmentSolver(assignmentSolverID_); mergeTreeBary.setThreadNumber(this->threadNumber_); mergeTreeBary.setDeterministic(deterministic_); + mergeTreeBary.setIsPersistenceDiagram(isPersistenceDiagram_); mergeTreeBary.setBarycenterSizeLimitPercent(barycenterSizeLimitPercent); mergeTreeBary.setBarycenterMaximumNumberOfPairs( barycenterMaximumNumberOfPairs); + mergeTreeBary.setBarycenterInitIndex(barycenterInitIndex); + if(oneIter) + mergeTreeBary.setBarycenterMaxIter(1); matchings.resize(trees.size()); mergeTreeBary.execute( @@ -334,6 +368,23 @@ namespace ttk { finalDistances = mergeTreeBary.getFinalDistances(); } + template + void computeOneBarycenter( + std::vector> &trees, + ftm::MergeTree &baryMergeTree, + std::vector>> + &matchings, + std::vector &finalDistances, + double barycenterSizeLimitPercent, + unsigned int barycenterMaximumNumberOfPairs, + bool useDoubleInput = false, + bool isFirstInput = true) { + computeOneBarycenter(trees, baryMergeTree, matchings, finalDistances, + barycenterSizeLimitPercent, + barycenterMaximumNumberOfPairs, -1, false, + useDoubleInput, isFirstInput); + } + template void computeOneBarycenter( std::vector> &trees, @@ -416,76 +467,6 @@ namespace ttk { //---------------------------------------------------------------------------- // Utils //---------------------------------------------------------------------------- - // v[i] contains the node in tree matched to the node i in barycenter - template - void getMatchingVector( - ftm::MergeTree &barycenter, - ftm::MergeTree &tree, - std::vector> &matchings, - std::vector &matchingVector) { - matchingVector.clear(); - matchingVector.resize(barycenter.tree.getNumberOfNodes(), - std::numeric_limits::max()); - for(unsigned int j = 0; j < matchings.size(); ++j) { - auto match0 = std::get<0>(matchings[j]); - auto match1 = std::get<1>(matchings[j]); - if(match0 < barycenter.tree.getNumberOfNodes() - and match1 < tree.tree.getNumberOfNodes()) - matchingVector[match0] = match1; - } - } - - // v[i] contains the node in barycenter matched to the node i in tree - template - void getInverseMatchingVector( - ftm::MergeTree &barycenter, - ftm::MergeTree &tree, - std::vector> &matchings, - std::vector &matchingVector) { - std::vector> invMatchings( - matchings.size()); - for(unsigned int i = 0; i < matchings.size(); ++i) - invMatchings[i] = std::make_tuple(std::get<1>(matchings[i]), - std::get<0>(matchings[i]), - std::get<2>(matchings[i])); - getMatchingVector(tree, barycenter, invMatchings, matchingVector); - } - - void reverseMatchingVector(unsigned int noNodes, - std::vector &matchingVector, - std::vector &invMatchingVector); - - template - void reverseMatchingVector(ftm::MergeTree &tree, - std::vector &matchingVector, - std::vector &invMatchingVector) { - reverseMatchingVector( - tree.tree.getNumberOfNodes(), matchingVector, invMatchingVector); - } - - // m[i][j] contains the node in trees[j] matched to the node i in the - // barycenter - template - void getMatchingMatrix( - ftm::MergeTree &barycenter, - std::vector> &trees, - std::vector>> - &matchings, - std::vector> &matchingMatrix) { - matchingMatrix.clear(); - matchingMatrix.resize( - barycenter.tree.getNumberOfNodes(), - std::vector( - trees.size(), std::numeric_limits::max())); - for(unsigned int i = 0; i < trees.size(); ++i) { - std::vector matchingVector; - getMatchingVector( - barycenter, trees[i], matchings[i], matchingVector); - for(unsigned int j = 0; j < matchingVector.size(); ++j) - matchingMatrix[j][i] = matchingVector[j]; - } - } - template std::tuple getParametrizedBirthDeath(ftm::FTMTree_MT *tree, ftm::idNode node) { @@ -498,7 +479,7 @@ namespace ttk { //---------------------------------------------------------------------------- template void computeBranchesCorrelationMatrix( - ftm::MergeTree &barycenter, + const ftm::MergeTree &barycenter, std::vector> &trees, std::vector>> &baryMatchings, @@ -512,7 +493,8 @@ namespace ttk { // m[i][j] contains the node in trees[j] matched to the node i in the // barycenter std::vector> matchingMatrix; - getMatchingMatrix(barycenter, trees, baryMatchings, matchingMatrix); + ttk::axa::getMatchingMatrix( + barycenter, trees, baryMatchings, matchingMatrix); std::queue queue; queue.emplace(barycenter.tree.getRoot()); diff --git a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.cpp b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.cpp index cf13a66590..3ffe313781 100644 --- a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.cpp +++ b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.cpp @@ -2,6 +2,20 @@ namespace ttk { namespace axa { + //---------------------------------------------------------------------------- + // Utils + //---------------------------------------------------------------------------- + void reverseMatchingVector(unsigned int noNodes, + std::vector &matchingVector, + std::vector &invMatchingVector) { + invMatchingVector.clear(); + invMatchingVector.resize( + noNodes, std::numeric_limits::max()); + for(unsigned int i = 0; i < matchingVector.size(); ++i) + if(matchingVector[i] < noNodes) + invMatchingVector[matchingVector[i]] = i; + } + //---------------------------------------------------------------------------- // Output Utils //---------------------------------------------------------------------------- diff --git a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.h b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.h index a1f1cd8585..d2758a82b7 100644 --- a/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.h +++ b/core/base/mergeTreePrincipalGeodesics/MergeTreeAxesAlgorithmUtils.h @@ -1,9 +1,83 @@ #pragma once +#include #include namespace ttk { namespace axa { + //---------------------------------------------------------------------------- + // Utils + //---------------------------------------------------------------------------- + // v[i] contains the node in tree matched to the node i in barycenter + template + void getMatchingVector( + const ftm::MergeTree &barycenter, + const ftm::MergeTree &tree, + std::vector> &matchings, + std::vector &matchingVector) { + matchingVector.clear(); + matchingVector.resize(barycenter.tree.getNumberOfNodes(), + std::numeric_limits::max()); + for(unsigned int j = 0; j < matchings.size(); ++j) { + auto &match0 = std::get<0>(matchings[j]); + auto &match1 = std::get<1>(matchings[j]); + if(match0 < barycenter.tree.getNumberOfNodes() + and match1 < tree.tree.getNumberOfNodes()) + matchingVector[match0] = match1; + } + } + + // v[i] contains the node in barycenter matched to the node i in tree + template + void getInverseMatchingVector( + const ftm::MergeTree &barycenter, + const ftm::MergeTree &tree, + std::vector> &matchings, + std::vector &matchingVector) { + std::vector> invMatchings( + matchings.size()); + for(unsigned int i = 0; i < matchings.size(); ++i) + invMatchings[i] = std::make_tuple(std::get<1>(matchings[i]), + std::get<0>(matchings[i]), + std::get<2>(matchings[i])); + getMatchingVector(tree, barycenter, invMatchings, matchingVector); + } + + void reverseMatchingVector(unsigned int noNodes, + std::vector &matchingVector, + std::vector &invMatchingVector); + + template + void reverseMatchingVector(ftm::MergeTree &tree, + std::vector &matchingVector, + std::vector &invMatchingVector) { + reverseMatchingVector( + tree.tree.getNumberOfNodes(), matchingVector, invMatchingVector); + } + + // m[i][j] contains the node in trees[j] matched to the node i in the + // barycenter + template + void getMatchingMatrix( + const ftm::MergeTree &barycenter, + std::vector> &trees, + std::vector>> + &matchings, + std::vector> &matchingMatrix) { + matchingMatrix.clear(); + matchingMatrix.resize( + barycenter.tree.getNumberOfNodes(), + std::vector( + trees.size(), std::numeric_limits::max())); + for(unsigned int i = 0; i < trees.size(); ++i) { + std::vector matchingVector; + getMatchingVector( + barycenter, trees[i], matchings[i], matchingVector); + for(unsigned int j = 0; j < matchingVector.size(); ++j) + matchingMatrix[j][i] = matchingVector[j]; + } + } + //---------------------------------------------------------------------------- // Output Utils //---------------------------------------------------------------------------- diff --git a/core/base/mergeTreePrincipalGeodesics/MergeTreePrincipalGeodesics.h b/core/base/mergeTreePrincipalGeodesics/MergeTreePrincipalGeodesics.h index b5228ab3af..aaa24103fa 100644 --- a/core/base/mergeTreePrincipalGeodesics/MergeTreePrincipalGeodesics.h +++ b/core/base/mergeTreePrincipalGeodesics/MergeTreePrincipalGeodesics.h @@ -171,7 +171,8 @@ namespace ttk { if(extremityTree->getRealNumberOfNodes() != 0) { computeOneDistance(barycenter, extremity, matching, distance, true, useDoubleInput, isFirstInput); - getMatchingVector(barycenter, extremity, matching, matchingVector); + ttk::axa::getMatchingVector( + barycenter, extremity, matching, matchingVector); } else matchingVector.resize(barycenterTree->getNumberOfNodes(), std::numeric_limits::max()); @@ -545,7 +546,7 @@ namespace ttk { // Get matching matrix std::vector> matchingMatrix; - getMatchingMatrix(barycenter, trees, matchings, matchingMatrix); + ttk::axa::getMatchingMatrix(barycenter, trees, matchings, matchingMatrix); // Update for(unsigned int i = 0; i < barycenter.tree.getNumberOfNodes(); ++i) { diff --git a/core/vtk/ttkMergeTreeAutoencoder/ttk.module b/core/vtk/ttkMergeTreeAutoencoder/ttk.module index 3bffe68277..18ca77041c 100644 --- a/core/vtk/ttkMergeTreeAutoencoder/ttk.module +++ b/core/vtk/ttkMergeTreeAutoencoder/ttk.module @@ -2,10 +2,10 @@ NAME ttkMergeTreeAutoencoder SOURCES ttkMergeTreeAutoencoder.cpp - ttkMergeTreeAutoencoderUtils.cpp + ttkMergeTreeNeuralNetworkUtils.cpp HEADERS ttkMergeTreeAutoencoder.h - ttkMergeTreeAutoencoderUtils.h + ttkMergeTreeNeuralNetworkUtils.h DEPENDS mergeTreeAutoencoder ttkMergeTree diff --git a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.cpp b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.cpp index 09676e0f38..ec888dcabf 100644 --- a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.cpp +++ b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.cpp @@ -1,14 +1,13 @@ #include #include -#include +#include #include #include -#include - #include #include #include +#include #include #include #include @@ -237,279 +236,52 @@ int ttkMergeTreeAutoencoder::runOutput( // ------------------------------------------ // --- Tracking information // ------------------------------------------ - auto originsMatchingSize = originsMatchings_.size(); - std::vector> originsMatchingVectorT( - originsMatchingSize), - invOriginsMatchingVectorT = originsMatchingVectorT; - for(unsigned int l = 0; l < originsMatchingVectorT.size(); ++l) { - auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]); - getMatchingVector(tree1.mTree, tree2.mTree, originsMatchings_[l], - originsMatchingVectorT[l]); - getInverseMatchingVector(tree1.mTree, tree2.mTree, originsMatchings_[l], - invOriginsMatchingVectorT[l]); - } + std::vector> originsMatchingVectorT, + invOriginsMatchingVectorT; + std::vector>> + invDataMatchingVectorT; + std::vector> invReconstMatchingVectorT; + ttk::wnn::makeMatchingVectors( + originsMatchings_, originsCopy_, originsPrimeCopy_, originsMatchingVectorT, + invOriginsMatchingVectorT, dataMatchings_, recs_, invDataMatchingVectorT, + reconstMatchings_, invReconstMatchingVectorT); + std::vector> originsMatchingVector; std::vector> originsPersPercent, originsPersDiff; std::vector originPersPercent, originPersDiff; std::vector originPersistenceOrder; - ttk::wae::computeTrackingInformation( - origins_, originsPrime_, originsMatchingVectorT, invOriginsMatchingVectorT, - isPersistenceDiagram_, originsMatchingVector, originsPersPercent, - originsPersDiff, originPersPercent, originPersDiff, originPersistenceOrder); - - std::vector>> - invDataMatchingVectorT(dataMatchings_.size()); - for(unsigned int l = 0; l < invDataMatchingVectorT.size(); ++l) { - invDataMatchingVectorT[l].resize(dataMatchings_[l].size()); - for(unsigned int i = 0; i < invDataMatchingVectorT[l].size(); ++i) - getInverseMatchingVector(origins_[l].mTree, recs_[i][l].mTree, - dataMatchings_[l][i], - invDataMatchingVectorT[l][i]); - } - std::vector> invReconstMatchingVectorT( - reconstMatchings_.size()); - for(unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) { - auto l = recs_[i].size() - 1; - getInverseMatchingVector(recs_[i][0].mTree, recs_[i][l].mTree, - reconstMatchings_[i], - invReconstMatchingVectorT[i]); - } + ttk::wnn::computeTrackingInformation( + originsCopy_, originsPrimeCopy_, originsMatchingVectorT, + invOriginsMatchingVectorT, isPersistenceDiagram_, originsMatchingVector, + originsPersPercent, originsPersDiff, originPersPercent, originPersDiff, + originPersistenceOrder); // ------------------------------------------ // --- Data // ------------------------------------------ - output_data->SetNumberOfBlocks(1); - vtkSmartPointer data - = vtkSmartPointer::New(); - data->SetNumberOfBlocks(1); - vtkSmartPointer dataSeg - = vtkSmartPointer::New(); - dataSeg->SetNumberOfBlocks(recs_.size()); - bool outputSegmentation = !treesSegmentation.empty() and treesSegmentation[0]; - for(unsigned int l = 0; l < 1; ++l) { - vtkSmartPointer out_layer_i - = vtkSmartPointer::New(); - out_layer_i->SetNumberOfBlocks(recs_.size()); - std::vector *> trees(recs_.size()); - for(unsigned int i = 0; i < recs_.size(); ++i) - trees[i] = &(recs_[i][l].mTree); - - // Custom arrays - std::vector>>> - customIntArrays(recs_.size()); - std::vector>>> - customDoubleArrays(recs_.size()); - unsigned int lShift = 0; - ttk::wae::computeCustomArrays( - recs_, persCorrelationMatrix_, invDataMatchingVectorT, - invReconstMatchingVectorT, originsMatchingVector, originsMatchingVectorT, - originsPersPercent, originsPersDiff, originPersistenceOrder, l, lShift, - customIntArrays, customDoubleArrays); - - // Create output - ttk::wae::makeManyOutput(trees, treesNodes, treesNodeCorr_, out_layer_i, - customIntArrays, customDoubleArrays, - mixtureCoefficient_, isPersistenceDiagram_, - convertToDiagram_, this->debugLevel_); - if(outputSegmentation and l == 0) { - ttk::wae::makeManyOutput( - trees, treesNodes, treesNodeCorr_, treesSegmentation, dataSeg, - customIntArrays, customDoubleArrays, mixtureCoefficient_, - isPersistenceDiagram_, convertToDiagram_, this->debugLevel_); - } - data->SetBlock(l, out_layer_i); - std::stringstream ss; - ss << (l == 0 ? "Input" : "Layer") << l; - data->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str()); - } - output_data->SetBlock(0, data); - unsigned int num = 0; - output_data->GetMetaData(num)->Set( - vtkCompositeDataSet::NAME(), "layersTrees"); - if(outputSegmentation) - output_data->SetBlock(1, dataSeg); - vtkNew lossArray{}; - lossArray->SetName("Loss"); - lossArray->InsertNextTuple1(bestLoss_); - output_data->GetFieldData()->AddArray(lossArray); + ttk::wnn::makeDataOutput( + output_data, recs_, 1, treesSegmentation, persCorrelationMatrix_, + invDataMatchingVectorT, invReconstMatchingVectorT, originsMatchingVectorT, + originsMatchingVector, originsPersPercent, originsPersDiff, + originPersistenceOrder, treesNodes, treesNodeCorr_, bestLoss_, + mixtureCoefficient_, isPersistenceDiagram_, convertToDiagram_, + this->debugLevel_); // ------------------------------------------ // --- Origins // ------------------------------------------ - output_origins->SetNumberOfBlocks(2); - // Origins - vtkSmartPointer origins - = vtkSmartPointer::New(); - vtkSmartPointer originsP - = vtkSmartPointer::New(); - origins->SetNumberOfBlocks(noLayers_); - originsP->SetNumberOfBlocks(noLayers_); - std::vector *> trees(noLayers_); - std::vector>>> - customIntArrays(noLayers_); - std::vector>>> - customDoubleArrays(noLayers_); - for(unsigned int l = 0; l < noLayers_; ++l) { - trees[l] = &(origins_[l].mTree); - if(l == 0) { - std::string name2{"OriginPersPercent"}; - customDoubleArrays[l].emplace_back( - std::make_tuple(name2, originPersPercent)); - std::string name3{"OriginPersDiff"}; - customDoubleArrays[l].emplace_back( - std::make_tuple(name3, originPersDiff)); - std::string nameOrder{"OriginPersOrder"}; - customIntArrays[l].emplace_back( - std::make_tuple(nameOrder, originPersistenceOrder)); - } - } - ttk::wae::makeManyOutput(trees, origins, customIntArrays, customDoubleArrays, - mixtureCoefficient_, isPersistenceDiagram_, - convertToDiagram_, this->debugLevel_); - - customIntArrays.clear(); - customIntArrays.resize(noLayers_); - customDoubleArrays.clear(); - customDoubleArrays.resize(noLayers_); - for(unsigned int l = 0; l < noLayers_; ++l) { - trees[l] = &(originsPrime_[l].mTree); - if(l < originsMatchingVector.size()) { - std::vector customArrayMatching, - originPersOrder(trees[l]->tree.getNumberOfNodes(), -1); - for(unsigned int i = 0; i < originsMatchingVector[l].size(); ++i) { - customArrayMatching.emplace_back(originsMatchingVector[l][i]); - if(originsMatchingVector[l][i] < originPersistenceOrder.size()) - originPersOrder[i] - = originPersistenceOrder[originsMatchingVector[l][i]]; - } - std::string name{"OriginTrueNodeId"}; - customIntArrays[l].emplace_back( - std::make_tuple(name, customArrayMatching)); - std::string nameOrder{"OriginPersOrder"}; - customIntArrays[l].emplace_back( - std::make_tuple(nameOrder, originPersOrder)); - std::string name2{"OriginPersPercent"}; - customDoubleArrays[l].emplace_back( - std::make_tuple(name2, originsPersPercent[l])); - std::string name3{"OriginPersDiff"}; - customDoubleArrays[l].emplace_back( - std::make_tuple(name3, originsPersDiff[l])); - } - } - ttk::wae::makeManyOutput(trees, originsP, customIntArrays, customDoubleArrays, - mixtureCoefficient_, isPersistenceDiagram_, - convertToDiagram_, this->debugLevel_); - output_origins->SetBlock(0, origins); - output_origins->SetBlock(1, originsP); - // for(unsigned int l = 0; l < 2; ++l) { - for(unsigned int l = 0; l < noLayers_; ++l) { - if(l >= 2) - break; - std::stringstream ss; - ss << (l == 0 ? "InputOrigin" : "LayerOrigin") << l; - auto originsMetaData = origins->GetMetaData(l); - if(originsMetaData) - originsMetaData->Set(vtkCompositeDataSet::NAME(), ss.str()); - ss.str(""); - ss << (l == 0 ? "InputOriginPrime" : "LayerOriginPrime") << l; - auto originsPMetaData = originsP->GetMetaData(l); - if(originsPMetaData) - originsPMetaData->Set(vtkCompositeDataSet::NAME(), ss.str()); - } - num = 0; - output_origins->GetMetaData(num)->Set( - vtkCompositeDataSet::NAME(), "layersOrigins"); - num = 1; - output_origins->GetMetaData(num)->Set( - vtkCompositeDataSet::NAME(), "layersOriginsPrime"); + ttk::wnn::makeOriginsOutput( + output_origins, originsCopy_, originsPrimeCopy_, originPersPercent, + originPersDiff, originPersistenceOrder, originsMatchingVector, + originsPersPercent, originsPersDiff, mixtureCoefficient_, + isPersistenceDiagram_, convertToDiagram_, this->debugLevel_); // ------------------------------------------ // --- Coefficients // ------------------------------------------ - output_coef->SetNumberOfBlocks(allAlphas_[0].size()); - for(unsigned int l = 0; l < allAlphas_[0].size(); ++l) { - vtkSmartPointer coef_table = vtkSmartPointer::New(); - vtkNew treeIDArray{}; - treeIDArray->SetName("TreeID"); - treeIDArray->SetNumberOfTuples(inputTrees.size()); - for(unsigned int i = 0; i < inputTrees.size(); ++i) - treeIDArray->SetTuple1(i, i); - coef_table->AddColumn(treeIDArray); - auto noVec = allAlphas_[0][l].sizes()[0]; - for(unsigned int v = 0; v < noVec; ++v) { - // Alphas - vtkNew tArray{}; - std::string name = ttk::axa::getTableCoefficientName(noVec, v); - tArray->SetName(name.c_str()); - tArray->SetNumberOfTuples(allAlphas_.size()); - // Act Alphas - vtkNew actArray{}; - std::string actName = "Act" + name; - actArray->SetName(actName.c_str()); - actArray->SetNumberOfTuples(allAlphas_.size()); - // Scaled Alphas - vtkNew tArrayNorm{}; - std::string nameNorm = ttk::axa::getTableCoefficientNormName(noVec, v); - tArrayNorm->SetName(nameNorm.c_str()); - tArrayNorm->SetNumberOfTuples(allAlphas_.size()); - // Act Scaled Alphas - vtkNew actArrayNorm{}; - std::string actNameNorm = "Act" + nameNorm; - actArrayNorm->SetName(actNameNorm.c_str()); - actArrayNorm->SetNumberOfTuples(allAlphas_.size()); - // Fill Arrays - for(unsigned int i = 0; i < allAlphas_.size(); ++i) { - tArray->SetTuple1(i, allAlphas_[i][l][v].item()); - actArray->SetTuple1(i, allActAlphas_[i][l][v].item()); - tArrayNorm->SetTuple1(i, allScaledAlphas_[i][l][v].item()); - actArrayNorm->SetTuple1(i, allActScaledAlphas_[i][l][v].item()); - } - coef_table->AddColumn(tArray); - coef_table->AddColumn(actArray); - coef_table->AddColumn(tArrayNorm); - coef_table->AddColumn(actArrayNorm); - } - if(!clusterAsgn_.empty()) { - vtkNew clusterArray{}; - clusterArray->SetName("ClusterAssignment"); - clusterArray->SetNumberOfTuples(inputTrees.size()); - for(unsigned int i = 0; i < clusterAsgn_.size(); ++i) - clusterArray->SetTuple1(i, clusterAsgn_[i]); - coef_table->AddColumn(clusterArray); - } - if(l == 0) { - vtkNew treesNoNodesArray{}; - treesNoNodesArray->SetNumberOfTuples(recs_.size()); - treesNoNodesArray->SetName("treeNoNodes"); - for(unsigned int i = 0; i < recs_.size(); ++i) - treesNoNodesArray->SetTuple1( - i, recs_[i][0].mTree.tree.getNumberOfNodes()); - coef_table->AddColumn(treesNoNodesArray); - } - output_coef->SetBlock(l, coef_table); - std::stringstream ss; - ss << "Coef" << l; - output_coef->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str()); - } - - // Copy Field Data - // - aggregate input field data - for(unsigned int b = 0; b < inputTrees[0]->GetNumberOfBlocks(); ++b) { - vtkNew fd{}; - fd->CopyStructure(inputTrees[0]->GetBlock(b)->GetFieldData()); - fd->SetNumberOfTuples(inputTrees.size()); - for(size_t i = 0; i < inputTrees.size(); ++i) { - fd->SetTuple(i, 0, inputTrees[i]->GetBlock(b)->GetFieldData()); - } - - // - copy input field data to output row data - for(int i = 0; i < fd->GetNumberOfArrays(); ++i) { - auto array = fd->GetAbstractArray(i); - array->SetName(array->GetName()); - vtkTable::SafeDownCast(output_coef->GetBlock(0))->AddColumn(array); - } - } + ttk::wnn::makeCoefficientsOutput(output_coef, allAlphas_, allScaledAlphas_, + allActAlphas_, allActScaledAlphas_, + clusterAsgn_, recs_, inputTrees); // Field Data Input Parameters std::vector paramNames; @@ -537,58 +309,62 @@ int ttkMergeTreeAutoencoder::runOutput( for(unsigned int l = 0; l < dataMatchingVectorT.size(); ++l) { dataMatchingVectorT[l].resize(dataMatchings_[l].size()); for(unsigned int i = 0; i < dataMatchingVectorT[l].size(); ++i) { - auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]); - getMatchingVector(origin.mTree, recs_[i][l].mTree, dataMatchings_[l][i], - dataMatchingVectorT[l][i]); + auto &origin = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]); + ttk::axa::getMatchingVector(origin.mTree, recs_[i][l].mTree, + dataMatchings_[l][i], + dataMatchingVectorT[l][i]); } } output_vectors->SetNumberOfBlocks(2); vtkSmartPointer vectors = vtkSmartPointer::New(); - vectors->SetNumberOfBlocks(vSTensor_.size()); + vectors->SetNumberOfBlocks(noLayers_); vtkSmartPointer vectorsPrime = vtkSmartPointer::New(); - vectorsPrime->SetNumberOfBlocks(vSTensor_.size()); - for(unsigned int l = 0; l < vSTensor_.size(); ++l) { + vectorsPrime->SetNumberOfBlocks(noLayers_); + for(unsigned int l = 0; l < noLayers_; ++l) { vtkSmartPointer vectorsTable = vtkSmartPointer::New(); vtkSmartPointer vectorsPrimeTable = vtkSmartPointer::New(); - for(unsigned int v = 0; v < vSTensor_[l].sizes()[1]; ++v) { + for(unsigned int v = 0; v < layers_[l].getVSTensor().sizes()[1]; ++v) { // Vs vtkNew vectorArray{}; - std::string name - = ttk::axa::getTableVectorName(vSTensor_[l].sizes()[1], v, 0, 0, false); + std::string name = ttk::axa::getTableVectorName( + layers_[l].getVSTensor().sizes()[1], v, 0, 0, false); vectorArray->SetName(name.c_str()); - vectorArray->SetNumberOfTuples(vSTensor_[l].sizes()[0]); - for(unsigned int i = 0; i < vSTensor_[l].sizes()[0]; ++i) - vectorArray->SetTuple1(i, vSTensor_[l][i][v].item()); + vectorArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]); + for(unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i) + vectorArray->SetTuple1(i, layers_[l].getVSTensor()[i][v].item()); vectorsTable->AddColumn(vectorArray); // Vs Prime vtkNew vectorPrimeArray{}; - std::string name2 - = ttk::axa::getTableVectorName(vSTensor_[l].sizes()[1], v, 0, 0, false); + std::string name2 = ttk::axa::getTableVectorName( + layers_[l].getVSTensor().sizes()[1], v, 0, 0, false); vectorPrimeArray->SetName(name2.c_str()); - vectorPrimeArray->SetNumberOfTuples(vSPrimeTensor_[l].sizes()[0]); - for(unsigned int i = 0; i < vSPrimeTensor_[l].sizes()[0]; ++i) - vectorPrimeArray->SetTuple1(i, vSPrimeTensor_[l][i][v].item()); + vectorPrimeArray->SetNumberOfTuples( + layers_[l].getVSPrimeTensor().sizes()[0]); + for(unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i) + vectorPrimeArray->SetTuple1( + i, layers_[l].getVSPrimeTensor()[i][v].item()); vectorsPrimeTable->AddColumn(vectorPrimeArray); } // Rev node corr vtkNew revNodeCorrArray{}; revNodeCorrArray->SetName("revNodeCorr"); - revNodeCorrArray->SetNumberOfTuples(vSTensor_[l].sizes()[0]); + revNodeCorrArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]); std::vector revNodeCorr; - getReverseTorchNodeCorr(origins_[l], revNodeCorr); - for(unsigned int i = 0; i < vSTensor_[l].sizes()[0]; ++i) + getReverseTorchNodeCorr(originsCopy_[l], revNodeCorr); + for(unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i) revNodeCorrArray->SetTuple1(i, revNodeCorr[i]); vectorsTable->AddColumn(revNodeCorrArray); // Rev node corr prime vtkNew revNodeCorrPrimeArray{}; - revNodeCorrPrimeArray->SetNumberOfTuples(vSPrimeTensor_[l].sizes()[0]); + revNodeCorrPrimeArray->SetNumberOfTuples( + layers_[l].getVSPrimeTensor().sizes()[0]); revNodeCorrPrimeArray->SetName("revNodeCorr"); std::vector revNodeCorrPrime; - getReverseTorchNodeCorr(originsPrime_[l], revNodeCorrPrime); - for(unsigned int i = 0; i < vSPrimeTensor_[l].sizes()[0]; ++i) + getReverseTorchNodeCorr(originsPrimeCopy_[l], revNodeCorrPrime); + for(unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i) revNodeCorrPrimeArray->SetTuple1(i, revNodeCorrPrime[i]); vectorsPrimeTable->AddColumn(revNodeCorrPrimeArray); // Origins Matchings @@ -627,7 +403,7 @@ int ttkMergeTreeAutoencoder::runOutput( if(l < dataMatchingVectorT.size() - 1) addDataMatchingArray(vectorsPrimeTable, dataMatchingVectorT[l + 1]); // Reconst Matchings - if(l == vSTensor_.size() - 1) { + if(l == noLayers_ - 1) { for(unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) { vtkNew matchingArray{}; matchingArray->SetNumberOfTuples(invReconstMatchingVectorT[i].size()); @@ -653,7 +429,7 @@ int ttkMergeTreeAutoencoder::runOutput( } output_vectors->SetBlock(0, vectors); output_vectors->SetBlock(1, vectorsPrime); - num = 0; + unsigned int num = 0; output_vectors->GetMetaData(num)->Set(vtkCompositeDataSet::NAME(), "Vectors"); num = 1; output_vectors->GetMetaData(num)->Set( diff --git a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.h b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.h index 656e643c18..9d4deb3961 100644 --- a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.h +++ b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoder.h @@ -99,15 +99,6 @@ class TTKMERGETREEAUTOENCODER_EXPORT ttkMergeTreeAutoencoder vtkAlgorithm::SetInputArrayToProcess(0, 2, 0, 6, name); } - void SetDoCompute(bool doCompute) { - doCompute_ = doCompute; - Modified(); - resetDataVisualization(); - } - bool GetDoCompute() { - return doCompute_; - } - void SetNormalizedWasserstein(bool nW) { normalizedWasserstein_ = nW; Modified(); diff --git a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.cpp b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.cpp similarity index 50% rename from core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.cpp rename to core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.cpp index a163934a14..b5767fc48a 100644 --- a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.cpp +++ b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.cpp @@ -1,9 +1,383 @@ #include -#include +#include + +#include +#include #ifdef TTK_ENABLE_TORCH namespace ttk { - namespace wae { + namespace wnn { + + void makeMatchingVectors( + std::vector>> + &originsMatchings, + std::vector> &originsCopy, + std::vector> &originsPrimeCopy, + std::vector> &originsMatchingVectorT, + std::vector> &invOriginsMatchingVectorT, + std::vector< + std::vector>>> + &dataMatchings, + std::vector>> &recs, + std::vector>> + &invDataMatchingVectorT, + std::vector>> + &reconstMatchings, + std::vector> &invReconstMatchingVectorT) { + originsMatchingVectorT.resize(originsMatchings.size()); + invOriginsMatchingVectorT = originsMatchingVectorT; + for(unsigned int l = 0; l < originsMatchingVectorT.size(); ++l) { + auto &tree1 = (l == 0 ? originsCopy[0] : originsPrimeCopy[l - 1]); + auto &tree2 = (l == 0 ? originsPrimeCopy[0] : originsPrimeCopy[l]); + ttk::axa::getMatchingVector(tree1.mTree, tree2.mTree, + originsMatchings[l], + originsMatchingVectorT[l]); + ttk::axa::getInverseMatchingVector(tree1.mTree, tree2.mTree, + originsMatchings[l], + invOriginsMatchingVectorT[l]); + } + + invDataMatchingVectorT.resize(dataMatchings.size()); + for(unsigned int l = 0; l < invDataMatchingVectorT.size(); ++l) { + invDataMatchingVectorT[l].resize(dataMatchings[l].size()); + for(unsigned int i = 0; i < invDataMatchingVectorT[l].size(); ++i) + ttk::axa::getInverseMatchingVector( + originsCopy[l].mTree, recs[i][l].mTree, dataMatchings[l][i], + invDataMatchingVectorT[l][i]); + } + invReconstMatchingVectorT.resize(reconstMatchings.size()); + for(unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) { + auto l = recs[i].size() - 1; + ttk::axa::getInverseMatchingVector(recs[i][0].mTree, recs[i][l].mTree, + reconstMatchings[i], + invReconstMatchingVectorT[i]); + } + } + + void makeDataOutput( + vtkMultiBlockDataSet *output_data, + std::vector>> &recs, + unsigned int recSize, + std::vector &treesSegmentation, + std::vector> &persCorrelationMatrix, + std::vector>> + &invDataMatchingVectorT, + std::vector> &invReconstMatchingVectorT, + std::vector> &originsMatchingVectorT, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + std::vector &originPersistenceOrder, + std::vector &treesNodes, + std::vector> &treesNodeCorr, + std::vector classId, + float bestLoss, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel) { + output_data->SetNumberOfBlocks(1); + vtkSmartPointer data + = vtkSmartPointer::New(); + data->SetNumberOfBlocks(1); + vtkSmartPointer dataSeg + = vtkSmartPointer::New(); + dataSeg->SetNumberOfBlocks(recs.size()); + bool outputSegmentation + = !treesSegmentation.empty() and treesSegmentation[0]; + for(unsigned int l = 0; l < recSize; ++l) { + vtkSmartPointer out_layer_i + = vtkSmartPointer::New(); + out_layer_i->SetNumberOfBlocks(recs.size()); + std::vector *> trees(recs.size()); + for(unsigned int i = 0; i < recs.size(); ++i) + trees[i] = &(recs[i][l].mTree); + + // Custom arrays + std::vector>>> + customIntArrays(recs.size()); + std::vector>>> + customDoubleArrays(recs.size()); + unsigned int lShift = 0; + ttk::wnn::computeCustomArrays( + recs, persCorrelationMatrix, invDataMatchingVectorT, + invReconstMatchingVectorT, originsMatchingVector, + originsMatchingVectorT, originsPersPercent, originsPersDiff, + originPersistenceOrder, l, lShift, customIntArrays, + customDoubleArrays); + if(!classId.empty()) { + for(unsigned int i = 0; i < recs.size(); ++i) + customIntArrays[i].emplace_back(std::make_tuple( + "ClassID", + std::vector( + recs[i][l].mTree.tree.getNumberOfNodes(), classId[i]))); + } + + // Create output + if(l == 0) + ttk::wnn::makeManyOutput( + trees, treesNodes, treesNodeCorr, out_layer_i, customIntArrays, + customDoubleArrays, mixtureCoefficient, isPersistenceDiagram, + convertToDiagram, debugLevel); + else + ttk::wnn::makeManyOutput(trees, out_layer_i, customIntArrays, + customDoubleArrays, mixtureCoefficient, + isPersistenceDiagram, convertToDiagram, + debugLevel); + if(outputSegmentation and l == 0) { + ttk::wnn::makeManyOutput( + trees, treesNodes, treesNodeCorr, treesSegmentation, dataSeg, + customIntArrays, customDoubleArrays, mixtureCoefficient, + isPersistenceDiagram, convertToDiagram, debugLevel); + } + data->SetBlock(l, out_layer_i); + std::stringstream ss; + ss << (l == 0 ? "Input" : "Layer") << l; + data->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str()); + } + output_data->SetBlock(0, data); + unsigned int num = 0; + output_data->GetMetaData(num)->Set( + vtkCompositeDataSet::NAME(), "layersTrees"); + if(outputSegmentation) + output_data->SetBlock(1, dataSeg); + vtkNew lossArray{}; + lossArray->SetName("Loss"); + lossArray->InsertNextTuple1(bestLoss); + output_data->GetFieldData()->AddArray(lossArray); + } + + void makeDataOutput( + vtkMultiBlockDataSet *output_data, + std::vector>> &recs, + unsigned int recSize, + std::vector &treesSegmentation, + std::vector> &persCorrelationMatrix, + std::vector>> + &invDataMatchingVectorT, + std::vector> &invReconstMatchingVectorT, + std::vector> &originsMatchingVectorT, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + std::vector &originPersistenceOrder, + std::vector &treesNodes, + std::vector> &treesNodeCorr, + float bestLoss, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel) { + std::vector classId; + makeDataOutput(output_data, recs, recSize, treesSegmentation, + persCorrelationMatrix, invDataMatchingVectorT, + invReconstMatchingVectorT, originsMatchingVectorT, + originsMatchingVector, originsPersPercent, originsPersDiff, + originPersistenceOrder, treesNodes, treesNodeCorr, classId, + bestLoss, mixtureCoefficient, isPersistenceDiagram, + convertToDiagram, debugLevel); + } + + void makeOriginsOutput( + vtkMultiBlockDataSet *output_origins, + std::vector> &originsCopy, + std::vector> &originsPrimeCopy, + std::vector &originPersPercent, + std::vector &originPersDiff, + std::vector &originPersistenceOrder, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel) { + unsigned int noLayers = originsCopy.size(); + + output_origins->SetNumberOfBlocks(2); + // Origins + vtkSmartPointer origins + = vtkSmartPointer::New(); + vtkSmartPointer originsP + = vtkSmartPointer::New(); + origins->SetNumberOfBlocks(noLayers); + originsP->SetNumberOfBlocks(noLayers); + std::vector *> trees(noLayers); + std::vector>>> + customIntArrays(noLayers); + std::vector>>> + customDoubleArrays(noLayers); + for(unsigned int l = 0; l < noLayers; ++l) { + trees[l] = &(originsCopy[l].mTree); + if(l == 0) { + std::string name2{"OriginPersPercent"}; + customDoubleArrays[l].emplace_back( + std::make_tuple(name2, originPersPercent)); + std::string name3{"OriginPersDiff"}; + customDoubleArrays[l].emplace_back( + std::make_tuple(name3, originPersDiff)); + std::string nameOrder{"OriginPersOrder"}; + customIntArrays[l].emplace_back( + std::make_tuple(nameOrder, originPersistenceOrder)); + } + } + ttk::wnn::makeManyOutput( + trees, origins, customIntArrays, customDoubleArrays, mixtureCoefficient, + isPersistenceDiagram, convertToDiagram, debugLevel); + + customIntArrays.clear(); + customIntArrays.resize(noLayers); + customDoubleArrays.clear(); + customDoubleArrays.resize(noLayers); + for(unsigned int l = 0; l < noLayers; ++l) { + trees[l] = &(originsPrimeCopy[l].mTree); + if(l < originsMatchingVector.size()) { + std::vector customArrayMatching, + originPersOrder(trees[l]->tree.getNumberOfNodes(), -1); + for(unsigned int i = 0; i < originsMatchingVector[l].size(); ++i) { + customArrayMatching.emplace_back(originsMatchingVector[l][i]); + if(originsMatchingVector[l][i] < originPersistenceOrder.size()) + originPersOrder[i] + = originPersistenceOrder[originsMatchingVector[l][i]]; + } + std::string name{"OriginTrueNodeId"}; + customIntArrays[l].emplace_back( + std::make_tuple(name, customArrayMatching)); + std::string nameOrder{"OriginPersOrder"}; + customIntArrays[l].emplace_back( + std::make_tuple(nameOrder, originPersOrder)); + std::string name2{"OriginPersPercent"}; + customDoubleArrays[l].emplace_back( + std::make_tuple(name2, originsPersPercent[l])); + std::string name3{"OriginPersDiff"}; + customDoubleArrays[l].emplace_back( + std::make_tuple(name3, originsPersDiff[l])); + } + } + ttk::wnn::makeManyOutput( + trees, originsP, customIntArrays, customDoubleArrays, + mixtureCoefficient, isPersistenceDiagram, convertToDiagram, debugLevel); + output_origins->SetBlock(0, origins); + output_origins->SetBlock(1, originsP); + // for(unsigned int l = 0; l < 2; ++l) { + for(unsigned int l = 0; l < noLayers; ++l) { + if(l >= 2) + break; + std::stringstream ss; + ss << (l == 0 ? "InputOrigin" : "LayerOrigin") << l; + auto originsMetaData = origins->GetMetaData(l); + if(originsMetaData) + originsMetaData->Set(vtkCompositeDataSet::NAME(), ss.str()); + ss.str(""); + ss << (l == 0 ? "InputOriginPrime" : "LayerOriginPrime") << l; + auto originsPMetaData = originsP->GetMetaData(l); + if(originsPMetaData) + originsPMetaData->Set(vtkCompositeDataSet::NAME(), ss.str()); + } + unsigned int num = 0; + output_origins->GetMetaData(num)->Set( + vtkCompositeDataSet::NAME(), "layersOrigins"); + num = 1; + output_origins->GetMetaData(num)->Set( + vtkCompositeDataSet::NAME(), "layersOriginsPrime"); + } + + void makeCoefficientsOutput( + vtkMultiBlockDataSet *output_coef, + std::vector> &allAlphas, + std::vector> &allScaledAlphas, + std::vector> &allActAlphas, + std::vector> &allActScaledAlphas, + std::vector &clusterAsgn, + std::vector>> &recs, + std::vector> &inputTrees) { + output_coef->SetNumberOfBlocks(allAlphas[0].size()); + for(unsigned int l = 0; l < allAlphas[0].size(); ++l) { + vtkSmartPointer coef_table = vtkSmartPointer::New(); + vtkNew treeIDArray{}; + treeIDArray->SetName("TreeID"); + treeIDArray->SetNumberOfTuples(inputTrees.size()); + for(unsigned int i = 0; i < inputTrees.size(); ++i) + treeIDArray->SetTuple1(i, i); + coef_table->AddColumn(treeIDArray); + auto noVec = allAlphas[0][l].sizes()[0]; + for(unsigned int v = 0; v < noVec; ++v) { + // Alphas + vtkNew tArray{}; + std::string name = ttk::axa::getTableCoefficientName(noVec, v); + tArray->SetName(name.c_str()); + tArray->SetNumberOfTuples(allAlphas.size()); + // Act Alphas + vtkNew actArray{}; + std::string actName = "Act" + name; + actArray->SetName(actName.c_str()); + actArray->SetNumberOfTuples(allAlphas.size()); + // Scaled Alphas + vtkNew tArrayNorm{}; + std::string nameNorm + = ttk::axa::getTableCoefficientNormName(noVec, v); + tArrayNorm->SetName(nameNorm.c_str()); + tArrayNorm->SetNumberOfTuples(allAlphas.size()); + // Act Scaled Alphas + vtkNew actArrayNorm{}; + std::string actNameNorm = "Act" + nameNorm; + actArrayNorm->SetName(actNameNorm.c_str()); + actArrayNorm->SetNumberOfTuples(allAlphas.size()); + // Fill Arrays + for(unsigned int i = 0; i < allAlphas.size(); ++i) { + tArray->SetTuple1(i, allAlphas[i][l][v].item()); + actArray->SetTuple1(i, allActAlphas[i][l][v].item()); + tArrayNorm->SetTuple1(i, allScaledAlphas[i][l][v].item()); + actArrayNorm->SetTuple1( + i, allActScaledAlphas[i][l][v].item()); + } + coef_table->AddColumn(tArray); + coef_table->AddColumn(actArray); + coef_table->AddColumn(tArrayNorm); + coef_table->AddColumn(actArrayNorm); + } + if(!clusterAsgn.empty()) { + vtkNew clusterArray{}; + clusterArray->SetName("ClusterAssignment"); + clusterArray->SetNumberOfTuples(inputTrees.size()); + for(unsigned int i = 0; i < clusterAsgn.size(); ++i) + clusterArray->SetTuple1(i, clusterAsgn[i]); + coef_table->AddColumn(clusterArray); + } + if(l == 0) { + vtkNew treesNoNodesArray{}; + treesNoNodesArray->SetNumberOfTuples(recs.size()); + treesNoNodesArray->SetName("treeNoNodes"); + for(unsigned int i = 0; i < recs.size(); ++i) + treesNoNodesArray->SetTuple1( + i, recs[i][0].mTree.tree.getNumberOfNodes()); + coef_table->AddColumn(treesNoNodesArray); + } + output_coef->SetBlock(l, coef_table); + std::stringstream ss; + ss << "Coef" << l; + output_coef->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str()); + } + + // Copy Field Data + // - aggregate input field data + for(unsigned int b = 0; b < inputTrees[0]->GetNumberOfBlocks(); ++b) { + vtkNew fd{}; + fd->CopyStructure(inputTrees[0]->GetBlock(b)->GetFieldData()); + fd->SetNumberOfTuples(inputTrees.size()); + for(size_t i = 0; i < inputTrees.size(); ++i) { + fd->SetTuple(i, 0, inputTrees[i]->GetBlock(b)->GetFieldData()); + } + + // - copy input field data to output row data + for(int i = 0; i < fd->GetNumberOfArrays(); ++i) { + auto array = fd->GetAbstractArray(i); + array->SetName(array->GetName()); + vtkTable::SafeDownCast(output_coef->GetBlock(0))->AddColumn(array); + } + } + } + void makeOneOutput( ttk::ftm::MergeTree &tree, vtkUnstructuredGrid *treeNodes, @@ -402,6 +776,7 @@ namespace ttk { } } } - } // namespace wae + + } // namespace wnn } // namespace ttk #endif diff --git a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.h b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.h similarity index 73% rename from core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.h rename to core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.h index ae65f5d414..d637b6e7af 100644 --- a/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeAutoencoderUtils.h +++ b/core/vtk/ttkMergeTreeAutoencoder/ttkMergeTreeNeuralNetworkUtils.h @@ -15,7 +15,95 @@ #ifdef TTK_ENABLE_TORCH namespace ttk { - namespace wae { + namespace wnn { + + void makeMatchingVectors( + std::vector>> + &originsMatchings, + std::vector> &originsCopy, + std::vector> &originsPrimeCopy, + std::vector> &originsMatchingVectorT, + std::vector> &invOriginsMatchingVectorT, + std::vector< + std::vector>>> + &dataMatchings, + std::vector>> &recs, + std::vector>> + &invDataMatchingVectorT, + std::vector>> + &reconstMatchings, + std::vector> &invReconstMatchingVectorT); + + void makeDataOutput( + vtkMultiBlockDataSet *output_data, + std::vector>> &recs, + unsigned int recSize, + std::vector &treesSegmentation, + std::vector> &persCorrelationMatrix, + std::vector>> + &invDataMatchingVectorT, + std::vector> &invReconstMatchingVectorT, + std::vector> &originsMatchingVectorT, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + std::vector &originPersistenceOrder, + std::vector &treesNodes, + std::vector> &treesNodeCorr, + std::vector classId, + float bestLoss, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel); + + void makeDataOutput( + vtkMultiBlockDataSet *output_data, + std::vector>> &recs, + unsigned int recSize, + std::vector &treesSegmentation, + std::vector> &persCorrelationMatrix, + std::vector>> + &invDataMatchingVectorT, + std::vector> &invReconstMatchingVectorT, + std::vector> &originsMatchingVectorT, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + std::vector &originPersistenceOrder, + std::vector &treesNodes, + std::vector> &treesNodeCorr, + float bestLoss, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel); + + void makeOriginsOutput( + vtkMultiBlockDataSet *output_origins, + std::vector> &originsCopy, + std::vector> &originsPrimeCopy, + std::vector &originPersPercent, + std::vector &originPersDiff, + std::vector &originPersistenceOrder, + std::vector> &originsMatchingVector, + std::vector> &originsPersPercent, + std::vector> &originsPersDiff, + double mixtureCoefficient, + bool isPersistenceDiagram, + bool convertToDiagram, + int debugLevel); + + void makeCoefficientsOutput( + vtkMultiBlockDataSet *output_coef, + std::vector> &allAlphas, + std::vector> &allScaledAlphas, + std::vector> &allActAlphas, + std::vector> &allActScaledAlphas, + std::vector &clusterAsgn, + std::vector>> &recs, + std::vector> &inputTrees); + /** * @brief Proxy function to use ttkMergeTreeVisualization to create the vtk * objects of a merge tree. @@ -249,6 +337,6 @@ namespace ttk { &customIntArrays, std::vector>>> &customDoubleArrays); - } // namespace wae + } // namespace wnn } // namespace ttk #endif diff --git a/core/vtk/ttkMergeTreeAutoencoderDecoding/ttk.module b/core/vtk/ttkMergeTreeAutoencoderDecoding/ttk.module index 1e1ff8d806..4a6b883ea8 100644 --- a/core/vtk/ttkMergeTreeAutoencoderDecoding/ttk.module +++ b/core/vtk/ttkMergeTreeAutoencoderDecoding/ttk.module @@ -6,6 +6,6 @@ HEADERS ttkMergeTreeAutoencoderDecoding.h DEPENDS mergeTreeAutoencoderDecoding + ttkMergeTreeAutoencoder ttkAlgorithm ttkMergeTree - ttkMergeTreeAutoencoder diff --git a/core/vtk/ttkMergeTreeAutoencoderDecoding/ttkMergeTreeAutoencoderDecoding.cpp b/core/vtk/ttkMergeTreeAutoencoderDecoding/ttkMergeTreeAutoencoderDecoding.cpp index e6f439c734..c97105629b 100644 --- a/core/vtk/ttkMergeTreeAutoencoderDecoding/ttkMergeTreeAutoencoderDecoding.cpp +++ b/core/vtk/ttkMergeTreeAutoencoderDecoding/ttkMergeTreeAutoencoderDecoding.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include @@ -217,8 +217,8 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( // ----------------- // Vectors // ----------------- - vSTensor_.resize(noLayers_); - vSPrimeTensor_.resize(noLayers_); + vSTensorCopy_.resize(noLayers_); + vSPrimeTensorCopy_.resize(noLayers_); auto vSPrime = vtkMultiBlockDataSet::SafeDownCast(vectors->GetBlock(1)); std::vector allRevNodeCorr(noLayers_), allRevNodeCorrPrime(noLayers_); @@ -227,7 +227,7 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( #ifdef TTK_ENABLE_OPENMP #pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) #endif - for(unsigned int l = 0; l < vSTensor_.size(); ++l) { + for(unsigned int l = 0; l < vSTensorCopy_.size(); ++l) { auto layerVectorsTable = vtkTable::SafeDownCast(vS->GetBlock(l)); auto layerVectorsPrimeTable = vtkTable::SafeDownCast(vSPrime->GetBlock(l)); auto noRows = layerVectorsTable->GetNumberOfRows(); @@ -248,8 +248,8 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( ->GetVariantValue(i) .ToFloat(); } - vSTensor_[l] = torch::tensor(vSTensor).reshape({noRows, allNoAxes[l]}); - vSPrimeTensor_[l] + vSTensorCopy_[l] = torch::tensor(vSTensor).reshape({noRows, allNoAxes[l]}); + vSPrimeTensorCopy_[l] = torch::tensor(vSPrimeTensor).reshape({noRows2, allNoAxes[l]}); allRevNodeCorr[l] = ttkUtils::GetPointer(vtkDataArray::SafeDownCast( @@ -287,9 +287,9 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( originsMatchingVectorT[l].resize(array->GetNumberOfTuples()); for(unsigned int i = 0; i < originsMatchingVectorT[l].size(); ++i) originsMatchingVectorT[l][i] = array->GetVariantValue(i).ToUnsignedInt(); - reverseMatchingVector(originsPrime_[l].mTree, - originsMatchingVectorT[l], - invOriginsMatchingVectorT[l]); + ttk::axa::reverseMatchingVector(originsPrimeCopy_[l].mTree, + originsMatchingVectorT[l], + invOriginsMatchingVectorT[l]); } auto dataMatchingSize = getLatentLayerIndex() + 2; std::vector>> dataMatchingVectorT( @@ -319,7 +319,7 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( ->GetVariantValue(i) .ToUnsignedInt() : recs_[i][l - 1].mTree.tree.getNumberOfNodes()); - reverseMatchingVector( + ttk::axa::reverseMatchingVector( noNodes, dataMatchingVectorT[l][i], invDataMatchingVectorT[l][i]); } } @@ -343,10 +343,11 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( std::vector> originsPersPercent, originsPersDiff; std::vector originPersPercent, originPersDiff; std::vector originPersistenceOrder; - ttk::wae::computeTrackingInformation( - origins_, originsPrime_, originsMatchingVectorT, invOriginsMatchingVectorT, - isPersistenceDiagram_, originsMatchingVector, originsPersPercent, - originsPersDiff, originPersPercent, originPersDiff, originPersistenceOrder); + ttk::wnn::computeTrackingInformation( + originsCopy_, originsPrimeCopy_, originsMatchingVectorT, + invOriginsMatchingVectorT, isPersistenceDiagram_, originsMatchingVector, + originsPersPercent, originsPersDiff, originPersPercent, originPersDiff, + originPersistenceOrder); // ------------------------------------------ // --- Data @@ -373,14 +374,14 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( std::vector>>> customDoubleArrays(recs_.size()); unsigned int lShift = 1; - ttk::wae::computeCustomArrays( + ttk::wnn::computeCustomArrays( recs_, persCorrelationMatrix_, invDataMatchingVectorT, invReconstMatchingVectorT, originsMatchingVector, originsMatchingVectorT, originsPersPercent, originsPersDiff, originPersistenceOrder, l, lShift, customIntArrays, customDoubleArrays); // Create output - ttk::wae::makeManyOutput(trees, out_layer_i, customIntArrays, + ttk::wnn::makeManyOutput(trees, out_layer_i, customIntArrays, customDoubleArrays, mixtureCoefficient_, isPersistenceDiagram_, convertToDiagram_, this->debugLevel_); @@ -408,8 +409,9 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( for(unsigned int i = 0; i < customRecs_.size(); ++i) { trees[i] = &(customRecs_[i].mTree); std::vector matchingVector; - getInverseMatchingVector(origins_[0].mTree, customRecs_[i].mTree, - customMatchings_[i], matchingVector); + ttk::axa::getInverseMatchingVector(originsCopy_[0].mTree, + customRecs_[i].mTree, + customMatchings_[i], matchingVector); customOriginPersOrder[i].resize( customRecs_[i].mTree.tree.getNumberOfNodes()); for(unsigned int j = 0; j < matchingVector.size(); ++j) { @@ -425,7 +427,7 @@ int ttkMergeTreeAutoencoderDecoding::RequestData( } vtkSmartPointer dataCustom = vtkSmartPointer::New(); - ttk::wae::makeManyOutput(trees, dataCustom, customRecsIntArrays, + ttk::wnn::makeManyOutput(trees, dataCustom, customRecsIntArrays, customRecsDoubleArrays, mixtureCoefficient_, isPersistenceDiagram_, convertToDiagram_, this->debugLevel_); diff --git a/core/vtk/ttkMergeTreePrincipalGeodesics/ttkMergeTreePrincipalGeodesics.cpp b/core/vtk/ttkMergeTreePrincipalGeodesics/ttkMergeTreePrincipalGeodesics.cpp index e1df8308ef..3e86e7b9c8 100644 --- a/core/vtk/ttkMergeTreePrincipalGeodesics/ttkMergeTreePrincipalGeodesics.cpp +++ b/core/vtk/ttkMergeTreePrincipalGeodesics/ttkMergeTreePrincipalGeodesics.cpp @@ -385,7 +385,7 @@ int ttkMergeTreePrincipalGeodesics::runOutput( // Tree matching std::vector> matchingMatrix; - getMatchingMatrix( + ttk::axa::getMatchingMatrix( barycenter_, intermediateDTrees, baryMatchings_, matchingMatrix); if(not normalizedWasserstein_) for(unsigned int j = 0; j < inputTrees.size(); ++j) diff --git a/core/vtk/ttkMergeTreePrincipalGeodesicsDecoding/ttkMergeTreePrincipalGeodesicsDecoding.cpp b/core/vtk/ttkMergeTreePrincipalGeodesicsDecoding/ttkMergeTreePrincipalGeodesicsDecoding.cpp index 265b9bb420..c72e718fd4 100644 --- a/core/vtk/ttkMergeTreePrincipalGeodesicsDecoding/ttkMergeTreePrincipalGeodesicsDecoding.cpp +++ b/core/vtk/ttkMergeTreePrincipalGeodesicsDecoding/ttkMergeTreePrincipalGeodesicsDecoding.cpp @@ -487,7 +487,7 @@ int ttkMergeTreePrincipalGeodesicsDecoding::runOutput( // ------------------------------------------ std::vector> matchingMatrix; if(!baryMatchings_.empty()) - getMatchingMatrix( + ttk::axa::getMatchingMatrix( baryMTree[0], inputMTrees, baryMatchings_, matchingMatrix); // TODO compute matching to barycenter if correlation matrix is not provided if(transferInputTreesInformation_ @@ -628,7 +628,7 @@ int ttkMergeTreePrincipalGeodesicsDecoding::runOutput( ttk::ftm::MergeTree baryMT; ttk::ftm::mergeTreeDoubleToTemplate(baryMTree[0], baryMT); std::vector matchingVector; - getInverseMatchingVector( + ttk::axa::getInverseMatchingVector( mt, baryMT, recBaryMatchings[index], matchingVector); std::vector baryNodeID(mt.tree.getNumberOfNodes(), -1); for(unsigned int n = 0; n < vSize_; ++n) { @@ -645,7 +645,7 @@ int ttkMergeTreePrincipalGeodesicsDecoding::runOutput( ttk::ftm::mergeTreeDoubleToTemplate( inputMTrees[index], inputMT); std::vector matchingVector; - getInverseMatchingVector( + ttk::axa::getInverseMatchingVector( mt, inputMT, recInputMatchings[index], matchingVector); std::vector baryNodeID(mt.tree.getNumberOfNodes(), -1); for(unsigned int n = 0; n < vSize_; ++n) {