Skip to content

Commit 7b2cc72

Browse files
committed
completed configurable backends
1 parent d329967 commit 7b2cc72

File tree

21 files changed

+168
-92
lines changed

21 files changed

+168
-92
lines changed
Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
# Demo
1+
# Node Ranking
22

3-
As a quick start, let us construct a graph
4-
and a set of nodes. The graph's class can be
5-
imported either from the `networkx` library or from
6-
`pygrank` itself. The two are in large part interoperable
7-
and both can be parsed by our algorithms.
8-
But our implementation is tailored to graph signal
9-
processing needs and thus tends to be faster and consume
10-
only a fraction of the memory.
3+
Here we will see how an appropriate convergence manager
4+
can be used to speed up a node ranking process, where
5+
nodes obtain ordinal values 1,2,3,... based on their
6+
importance in the graph structure (1 is the most
7+
important node). For starters, let us construct some data to test with:
118

129
```python
13-
from pygrank import Graph
10+
import pygrank as pg
1411

15-
graph = Graph()
12+
graph = pg.Graph()
1613
graph.add_edge("A", "B")
1714
graph.add_edge("B", "C")
1815
graph.add_edge("C", "D")
@@ -24,16 +21,8 @@ seeds = {"A", "B"}
2421
```
2522

2623
We now run a personalized PageRank
27-
to score the structural relatedness of graph nodes to the ones of the given set.
28-
First, let us import the library:
29-
30-
```python
31-
import pygrank as pg
32-
```
33-
34-
For instructional purposes,
35-
we experiment with (personalized) *PageRank*
36-
and make it output the node order of ranks.
24+
to score the structural relatedness of graph nodes to the ones of the given set
25+
and apply a postprocessor that ranks nodes based on their score:
3726

3827
```python
3928
ranker = pg.PageRank(alpha=0.85, tol=1.E-6, normalization="auto") >> pg.Ordinals()
@@ -61,6 +50,7 @@ print(ordinals["B"], ordinals["D"], ordinals["E"])
6150
# 3.0 5.0 4.0
6251
```
6352

64-
Close to the previous results at a fraction of the time! For large graphs,
53+
This is close to the previous results at a fraction of the time!
54+
For large graphs,
6555
most ordinals would be near the ideal ones. Note that convergence time
6656
does not take into account the time needed to preprocess graphs.

docs/userguide/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Quickstart
22

33
## 1. Install and import
4-
Install the library using `pip install pygrank` and import it. Construct a node ranking algorithm from a graph filter by incrementally applying postprocessors using >>. There are many components and parameters available. You can use [autotuning](autotuning.md) to find good configurations.
4+
Install the library using `pip install pygrank` and import it. Construct a node ranking algorithm from a graph filter by incrementally applying postprocessors using >>. There are many components and parameters available. Use [autotuning](autotuning.md) to find good configurations.
55

66
```python
77
import pygrank as pg
@@ -28,4 +28,4 @@ Evaluate the scores using a stochastic generalization of the unsupervised conduc
2828
measure = pg.Conductance() # an evaluation measure
2929
pg.benchmark_print_line("My conductance", measure(scores)) # pretty
3030
print("Cite this algorithm as:", hk5_advanced.cite())
31-
```~~
31+
```

docs/userguide/setup.md

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ and install or upgrade to the latest version of `pygrank` with:
77
pip install --upgrade pygrank
88
```
99

10+
## Creating graphs
11+
12+
When working n practical problems,
13+
use `networkx` to construct graphs
14+
by adding edges between Python objects.
15+
`pygrank` also provides its own `pygrank.Graph` class
16+
that implements a subset of `networkx.Graph` operations
17+
to gain several optimizations; it tends to be faster for the
18+
construction of large graphs and consumes
19+
only a fraction of the memory.
20+
1021
## Backends
1122

1223
Several popular computational backends are supported.
@@ -41,12 +52,13 @@ The same message points to a configuration file stored under *home/.pygrank*.
4152
In addition to automatically downloaded content, there is a JSON configuration
4253
file specifying the default backend to be set upon first import and the option
4354
to silence the reminder message. The configuration looks like this and can either be
44-
edited directly or programmatically set with `pg.set_backend_preference(name, reminder=True)`):
55+
edited directly or programmatically set with `pg.set_backend_preference(name, reminder=True, **init)`):
4556

