-
Notifications
You must be signed in to change notification settings - Fork 5
/
hookean_springs.py
30 lines (27 loc) · 971 Bytes
/
hookean_springs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
def make_incidence(indices, num_vertices):
# this creates a dense matrix (incidence), but
# sparse matrices or convolutions might be more appropriate
# in certain cases
num_springs = len(indices)
incidence = torch.zeros(num_springs, num_vertices, dtype=torch.float32)
for i, item in enumerate(indices):
i1, i2 = item
incidence[i, i1] = 1
incidence[i, i2] = -1
return incidence
class HookeanSprings(nn.Module):
def __init__(self, indices, l0, k, num_vertices):
super().__init__()
self.indices = indices
self.register_buffer("incidence", make_incidence(indices, num_vertices))
self.register_buffer("l0", l0)
self.register_buffer("k", k)
def energy(self, x):
d = self.incidence.mm(x)
q = d.pow(2).sum(1)
l = (q + 1e-6).sqrt()
dl = l - self.l0
e = 0.5 * (self.k * dl.pow(2)).sum()
return e