Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions pygem/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,100 @@ def __call__(self, src_pts):
H[:, self.n_control_points] = 1.0
H[:, -3:] = src_pts
return np.asarray(np.dot(H, self.weights))

class RBFSinglePrecision(RBF):
"""
Memory-optimized RBF that stores and computes large matrices in single
precision (float32). Other behavior matches `RBF`.

Use this class when memory is constrained; results remain in float32.
"""

def __init__(self,
original_control_points=None,
deformed_control_points=None,
func='gaussian_spline',
radius=0.5,
extra_parameter=None,
dtype=np.float32):

# store desired dtype for heavy arrays
self._dtype = dtype
# set basis and radius using parent property setters
self.basis = func
self.radius = radius

# initialize control points in single precision
if original_control_points is None:
self.original_control_points = np.array(
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
dtype=self._dtype)
else:
self.original_control_points = np.asarray(original_control_points,
dtype=self._dtype)

if deformed_control_points is None:
self.deformed_control_points = np.array(
[[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.],
[0., 1., 1.], [1., 0., 1.], [1., 1., 0.], [1., 1., 1.]],
dtype=self._dtype)
else:
self.deformed_control_points = np.asarray(deformed_control_points,
dtype=self._dtype)

# extra parameters (small), keep as provided
self.extra = extra_parameter if extra_parameter else dict()

# compute weights in single precision
self.weights = self._get_weights(self.original_control_points,
self.deformed_control_points)

def _get_weights(self, X, Y):
"""
Single-precision version of weight computation. Large matrices (H, rhs,
basis evaluations) use float32 to reduce memory usage.
"""
npts, dim = X.shape
size = npts + 3 + 1
H = np.zeros((size, size), dtype=self._dtype)

# compute pairwise distances then cast to single precision
dists = cdist(X.astype(np.float64), X.astype(np.float64)).astype(self._dtype)
basis_block = self.basis(dists, self.radius, **self.extra)
# ensure basis_block is single precision
basis_block = np.asarray(basis_block, dtype=self._dtype)
H[:npts, :npts] = basis_block

H[npts, :npts] = self._dtype(1.0)
H[:npts, npts] = self._dtype(1.0)
H[:npts, -3:] = X
H[-3:, :npts] = X.T

rhs = np.zeros((size, dim), dtype=self._dtype)
rhs[:npts, :] = Y

# solve in single precision
weights = np.linalg.solve(H.astype(self._dtype), rhs.astype(self._dtype))
return weights.astype(self._dtype)

def __call__(self, src_pts):
"""
Deform `src_pts`. Heavy temporary arrays are single precision.
"""
# ensure src_pts in single precision for computations
src = np.asarray(src_pts, dtype=self._dtype)
# recompute weights to keep consistency with parent API
self.weights = self._get_weights(self.original_control_points,
self.deformed_control_points)

H = np.zeros((src.shape[0], self.n_control_points + 3 + 1),
dtype=self._dtype)

dists = cdist(src.astype(np.float64), self.original_control_points.astype(np.float64)).astype(self._dtype)
basis_block = self.basis(dists, self.radius, **self.extra)
H[:, :self.n_control_points] = np.asarray(basis_block, dtype=self._dtype)
H[:, self.n_control_points] = self._dtype(1.0)
H[:, -3:] = src
result = np.dot(H, self.weights)
return np.asarray(result, dtype=self._dtype)
Loading