4657
```json
4758
{
4859
"backend": "numpy",
49-
"reminder": "true"
60+
"reminder": "true",
61+
"init": {}
5062
}
5163
```
5264

@@ -64,28 +76,26 @@ necessarily be the fastest option for dense or very sparse graphs.
6476

6577
### <span class="component">tensorflow</span>
6678
<b class="parameters">About</b><br>Performs computations within the `tensorflow` execution environment.
67-
The latter is an open-source platform for machine learning developed by the Google Brain team.
68-
It allows for efficient computation across multiple CPUs and GPUs, making it suitable for
69-
performant large-scale data processing and deep learning applications.
79+
The latter is an open-source platform for machine learning developed by the Google Brain team.
7080
There
7181
are two modes in which this backend can be executed: `"dense"` (default) and `"sparse"`.
7282
The mode may be provided as additional arguments to the
73-
`pg.set_backend("tensorflow", mode=...)` call.
83+
`pg.set_backend("tensorflow", mode="dense" device="auto")` call.
7484
In dense mode, the tensorflow backend attempts to store graphs in dense square
7585
matrices that take full advantage of tensorflow's parallelization.
7686
If there is not enough memory to allocate a sparse adjacency matrix,
7787
the backend generates a sparse version and creates a warning.
78-
<br>
88+
The backend's initialization also accepts a device string or object to
89+
which computations should be internally transferred. This needs to
90+
be one among tensorflow's available devices.
7991
<br>
8092
<b class="parameters">Installation</b><br> `pip install tensorflow[and-cuda]`<br>On Windows install WSL2 (Windows Subsystem for Linux) first.<br>
8193
<b class="parameters">Links</b><br> [tensorflow](https://www.tensorflow.org/install)
8294

8395

8496
### <span class="component">pytorch</span>
8597
<b class="parameters">About</b><br>Performs computations within the `pytorch` execution environment.
86-
The latter is an open-source platform for machine learning developed by Meta's AI Research lab.
87-
It is known for its flexibility, ease of use, and dynamic computation graph, which makes it popular
88-
in research and production.
98+
The latter is an open-source platform for machine learning developed by Meta's AI Research lab.
8999
Similarly to `"tensorflow"`,
90100
are two modes in which this backend can be executed: `"dense"` (default) and `"sparse"`.
91101
The mode may be provided as additional arguments to the
@@ -94,9 +104,26 @@ In dense mode, the pytorch backend attempts to store graphs in dense square
94104
matrices that take full advantage of tensorflow's parallelization.
95105
If there is not enough memory to allocate a sparse adjacency matrix,
96106
the backend generates a sparse version and creates a warning.
107+
The backend's initialization also accepts a device string or object to
108+
which computations should be internally transferred. This needs to
109+
be one among pytorch's available devices (typically `"cuda"` or `"cpu"`).
97110
<br>
98-
<br>
99-
<br>
111+
<b class="parameters">Installation</b><br> For full installation instructions visit pytorch's website in the links below.<br>
112+
<b class="parameters">Links</b><br> [pytorch](https://pytorch.org/get-started/locally)
113+
114+
### <span class="component">torch_sparse</span>
115+
<b class="parameters">About</b><br>Performs computations within the `pytorch` execution environment,
116+
but contrary to the `"pytorch` backend uses the sparse computations of the `torch_sparse` library.
117+
The latter is an open-source platform for machine learning developed by Meta's AI Research lab.
118+
Similarly to `"tensorflow"`,
119+
are two modes in which this backend can be executed: `"dense"` (default) and `"sparse"`.
120+
The backend's initialization only accepts a device string or object to
121+
which computations should be internally transferred. This needs to
122+
be one among pytorch's available devices (typically `"cuda"` or `"cpu"`).
123+
!!! info
124+
`"torch_sparse"` is much more computationally efficient than `"pytorch"`
125+
for computations with sparse data structures.
126+
100127
<b class="parameters">Installation</b><br> For full installation instructions visit pytorch's website in the links below.<br>
101128
<b class="parameters">Links</b><br> [pytorch](https://pytorch.org/get-started/locally) <br>
102129
[torch_sparse](https://github.com/rusty1s/pytorch_sparse)
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import pygrank as pg
22
import torch
3+
34
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45

56

67
with pg.Backend("torch_sparse", device=device):
78
_, graph, community = next(pg.load_datasets_one_community(["amazon"]))
8-
ppr = pg.PageRank(alpha=0.9, normalization="symmetric", assume_immutability=True,
9-
convergence=pg.ConvergenceManager(max_iters=38, error_type="iters"))
9+
ppr = pg.PageRank(
10+
alpha=0.9,
11+
normalization="symmetric",
12+
assume_immutability=True,
13+
convergence=pg.ConvergenceManager(max_iters=38, error_type="iters"),
14+
)
1015
ppr.preprocessor(graph)
1116
signal = pg.to_signal(graph, {node: 1.0 for node in community})
1217
torch.cuda.synchronize() # correct timing
1318
scores = ppr(signal)
1419
print(ppr.convergence)
1520
print(scores["B00005MHUG"]) # 0.00508212111890316
1621
print(scores["B00006RGI2"]) # 0.70645672082901
17-
print(scores["0006497993"]) # 0.19633759558200836
22+
print(scores["0006497993"]) # 0.19633759558200836

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ nav:
1313
- 'userguide/autotuning.md'
1414
- 'userguide/preprocessing.md'
1515
- Applications:
16+
- 'advanced/ranking.md'
1617
- 'advanced/community.md'
1718
- 'advanced/gnn.md'
19+
- 'advanced/fairness.md'
1820
- R&D:
1921
- 'tips/citations.md'
2022
- 'tips/big.md'

pygrank/algorithms/convergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def start(self, restart_timer: bool = True):
141141

142142
def has_converged(self, new_ranks: BackendPrimitive) -> bool:
143143
# TODO: convert to any backend
144-
new_ranks = np.array(new_ranks).squeeze()
144+
new_ranks = backend.to_numpy(new_ranks).squeeze()
145145
self.accumulated_ranks = (
146146
self.accumulated_ranks * self.iteration + new_ranks
147147
) / (self.iteration + 1)

pygrank/core/backend/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def converted(*args, **kwargs):
110110
return converted
111111

112112
setattr(thismod, api, converter(mod.__dict__[api]))
113-
else: # pragma: no cover
114-
raise Exception("Missing implementation for " + str(api))
113+
#else: # pragma: no cover
114+
# raise Exception("Missing implementation for " + str(api))
115115
return mod.backend_init(*args, **kwargs)
116116

117117

@@ -157,9 +157,9 @@ def get_backend_preference(): # pragma: no cover
157157
return {"mod_name": mod_name, **init_parameters}
158158

159159

160-
def set_backend_preference(mod_name: str ,
161-
remind_where_to_find: bool = True,
162-
**kwargs): # pragma: no cover
160+
def set_backend_preference(
161+
mod_name: str, remind_where_to_find: bool = True, **kwargs
162+
): # pragma: no cover
163163
default_dir = os.path.join(os.path.expanduser("~"), ".pygrank")
164164
if not os.path.exists(default_dir):
165165
os.makedirs(default_dir)
@@ -169,7 +169,7 @@ def set_backend_preference(mod_name: str ,
169169
{
170170
"backend": mod_name.lower(),
171171
"reminder": str(remind_where_to_find).lower(),
172-
"init": {str(k): str(v) for k, v in kwargs.items()}
172+
"init": {str(k): str(v) for k, v in kwargs.items()},
173173
},
174174
config_file,
175175
)

pygrank/core/backend/ddask.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def backend_init(*args, splits: int = 8, client=None, **kwargs):
1818
if client is None:
1919
client = dsk.distributed.Client(*args, **kwargs)
2020
__pygrank_dask_config["client"] = client
21-
else:
21+
elif client is not None:
2222
__pygrank_dask_config["client"] = client
2323
return __pygrank_dask_config["client"]
2424

@@ -117,7 +117,8 @@ def multiply_and_collect(signal, split):
117117

118118
# Use Dask to parallelize the multiplication
119119
futures = [
120-
__pygrank_dask_config["client"].submit(multiply_and_collect, signal, split) for split in M_splits
120+
__pygrank_dask_config["client"].submit(multiply_and_collect, signal, split)
121+
for split in M_splits
121122
]
122123
results = __pygrank_dask_config["client"].gather(futures)
123124

pygrank/core/backend/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def to_array(obj, copy_array=False):
4646
if obj.__class__.__module__ == "tensorflow.python.framework.ops":
4747
return obj.numpy()
4848
if obj.__class__.__module__ == "torch":
49-
return obj.detach().numpy()
49+
return obj.detach().cpu().numpy()
5050
return np.array(obj)
5151

5252

pygrank/core/backend/pytorch.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@ def diag(x, offset=0):
3434
def backend_init(mode="dense", device=None):
3535
__pygrank_torch_config["mode"] = mode
3636
if device is not None and device == "auto":
37-
if not isinstance(__pygrank_torch_config["device"], str) or __pygrank_torch_config["device"] != "auto":
37+
if (
38+
not isinstance(__pygrank_torch_config["device"], str)
39+
or __pygrank_torch_config["device"] != "auto"
40+
):
3841
return
3942
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40-
warnings.warn(f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}")
43+
warnings.warn(
44+
f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}"
45+
)
4146
if device is not None and isinstance(device, str):
4247
device = torch.device(device)
4348
__pygrank_torch_config["device"] = device
@@ -94,14 +99,22 @@ def scipy_sparse_to_backend(M):
9499
return torch.FloatTensor(M.todense()).to(__pygrank_torch_config["device"])
95100
except MemoryError:
96101
warnings.warn(
97-
f"[pygrank.backend.pytorch] Not enough memory to convert a scipy sparse matrix with shape {M.shape} to a numpy dense matrix before moving it to your device.\nWill create a torch.sparse_coo_tensor instead.\nAdd the option mode=\"sparse\" to the backend's initialization to hide this message,\nbut prefer switching to the torch_sparse backend for a performant implementation.")
102+
f"[pygrank.backend.pytorch] Not enough memory to convert a scipy sparse matrix with shape {M.shape} "
103+
f"to a numpy dense matrix before moving it to your device.\nWill create a torch.sparse_coo_tensor instead."
104+
f'\nAdd the option mode="sparse" to the backend\'s initialization to hide this message,'
105+
f"\nbut prefer switching to the torch_sparse backend for a performant implementation."
106+
)
98107

99108
coo = M.tocoo()
100-
return torch.sparse_coo_tensor(
101-
torch.LongTensor(np.vstack((coo.col, coo.row))),
102-
torch.FloatTensor(coo.data),
103-
coo.shape,
104-
).coalesce().to(__pygrank_torch_config["device"])
109+
return (
110+
torch.sparse_coo_tensor(
111+
torch.LongTensor(np.vstack((coo.col, coo.row))),
112+
torch.FloatTensor(coo.data),
113+
coo.shape,
114+
)
115+
.coalesce()
116+
.to(__pygrank_torch_config["device"])
117+
)
105118

106119

107120
def to_array(obj, copy_array=False):
@@ -111,12 +124,16 @@ def to_array(obj, copy_array=False):
111124
return torch.clone(obj).to(__pygrank_torch_config["device"])
112125
return obj.to(__pygrank_torch_config["device"])
113126
return torch.ravel(obj).to(__pygrank_torch_config["device"])
114-
return torch.ravel(torch.FloatTensor(np.array([v for v in obj], dtype=np.float32))).to(__pygrank_torch_config["device"])
127+
return torch.ravel(
128+
torch.FloatTensor(np.array([v for v in obj], dtype=np.float32))
129+
).to(__pygrank_torch_config["device"])
115130

116131

117132
def to_primitive(obj):
118133
if isinstance(obj, float):
119-
return torch.tensor(obj, dtype=torch.float32).to(__pygrank_torch_config["device"])
134+
return torch.tensor(obj, dtype=torch.float32).to(
135+
__pygrank_torch_config["device"]
136+
)
120137
return torch.FloatTensor(obj).to(__pygrank_torch_config["device"])
121138

122139

@@ -132,9 +149,9 @@ def self_normalize(obj):
132149

133150

134151
def conv(signal, M):
135-
#if M.is_sparse:
152+
# if M.is_sparse:
136153
return torch.mv(M, signal)
137-
#return [email protected]((-1,1))
154+
# return [email protected]((-1,1))
138155

139156

140157
def length(x):

0 commit comments

Comments
 (0)