Skip to content

Commit ac8b0ab

Browse files
authored
Merge pull request #10 from monkey0722/feature/dfs
feat: Add Depth-First Search (DFS)
2 parents 179e45e + 4efe401 commit ac8b0ab

File tree

2 files changed

+390
-0
lines changed

2 files changed

+390
-0
lines changed

src/graph/dfs.hpp

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#ifndef DFS_HPP
2+
#define DFS_HPP
3+
4+
#include <algorithm>
5+
#include <concepts>
6+
#include <iostream>
7+
#include <stack>
8+
#include <unordered_map>
9+
#include <unordered_set>
10+
#include <vector>
11+
12+
template <typename T>
13+
concept GraphNode = std::equality_comparable<T> && std::copy_constructible<T>;
14+
15+
template <typename T>
16+
concept HashableNode = GraphNode<T> && requires(T x) {
17+
{ std::hash<T>{}(x) } -> std::convertible_to<std::size_t>;
18+
};
19+
20+
template <HashableNode NodeType>
21+
class DFS {
22+
private:
23+
using Graph = std::unordered_map<NodeType, std::vector<NodeType>>;
24+
Graph adjacencyList;
25+
26+
void logVisit(const NodeType& node) const {
27+
std::cout << "Visiting node: " << node << std::endl;
28+
}
29+
30+
// Helper function for recursive DFS traversal
31+
void traverseRecursive(const NodeType& node, std::unordered_map<NodeType, bool>& visited,
32+
std::vector<NodeType>& result) const {
33+
visited[node] = true;
34+
logVisit(node);
35+
result.push_back(node);
36+
37+
if (auto it = adjacencyList.find(node); it != adjacencyList.end()) {
38+
for (const auto& neighbor : it->second) {
39+
if (!visited[neighbor]) {
40+
std::cout << "Moving from " << node << " to " << neighbor << std::endl;
41+
traverseRecursive(neighbor, visited, result);
42+
}
43+
}
44+
}
45+
}
46+
47+
public:
48+
void addEdge(const NodeType& from, const NodeType& to) {
49+
adjacencyList[from].push_back(to);
50+
// In the case of an isolated point, create an empty adjacent list
51+
if (adjacencyList.find(to) == adjacencyList.end()) {
52+
adjacencyList[to] = std::vector<NodeType>();
53+
}
54+
}
55+
56+
// Iterative DFS traversal using a stack
57+
[[nodiscard]] std::vector<NodeType> traverse(const NodeType& start) const {
58+
std::stack<NodeType> stack;
59+
std::unordered_map<NodeType, bool> visited;
60+
std::vector<NodeType> result;
61+
62+
std::cout << "Starting DFS traversal from node: " << start << std::endl;
63+
64+
stack.push(start);
65+
66+
while (!stack.empty()) {
67+
NodeType current = stack.top();
68+
stack.pop();
69+
70+
if (!visited[current]) {
71+
logVisit(current);
72+
result.push_back(current);
73+
visited[current] = true;
74+
75+
if (auto it = adjacencyList.find(current); it != adjacencyList.end()) {
76+
// Push neighbors in reverse order to process them in the original order
77+
for (auto it2 = it->second.rbegin(); it2 != it->second.rend(); ++it2) {
78+
const auto& neighbor = *it2;
79+
if (!visited[neighbor]) {
80+
std::cout << "Pushing " << neighbor << " to stack" << std::endl;
81+
stack.push(neighbor);
82+
}
83+
}
84+
}
85+
}
86+
}
87+
88+
return result;
89+
}
90+
91+
// Recursive DFS traversal
92+
[[nodiscard]] std::vector<NodeType> traverseRecursive(const NodeType& start) const {
93+
std::unordered_map<NodeType, bool> visited;
94+
std::vector<NodeType> result;
95+
96+
std::cout << "Starting recursive DFS traversal from node: " << start << std::endl;
97+
traverseRecursive(start, visited, result);
98+
99+
return result;
100+
}
101+
102+
// Find path using DFS (not necessarily the shortest)
103+
[[nodiscard]] std::vector<NodeType> findPath(const NodeType& start, const NodeType& target) const {
104+
std::stack<NodeType> stack;
105+
std::unordered_map<NodeType, bool> visited;
106+
std::unordered_map<NodeType, NodeType> parent;
107+
108+
std::cout << "Finding path from " << start << " to " << target << std::endl;
109+
110+
stack.push(start);
111+
visited[start] = true;
112+
113+
bool found = false;
114+
while (!stack.empty() && !found) {
115+
NodeType current = stack.top();
116+
stack.pop();
117+
118+
if (current == target) {
119+
std::cout << "Target " << target << " found!" << std::endl;
120+
found = true;
121+
break;
122+
}
123+
124+
if (auto it = adjacencyList.find(current); it != adjacencyList.end()) {
125+
for (const auto& neighbor : it->second) {
126+
if (!visited[neighbor]) {
127+
stack.push(neighbor);
128+
visited[neighbor] = true;
129+
parent[neighbor] = current;
130+
}
131+
}
132+
}
133+
}
134+
135+
std::vector<NodeType> path;
136+
if (found) {
137+
NodeType current = target;
138+
while (current != start) {
139+
path.push_back(current);
140+
current = parent[current];
141+
}
142+
path.push_back(start);
143+
std::reverse(path.begin(), path.end());
144+
}
145+
146+
return path;
147+
}
148+
149+
// Detect cycles in the graph
150+
[[nodiscard]] bool hasCycle() const {
151+
std::unordered_map<NodeType, bool> visited;
152+
std::unordered_map<NodeType, bool> inStack;
153+
154+
for (const auto& [node, _] : adjacencyList) {
155+
if (!visited[node]) {
156+
if (hasCycleUtil(node, visited, inStack)) {
157+
return true;
158+
}
159+
}
160+
}
161+
return false;
162+
}
163+
164+
private:
165+
bool hasCycleUtil(const NodeType& node, std::unordered_map<NodeType, bool>& visited,
166+
std::unordered_map<NodeType, bool>& inStack) const {
167+
visited[node] = true;
168+
inStack[node] = true;
169+
170+
if (auto it = adjacencyList.find(node); it != adjacencyList.end()) {
171+
for (const auto& neighbor : it->second) {
172+
if (!visited[neighbor]) {
173+
if (hasCycleUtil(neighbor, visited, inStack)) {
174+
return true;
175+
}
176+
} else if (inStack[neighbor]) {
177+
// If the neighbor is already in the recursion stack, we found a cycle
178+
return true;
179+
}
180+
}
181+
}
182+
183+
inStack[node] = false; // Remove the node from recursion stack
184+
return false;
185+
}
186+
187+
public:
188+
// Topological sort (only works for DAGs)
189+
[[nodiscard]] std::vector<NodeType> topologicalSort() const {
190+
if (hasCycle()) {
191+
std::cout << "Graph has a cycle, topological sort not possible" << std::endl;
192+
return {};
193+
}
194+
195+
std::unordered_map<NodeType, bool> visited;
196+
std::stack<NodeType> stack;
197+
std::vector<NodeType> result;
198+
199+
for (const auto& [node, _] : adjacencyList) {
200+
if (!visited[node]) {
201+
topologicalSortUtil(node, visited, stack);
202+
}
203+
}
204+
205+
while (!stack.empty()) {
206+
result.push_back(stack.top());
207+
stack.pop();
208+
}
209+
210+
return result;
211+
}
212+
213+
private:
214+
void topologicalSortUtil(const NodeType& node, std::unordered_map<NodeType, bool>& visited,
215+
std::stack<NodeType>& stack) const {
216+
visited[node] = true;
217+
218+
if (auto it = adjacencyList.find(node); it != adjacencyList.end()) {
219+
for (const auto& neighbor : it->second) {
220+
if (!visited[neighbor]) {
221+
topologicalSortUtil(neighbor, visited, stack);
222+
}
223+
}
224+
}
225+
226+
// All descendants processed, push current node to stack
227+
stack.push(node);
228+
}
229+
230+
public:
231+
[[nodiscard]] size_t countConnectedComponents() const {
232+
std::unordered_set<NodeType> unvisited;
233+
for (const auto& [node, _] : adjacencyList) {
234+
unvisited.insert(node);
235+
}
236+
237+
size_t components = 0;
238+
while (!unvisited.empty()) {
239+
NodeType start = *unvisited.begin();
240+
std::cout << "Starting new component exploration from node: " << start << std::endl;
241+
auto visited = traverse(start);
242+
for (const auto& node : visited) {
243+
unvisited.erase(node);
244+
}
245+
components++;
246+
}
247+
248+
return components;
249+
}
250+
};
251+
252+
#endif // DFS_HPP

