Skip to content

Commit d53ee09

Browse files
committed
initial commit
1 parent 258e502 commit d53ee09

File tree

5 files changed

+459
-0
lines changed

5 files changed

+459
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# ignore data
132+
data/

brainextractor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .main import BrainExtractor

brainextractor/helpers.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""
2+
Helper functions
3+
"""
4+
import numpy as np
5+
import trimesh
6+
from numba import jit
7+
from scipy.spatial import cKDTree # pylint: disable=no-name-in-module
8+
9+
def sphere(shape: list, radius: float, position: list):
10+
"""
11+
Creates a binary sphere
12+
"""
13+
# assume shape and position are both a 3-tuple of int or float
14+
# the units are pixels / voxels (px for short)
15+
# radius is a int or float in px
16+
semisizes = (radius,) * 3
17+
18+
# genereate the grid for the support points
19+
# centered at the position indicated by position
20+
grid = [slice(-x0, dim - x0) for x0, dim in zip(position, shape)]
21+
position = np.ogrid[grid]
22+
# calculate the distance of all points from `position` center
23+
# scaled by the radius
24+
arr = np.zeros(shape, dtype=float)
25+
for x_i, semisize in zip(position, semisizes):
26+
# this can be generalized for exponent != 2
27+
# in which case `(x_i / semisize)`
28+
# would become `np.abs(x_i / semisize)`
29+
arr += (x_i / semisize) ** 2
30+
31+
# the inner part of the sphere will have distance below 1
32+
return arr <= 1.0
33+
34+
def cartesian(arrays, out=None):
35+
"""
36+
Generate a cartesian product of input arrays
37+
"""
38+
39+
arrays = [np.asarray(x) for x in arrays]
40+
dtype = arrays[0].dtype
41+
42+
n = np.prod([x.size for x in arrays])
43+
if out is None:
44+
out = np.zeros([n, len(arrays)], dtype=dtype)
45+
46+
m = int(n / arrays[0].size)
47+
out[:,0] = np.repeat(arrays[0], m)
48+
if arrays[1:]:
49+
cartesian(arrays[1:], out=out[0:m, 1:])
50+
for j in range(1, arrays[0].size):
51+
out[j*m:(j+1)*m, 1:] = out[0:m, 1:]
52+
return out
53+
54+
def find_enclosure(surface: trimesh.Trimesh, data_shape: tuple):
55+
"""
56+
Finds all voxels inside of a surface
57+
58+
This function stores all the surface vertices in a k-d tree
59+
and uses it to quickly look up the closest vertex to each
60+
volume voxel in the image.
61+
62+
Once the closest vertex is found, a vector is created between
63+
the voxel location and the vertex. The resulting vector is dot
64+
product with the corresponding vertex normal. Negative values
65+
indicate that the voxel lies exterior to the surface (since it
66+
is anti-parallel to the vertex normal), while positive values
67+
indicate that they are interior to the surface (parallel to
68+
the vertex normal).
69+
"""
70+
# get vertex normals for each vertex on the surface
71+
normals = surface.vertex_normals
72+
73+
# create KDTree over surface vertices
74+
searcher = cKDTree(surface.vertices)
75+
76+
# get bounding box around surface
77+
max_loc = np.ceil(np.max(surface.vertices, axis=0)).astype(np.int64)
78+
min_loc = np.floor(np.min(surface.vertices, axis=0)).astype(np.int64)
79+
80+
# build a list of locations representing the volume grid
81+
# within the bounding box
82+
locs = cartesian([
83+
np.arange(min_loc[0], max_loc[0]),
84+
np.arange(min_loc[1], max_loc[1]),
85+
np.arange(min_loc[2], max_loc[2])])
86+
87+
# find the nearest vertex to each voxel
88+
# searcher.query returns a list of vertices corresponding
89+
# to the closest vertex to the given voxel location
90+
_, nearest_idx = searcher.query(locs, n_jobs=6)
91+
nearest_vertices = surface.vertices[nearest_idx]
92+
93+
# get the directional vector from each voxel location to it's nearest vertex
94+
direction_vectors = nearest_vertices - locs
95+
96+
# find it's direction by taking the dot product with vertex normal
97+
# this is done row-by-row between directional vectors and the vertex normals
98+
dot_products = np.einsum('ij,ij->i', direction_vectors, normals[nearest_idx])
99+
100+
# get the interior (where dot product is > 0)
101+
interior = (dot_products > 0).reshape((max_loc - min_loc).astype(np.int64))
102+
103+
# create mask
104+
mask = np.zeros(data_shape)
105+
mask[min_loc[0]:max_loc[0],min_loc[1]:max_loc[1],min_loc[2]:max_loc[2]] = interior
106+
107+
# return the mask
108+
return mask
109+
110+
@jit(nopython=True, cache=True)
111+
def closest_integer_point(vertex: np.ndarray):
112+
"""
113+
Gives the closest integer point based on euclidena distance
114+
"""
115+
# get neighboring grid points to search
116+
x = vertex[0]; y = vertex[1]; z = vertex[2]
117+
x0 = np.floor(x); y0 = np.floor(y); z0 = np.floor(z)
118+
x1 = x0 + 1; y1 = y0 + 1; z1 = z0 + 1
119+
120+
# initialize min euclidean distance
121+
min_euclid = 99
122+
123+
# loop through each neighbor point
124+
for i in [x0, x1]:
125+
for j in [y0, y1]:
126+
for k in [z0, z1]:
127+
# compare coordinate and store if min euclid distance
128+
coords = np.array([i, j, k])
129+
dist = l2norm(vertex - coords)
130+
if dist < min_euclid:
131+
min_euclid = dist
132+
final_coords = coords
133+
134+
# return the final coords
135+
return final_coords.astype(np.int64)
136+
137+
@jit(nopython=True, cache=True)
138+
def bresenham3d(v0: np.ndarray, v1: np.ndarray):
139+
"""
140+
Bresenham's algorithm for 3-D line
141+
"""
142+
# initialize axis differences
143+
144+
dx = np.abs(v1[0] - v0[0])
145+
dy = np.abs(v1[1] - v0[1])
146+
dz = np.abs(v1[2] - v0[2])
147+
xs = 1 if (v1[0] > v0[0]) else -1
148+
ys = 1 if (v1[1] > v0[1]) else -1
149+
zs = 1 if (v1[2] > v0[2]) else -1
150+
151+
# determine the driving axis
152+
if dx >= dy and dx >= dz:
153+
d0 = dx; d1 = dy; d2 = dz
154+
s0 = xs; s1 = ys; s2 = zs
155+
a0 = 0; a1 = 1; a2 = 2
156+
elif dy >= dx and dy >= dz:
157+
d0 = dy; d1 = dx; d2 = dz
158+
s0 = ys; s1 = xs; s2 = zs
159+
a0 = 1; a1 = 0; a2 = 2
160+
elif dz >= dx and dz >= dy:
161+
d0 = dz; d1 = dx; d2 = dy
162+
s0 = zs; s1 = xs; s2 = ys
163+
a0 = 2; a1 = 0; a2 = 1
164+
165+
# create line array
166+
line = np.zeros((d0 + 1, 3), dtype=np.int64)
167+
line[0] = v0
168+
169+
# get points
170+
p1 = 2*d1 - d0
171+
p2 = 2*d2 - d0
172+
for i in range(d0):
173+
c = line[i].copy()
174+
c[a0] += s0
175+
if (p1 >= 0):
176+
c[a1] += s1
177+
p1 -= 2 * d0
178+
if (p2 >= 0):
179+
c[a2] += s2
180+
p2 -= 2 * d0
181+
p1 += 2 * d1
182+
p2 += 2 * d2
183+
line[i + 1] = c
184+
185+
# return list
186+
return line
187+
188+
@jit(nopython=True, cache=True)
189+
def l2norm(vec: np.ndarray):
190+
"""
191+
Computes the l2 norm for 3d vector
192+
"""
193+
return np.sqrt(vec[0]**2 + vec[1]**2 + vec[2]**2)

0 commit comments

Comments
 (0)