Skip to content

Commit cefed24

Browse files
Merge pull request #1065 from MatPont/mtnn_refactor
[MTNN] merge tree neural network refactor
2 parents 7b36150 + 9d97e76 commit cefed24

38 files changed

+6141
-3731
lines changed

core/base/ftmTree/FTMNode.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,22 @@ namespace ttk {
154154
}
155155
}
156156

157+
inline void removeDownSuperArcs(std::vector<idSuperArc> &idSa) {
158+
if(idSa.empty())
159+
return;
160+
std::vector<bool> toDelete(
161+
(*std::max_element(idSa.begin(), idSa.end())) + 1, false);
162+
for(auto &id : idSa)
163+
toDelete[id] = true;
164+
vect_downSuperArcList_.erase(
165+
std::remove_if(vect_downSuperArcList_.begin(),
166+
vect_downSuperArcList_.end(),
167+
[&toDelete](const idSuperArc &i) {
168+
return i < toDelete.size() and toDelete[i];
169+
}),
170+
vect_downSuperArcList_.end());
171+
}
172+
157173
// Find and remove the arc
158174
inline void removeUpSuperArc(idSuperArc idSa) {
159175
for(idSuperArc i = 0; i < vect_upSuperArcList_.size(); ++i) {

core/base/ftmTree/FTMTreeUtils.cpp

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,34 @@ namespace ttk {
1313
// --------------------
1414
// Is
1515
// --------------------
16-
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) {
16+
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) const {
1717
unsigned int const origin
1818
= (unsigned int)this->getNode(nodeId)->getOrigin();
1919
return origin != nullNodes && origin < this->getNumberOfNodes();
2020
}
2121

22-
bool FTMTree_MT::isRoot(idNode nodeId) {
22+
bool FTMTree_MT::isRoot(idNode nodeId) const {
2323
return this->getNode(nodeId)->getNumberOfUpSuperArcs() == 0;
2424
}
2525

26-
bool FTMTree_MT::isLeaf(idNode nodeId) {
26+
bool FTMTree_MT::isLeaf(idNode nodeId) const {
2727
return this->getNode(nodeId)->getNumberOfDownSuperArcs() == 0;
2828
}
2929

30-
bool FTMTree_MT::isNodeAlone(idNode nodeId) {
30+
bool FTMTree_MT::isNodeAlone(idNode nodeId) const {
3131
return this->isRoot(nodeId) and this->isLeaf(nodeId);
3232
}
3333

34-
bool FTMTree_MT::isFullMerge() {
34+
bool FTMTree_MT::isFullMerge() const {
3535
idNode const treeRoot = this->getRoot();
3636
return (unsigned int)this->getNode(treeRoot)->getOrigin() == treeRoot;
3737
}
3838

39-
bool FTMTree_MT::isBranchOrigin(idNode nodeId) {
39+
bool FTMTree_MT::isBranchOrigin(idNode nodeId) const {
4040
return this->getParentSafe(this->getNode(nodeId)->getOrigin()) != nodeId;
4141
}
4242

43-
bool FTMTree_MT::isNodeMerged(idNode nodeId) {
43+
bool FTMTree_MT::isNodeMerged(idNode nodeId) const {
4444
bool merged = this->isNodeAlone(nodeId)
4545
or this->isNodeAlone(this->getNode(nodeId)->getOrigin());
4646
auto nodeIdOrigin = this->getNode(nodeId)->getOrigin();
@@ -49,11 +49,11 @@ namespace ttk {
4949
return merged;
5050
}
5151

52-
bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) {
52+
bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) const {
5353
return nodeId >= this->getNumberOfNodes();
5454
}
5555

56-
bool FTMTree_MT::isThereOnlyOnePersistencePair() {
56+
bool FTMTree_MT::isThereOnlyOnePersistencePair() const {
5757
idNode const treeRoot = this->getRoot();
5858
unsigned int cptNodeAlone = 0;
5959
idNode otherNode = treeRoot;
@@ -74,7 +74,7 @@ namespace ttk {
7474
}
7575

7676
// Do not normalize node is if root or son of a merged root
77-
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) {
77+
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) const {
7878
auto nodeIdParent = this->getParentSafe(nodeId);
7979
return this->isRoot(nodeId)
8080
or (this->isRoot(nodeIdParent)
@@ -84,7 +84,7 @@ namespace ttk {
8484
// and nodeIdOrigin == nodeIdParent) )
8585
}
8686

87-
bool FTMTree_MT::isMultiPersPair(idNode nodeId) {
87+
bool FTMTree_MT::isMultiPersPair(idNode nodeId) const {
8888
auto nodeOriginOrigin
8989
= (unsigned int)this->getNode(this->getNode(nodeId)->getOrigin())
9090
->getOrigin();
@@ -94,14 +94,14 @@ namespace ttk {
9494
// --------------------
9595
// Get
9696
// --------------------
97-
idNode FTMTree_MT::getRoot() {
97+
idNode FTMTree_MT::getRoot() const {
9898
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
9999
if(this->isRoot(node) and !this->isLeaf(node))
100100
return node;
101101
return nullNodes;
102102
}
103103

104-
idNode FTMTree_MT::getParentSafe(idNode nodeId) {
104+
idNode FTMTree_MT::getParentSafe(idNode nodeId) const {
105105
if(!this->isRoot(nodeId)) {
106106
// _ Nodes in merge trees should have only one parent
107107
idSuperArc const arcId = this->getNode(nodeId)->getUpSuperArcId(0);
@@ -112,7 +112,7 @@ namespace ttk {
112112
}
113113

114114
void FTMTree_MT::getChildren(idNode nodeId,
115-
std::vector<idNode> &childrens) {
115+
std::vector<idNode> &childrens) const {
116116
childrens.clear();
117117
for(idSuperArc i = 0;
118118
i < this->getNode(nodeId)->getNumberOfDownSuperArcs(); ++i) {
@@ -121,33 +121,34 @@ namespace ttk {
121121
}
122122
}
123123

124-
void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) {
124+
void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) const {
125125
treeLeaves.clear();
126126
for(idNode i = 0; i < this->getNumberOfNodes(); ++i) {
127127
if(this->isLeaf(i) and !this->isRoot(i))
128128
treeLeaves.push_back(i);
129129
}
130130
}
131131

132-
int FTMTree_MT::getNumberOfLeavesFromTree() {
132+
int FTMTree_MT::getNumberOfLeavesFromTree() const {
133133
std::vector<idNode> leaves;
134134
this->getLeavesFromTree(leaves);
135135
return leaves.size();
136136
}
137137

138-
int FTMTree_MT::getNumberOfNodeAlone() {
138+
int FTMTree_MT::getNumberOfNodeAlone() const {
139139
int cpt = 0;
140140
for(idNode i = 0; i < this->getNumberOfNodes(); ++i)
141141
cpt += this->isNodeAlone(i) ? 1 : 0;
142142
return cpt;
143143
}
144144

145-
int FTMTree_MT::getRealNumberOfNodes() {
145+
int FTMTree_MT::getRealNumberOfNodes() const {
146146
return this->getNumberOfNodes() - this->getNumberOfNodeAlone();
147147
}
148148

149149
void FTMTree_MT::getBranchOriginsFromThisBranch(
150-
idNode node, std::tuple<std::vector<idNode>, std::vector<idNode>> &res) {
150+
idNode node,
151+
std::tuple<std::vector<idNode>, std::vector<idNode>> &res) const {
151152
std::vector<idNode> branchOrigins, nonBranchOrigins;
152153

153154
idNode const nodeOrigin = this->getNode(node)->getOrigin();
@@ -166,7 +167,7 @@ namespace ttk {
166167
void FTMTree_MT::getTreeBranching(
167168
std::vector<idNode> &branching,
168169
std::vector<int> &branchingID,
169-
std::vector<std::vector<idNode>> &nodeBranching) {
170+
std::vector<std::vector<idNode>> &nodeBranching) const {
170171
branching = std::vector<idNode>(this->getNumberOfNodes());
171172
branchingID = std::vector<int>(this->getNumberOfNodes(), -1);
172173
nodeBranching
@@ -200,31 +201,31 @@ namespace ttk {
200201
}
201202

202203
void FTMTree_MT::getTreeBranching(std::vector<idNode> &branching,
203-
std::vector<int> &branchingID) {
204+
std::vector<int> &branchingID) const {
204205
std::vector<std::vector<idNode>> nodeBranching;
205206
this->getTreeBranching(branching, branchingID, nodeBranching);
206207
}
207208

208-
void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) {
209+
void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) const {
209210
roots.clear();
210211
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
211212
if(this->isRoot(node) and !this->isLeaf(node))
212213
roots.push_back(node);
213214
}
214215

215-
int FTMTree_MT::getNumberOfRoot() {
216+
int FTMTree_MT::getNumberOfRoot() const {
216217
int noRoot = 0;
217218
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
218219
if(this->isRoot(node) and !this->isLeaf(node))
219220
++noRoot;
220221
return noRoot;
221222
}
222223

223-
int FTMTree_MT::getNumberOfChildren(idNode nodeId) {
224+
int FTMTree_MT::getNumberOfChildren(idNode nodeId) const {
224225
return this->getNode(nodeId)->getNumberOfDownSuperArcs();
225226
}
226227

227-
int FTMTree_MT::getTreeDepth() {
228+
int FTMTree_MT::getTreeDepth() const {
228229
int maxDepth = 0;
229230
std::queue<std::tuple<idNode, int>> queue;
230231
queue.emplace(this->getRoot(), 0);
@@ -242,7 +243,7 @@ namespace ttk {
242243
return maxDepth;
243244
}
244245

245-
int FTMTree_MT::getNodeLevel(idNode nodeId) {
246+
int FTMTree_MT::getNodeLevel(idNode nodeId) const {
246247
int level = 0;
247248
auto root = this->getRoot();
248249
int const noRoot = this->getNumberOfRoot();
@@ -261,7 +262,7 @@ namespace ttk {
261262
return level;
262263
}
263264

264-
void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) {
265+
void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) const {
265266
allNodeLevel = std::vector<int>(this->getNumberOfNodes());
266267
std::queue<std::tuple<idNode, int>> queue;
267268
queue.emplace(this->getRoot(), 0);
@@ -279,7 +280,7 @@ namespace ttk {
279280
}
280281

281282
void FTMTree_MT::getLevelToNode(
282-
std::vector<std::vector<idNode>> &levelToNode) {
283+
std::vector<std::vector<idNode>> &levelToNode) const {
283284
std::vector<int> allNodeLevel;
284285
this->getAllNodeLevel(allNodeLevel);
285286
int const maxLevel
@@ -290,9 +291,10 @@ namespace ttk {
290291
}
291292
}
292293

293-
void FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
294-
idNode branchRoot,
295-
std::vector<idNode> &branchSubtree) {
294+
void
295+
FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
296+
idNode branchRoot,
297+
std::vector<idNode> &branchSubtree) const {
296298
branchSubtree.clear();
297299
std::queue<idNode> queue;
298300
queue.push(branchRoot);
@@ -316,7 +318,7 @@ namespace ttk {
316318
// Persistence
317319
// --------------------
318320
void FTMTree_MT::getMultiPersOriginsVectorFromTree(
319-
std::vector<std::vector<idNode>> &treeMultiPers) {
321+
std::vector<std::vector<idNode>> &treeMultiPers) const {
320322
treeMultiPers
321323
= std::vector<std::vector<idNode>>(this->getNumberOfNodes());
322324
for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i)
@@ -398,7 +400,7 @@ namespace ttk {
398400
// --------------------
399401
// Create/Delete/Modify Tree
400402
// --------------------
401-
void FTMTree_MT::copyMergeTreeStructure(FTMTree_MT *tree) {
403+
void FTMTree_MT::copyMergeTreeStructure(const FTMTree_MT *tree) {
402404
// Add Nodes
403405
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
404406
this->makeNode(i);
@@ -418,7 +420,7 @@ namespace ttk {
418420
// --------------------
419421
// Utils
420422
// --------------------
421-
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) {
423+
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) const {
422424
ss << "(" << node << ") \\ ";
423425

424426
std::vector<idNode> children;
@@ -431,7 +433,7 @@ namespace ttk {
431433
ss << std::endl;
432434
}
433435

434-
std::stringstream FTMTree_MT::printSubTree(idNode subRoot) {
436+
std::stringstream FTMTree_MT::printSubTree(idNode subRoot) const {
435437
std::stringstream ss;
436438
ss << "Nodes----------" << std::endl;
437439
std::queue<idNode> queue;
@@ -450,7 +452,7 @@ namespace ttk {
450452
return ss;
451453
}
452454

453-
std::stringstream FTMTree_MT::printTree(bool doPrint) {
455+
std::stringstream FTMTree_MT::printTree(bool doPrint) const {
454456
std::stringstream ss;
455457
std::vector<idNode> allRoots;
456458
this->getAllRoots(allRoots);
@@ -471,7 +473,7 @@ namespace ttk {
471473
return ss;
472474
}
473475

474-
std::stringstream FTMTree_MT::printTreeStats(bool doPrint) {
476+
std::stringstream FTMTree_MT::printTreeStats(bool doPrint) const {
475477
auto noNodesT = this->getNumberOfNodes();
476478
auto noNodes = this->getRealNumberOfNodes();
477479
std::stringstream ss;
@@ -483,7 +485,7 @@ namespace ttk {
483485
}
484486

485487
std::stringstream
486-
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) {
488+
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) const {
487489
std::stringstream ss;
488490
std::vector<std::vector<idNode>> vec;
489491
this->getMultiPersOriginsVectorFromTree(vec);

core/base/ftmTree/FTMTreeUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ namespace ttk {
148148
}
149149

150150
template <class dataType>
151-
void getTreeScalars(ftm::FTMTree_MT *tree,
151+
void getTreeScalars(const ftm::FTMTree_MT *tree,
152152
std::vector<dataType> &scalarsVector) {
153153
scalarsVector.clear();
154154
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
@@ -162,7 +162,7 @@ namespace ttk {
162162
}
163163

164164
template <class dataType>
165-
MergeTree<dataType> copyMergeTree(ftm::FTMTree_MT *tree,
165+
MergeTree<dataType> copyMergeTree(const ftm::FTMTree_MT *tree,
166166
bool doSplitMultiPersPairs = false) {
167167
std::vector<dataType> scalarsVector;
168168
getTreeScalars<dataType>(tree, scalarsVector);
@@ -201,7 +201,7 @@ namespace ttk {
201201
}
202202

203203
template <class dataType>
204-
MergeTree<dataType> copyMergeTree(MergeTree<dataType> &mergeTree,
204+
MergeTree<dataType> copyMergeTree(const MergeTree<dataType> &mergeTree,
205205
bool doSplitMultiPersPairs = false) {
206206
return copyMergeTree<dataType>(&(mergeTree.tree), doSplitMultiPersPairs);
207207
}

0 commit comments

Comments
 (0)