tests/graph/dfs_test.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include "../src/graph/dfs.hpp"
2+
3+
#include <gtest/gtest.h>
4+
5+
TEST(DFSTest, BasicTraversal) {
6+
DFS<int> dfs;
7+
8+
dfs.addEdge(0, 1);
9+
dfs.addEdge(0, 2);
10+
dfs.addEdge(2, 3);
11+
dfs.addEdge(2, 4);
12+
13+
auto result = dfs.traverse(0);
14+
15+
ASSERT_EQ(result.size(), 5);
16+
EXPECT_EQ(result[0], 0);
17+
// DFS will go deep first, so we expect a different order than BFS
18+
EXPECT_TRUE((result[1] == 1 && result[2] == 2) || (result[1] == 2 && result[3] == 3));
19+
}
20+
21+
TEST(DFSTest, RecursiveTraversal) {
22+
DFS<int> dfs;
23+
24+
dfs.addEdge(0, 1);
25+
dfs.addEdge(0, 2);
26+
dfs.addEdge(2, 3);
27+
dfs.addEdge(2, 4);
28+
29+
auto result = dfs.traverseRecursive(0);
30+
31+
ASSERT_EQ(result.size(), 5);
32+
EXPECT_EQ(result[0], 0);
33+
// The recursive DFS should follow the order of adjacency list
34+
}
35+
36+
TEST(DFSTest, FindPath) {
37+
DFS<int> dfs;
38+
39+
// Add an edge as an unoriented graph
40+
dfs.addEdge(0, 1);
41+
dfs.addEdge(1, 0);
42+
dfs.addEdge(0, 2);
43+
dfs.addEdge(2, 0);
44+
dfs.addEdge(0, 3);
45+
dfs.addEdge(3, 0);
46+
dfs.addEdge(2, 3);
47+
dfs.addEdge(3, 2);
48+
49+
auto path = dfs.findPath(1, 3);
50+
51+
ASSERT_GT(path.size(), 0);
52+
EXPECT_EQ(path[0], 1);
53+
EXPECT_EQ(path[path.size() - 1], 3);
54+
}
55+
56+
TEST(DFSTest, CycleDetection) {
57+
DFS<int> dfs1;
58+
59+
// Create a graph with a cycle
60+
dfs1.addEdge(0, 1);
61+
dfs1.addEdge(1, 2);
62+
dfs1.addEdge(2, 0);
63+
64+
EXPECT_TRUE(dfs1.hasCycle());
65+
66+
DFS<int> dfs2;
67+
68+
// Create a graph without a cycle (DAG)
69+
dfs2.addEdge(0, 1);
70+
dfs2.addEdge(0, 2);
71+
dfs2.addEdge(1, 3);
72+
dfs2.addEdge(2, 3);
73+
74+
EXPECT_FALSE(dfs2.hasCycle());
75+
}
76+
77+
TEST(DFSTest, TopologicalSort) {
78+
DFS<int> dfs;
79+
80+
// Create a DAG
81+
dfs.addEdge(5, 2);
82+
dfs.addEdge(5, 0);
83+
dfs.addEdge(4, 0);
84+
dfs.addEdge(4, 1);
85+
dfs.addEdge(2, 3);
86+
dfs.addEdge(3, 1);
87+
88+
auto result = dfs.topologicalSort();
89+
90+
ASSERT_EQ(result.size(), 6);
91+
92+
// Check that for each edge (u, v), u comes before v in the topological sort
93+
std::unordered_map<int, int> position;
94+
for (size_t i = 0; i < result.size(); ++i) {
95+
position[result[i]] = i;
96+
}
97+
98+
EXPECT_LT(position[5], position[2]);
99+
EXPECT_LT(position[5], position[0]);
100+
EXPECT_LT(position[4], position[0]);
101+
EXPECT_LT(position[4], position[1]);
102+
EXPECT_LT(position[2], position[3]);
103+
EXPECT_LT(position[3], position[1]);
104+
}
105+
106+
TEST(DFSTest, ConnectedComponents) {
107+
DFS<int> dfs;
108+
109+
// As an undirected graph, follow the edges
110+
dfs.addEdge(0, 1);
111+
dfs.addEdge(1, 0);
112+
dfs.addEdge(1, 2);
113+
dfs.addEdge(2, 1);
114+
115+
dfs.addEdge(3, 4);
116+
dfs.addEdge(4, 3);
117+
118+
// Isolated points are also tracked
119+
dfs.addEdge(5, 5);
120+
121+
EXPECT_EQ(dfs.countConnectedComponents(), 3);
122+
}
123+
124+
TEST(DFSTest, CustomNodeType) {
125+
DFS<std::string> dfs;
126+
127+
dfs.addEdge("A", "B");
128+
dfs.addEdge("B", "A");
129+
dfs.addEdge("B", "C");
130+
dfs.addEdge("C", "B");
131+
132+
auto result = dfs.traverse("A");
133+
134+
ASSERT_EQ(result.size(), 3);
135+
EXPECT_EQ(result[0], "A");
136+
// The order may vary depending on the implementation
137+
EXPECT_TRUE((result[1] == "B" && result[2] == "C") || (result[1] == "C" && result[2] == "B"));
138+
}

0 commit comments

Comments
 (0)