Skip to content

Commit fe227d6

Browse files
cdtwiggmeta-codesync[bot]
authored andcommitted
Migrate MPPCA and finalize numpy migration with Python bindings (#967)
Summary: Pull Request resolved: #967 Completed the migration from PyTorch tensors to numpy arrays for both MPPCA (Mixture of Probabilistic PCA) operations and kd-tree function bindings. This commit provides the Python bindings for both features. Key changes: - Implemented array_mppca for converting MPPCA models to numpy arrays - Changed Mppca.to_tensors() to Mppca.to_arrays() returning numpy arrays instead of torch tensors - Added Python bindings for find_closest_points, find_closest_points_with_normals, and find_closest_points_on_mesh using numpy arrays - Removed dependency on ATen/PyTorch tensor headers from geometry module - Updated test_posePrior to use numpy array indexing instead of tensor.select() Note: This commit includes bindings for both the kd-tree functions (from the previous commit) and MPPCA, as they share common infrastructure changes (removing tensor includes and adding array includes). Reviewed By: jeongseok-meta Differential Revision: D89891113
1 parent eb0431d commit fe227d6

File tree

5 files changed

+190
-16
lines changed

5 files changed

+190
-16
lines changed

pymomentum/cmake/build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ geometry_public_headers = [
120120
"geometry/array_blend_shape.h",
121121
"geometry/array_joint_parameters_to_positions.h",
122122
"geometry/array_kd_tree.h",
123+
"geometry/array_mppca.h",
123124
"geometry/array_parameter_transform.h",
124125
"geometry/array_skeleton_state.h",
125126
"geometry/array_skinning.h",
@@ -140,6 +141,7 @@ geometry_sources = [
140141
"geometry/array_blend_shape.cpp",
141142
"geometry/array_joint_parameters_to_positions.cpp",
142143
"geometry/array_kd_tree.cpp",
144+
"geometry/array_mppca.cpp",
143145
"geometry/array_parameter_transform.cpp",
144146
"geometry/array_skeleton_state.cpp",
145147
"geometry/array_skinning.cpp",
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "pymomentum/geometry/array_mppca.h"
9+
10+
#include <momentum/common/exception.h>
11+
#include <momentum/math/constants.h>
12+
#include <momentum/math/mppca.h>
13+
14+
#include <Eigen/Core>
15+
16+
namespace py = pybind11;
17+
18+
namespace pymomentum {
19+
20+
std::tuple<
21+
py::array_t<float>,
22+
py::array_t<float>,
23+
py::array_t<float>,
24+
py::array_t<float>,
25+
py::array_t<int>>
26+
mppcaToArrays(
27+
const momentum::Mppca& mppca,
28+
std::optional<const momentum::ParameterTransform*> paramTransform) {
29+
const auto nMixtures = mppca.p;
30+
const auto dimension = mppca.d;
31+
32+
// Create output arrays
33+
py::array_t<float> pi_array(static_cast<py::ssize_t>(nMixtures));
34+
py::array_t<float> mu_array(
35+
{static_cast<py::ssize_t>(nMixtures), static_cast<py::ssize_t>(dimension)});
36+
37+
auto pi = pi_array.mutable_unchecked<1>();
38+
auto mu = mu_array.mutable_unchecked<2>();
39+
40+
// Copy mu values
41+
for (py::ssize_t iMix = 0; iMix < nMixtures; ++iMix) {
42+
for (py::ssize_t d = 0; d < dimension; ++d) {
43+
mu(iMix, d) = mppca.mu(iMix, d);
44+
}
45+
}
46+
47+
// Process each mixture component
48+
MT_THROW_IF(mppca.Cinv.size() != nMixtures, "Invalid Mppca");
49+
50+
Eigen::VectorXf sigma_vec(nMixtures);
51+
int W_rank = 0; // Will be determined from first mixture
52+
53+
// First pass: determine W rank from first mixture
54+
{
55+
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXf> Cinv_eigs(mppca.Cinv[0]);
56+
Eigen::VectorXf C_eigenvalues = Cinv_eigs.eigenvalues().cwiseInverse();
57+
const float sigma2 = C_eigenvalues(C_eigenvalues.size() - 1);
58+
C_eigenvalues.array() -= sigma2;
59+
60+
W_rank = C_eigenvalues.size();
61+
for (Eigen::Index i = 0; i < C_eigenvalues.size(); ++i) {
62+
if (C_eigenvalues(i) < 0.0001) {
63+
W_rank = i;
64+
break;
65+
}
66+
}
67+
}
68+
69+
// Create W array with determined rank
70+
py::array_t<float> W_array(
71+
{static_cast<py::ssize_t>(nMixtures),
72+
static_cast<py::ssize_t>(W_rank),
73+
static_cast<py::ssize_t>(dimension)});
74+
auto W = W_array.mutable_unchecked<3>();
75+
76+
// Second pass: fill in all values
77+
for (Eigen::Index iMix = 0; iMix < nMixtures; ++iMix) {
78+
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXf> Cinv_eigs(mppca.Cinv[iMix]);
79+
80+
// Eigenvalues of the inverse are the inverse of the eigenvalues:
81+
Eigen::VectorXf C_eigenvalues = Cinv_eigs.eigenvalues().cwiseInverse();
82+
83+
// Assume that it's not full rank and hence the last eigenvalue is sigma^2.
84+
const float sigma2 = C_eigenvalues(C_eigenvalues.size() - 1);
85+
assert(sigma2 >= 0);
86+
sigma_vec[iMix] = std::sqrt(sigma2);
87+
88+
// (sigma^2*I + W^T*W) has eigenvalues (sigma^2 + lambda)
89+
// where the lambda are the eigenvalues for W^T*W (which we want):
90+
C_eigenvalues.array() -= sigma2;
91+
92+
// Fill W for this mixture
93+
for (Eigen::Index jComponent = 0; jComponent < W_rank; ++jComponent) {
94+
const float scale = std::sqrt(C_eigenvalues(jComponent));
95+
for (Eigen::Index d = 0; d < dimension; ++d) {
96+
W(iMix, jComponent, d) = scale * Cinv_eigs.eigenvectors()(d, jComponent);
97+
}
98+
}
99+
100+
const float C_logDeterminant = -Cinv_eigs.eigenvalues().array().log().sum();
101+
102+
// We have:
103+
// Rpre(c) = std::log(pi(c))
104+
// - 0.5 * C_logDeterminant
105+
// - 0.5 * static_cast<double>(d) * std::log(2.0 * PI));
106+
// so std::log(pi(c)) = Rpre(c) + 0.5 * C_logDeterminant + 0.5 *
107+
// d * std::log(2.0 * PI));
108+
const float log_pi = mppca.Rpre(iMix) + 0.5f * C_logDeterminant +
109+
0.5f * static_cast<float>(mppca.d) * std::log(2.0 * momentum::pi<float>());
110+
pi(iMix) = std::exp(log_pi);
111+
}
112+
113+
// Create sigma array
114+
py::array_t<float> sigma_array(static_cast<py::ssize_t>(nMixtures));
115+
auto sigma = sigma_array.mutable_unchecked<1>();
116+
for (py::ssize_t iMix = 0; iMix < nMixtures; ++iMix) {
117+
sigma(iMix) = sigma_vec[iMix];
118+
}
119+
120+
// Create parameter indices array
121+
py::array_t<int> param_indices_array(static_cast<py::ssize_t>(dimension));
122+
auto param_indices = param_indices_array.mutable_unchecked<1>();
123+
124+
for (py::ssize_t i = 0; i < dimension; ++i) {
125+
param_indices(i) = -1; // Default to -1
126+
}
127+
128+
if (paramTransform.has_value()) {
129+
for (Eigen::Index i = 0; i < mppca.names.size() && i < dimension; ++i) {
130+
auto paramIdx = (*paramTransform)->getParameterIdByName(mppca.names[i]);
131+
if (paramIdx != momentum::kInvalidIndex) {
132+
param_indices(i) = static_cast<int>(paramIdx);
133+
}
134+
}
135+
}
136+
137+
return {pi_array, mu_array, W_array, sigma_array, param_indices_array};
138+
}
139+
140+
} // namespace pymomentum

pymomentum/geometry/array_mppca.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <momentum/character/parameter_transform.h>
11+
#include <momentum/math/fwd.h>
12+
13+
#include <pybind11/numpy.h>
14+
15+
#include <optional>
16+
#include <tuple>
17+
18+
namespace pymomentum {
19+
20+
/// Convert an MPPCA model to numpy arrays.
21+
/// Returns (pi, mu, W, sigma, parameter_indices) where:
22+
/// - pi: [nMixtures] mixture weights
23+
/// - mu: [nMixtures, dimension] means
24+
/// - W: [nMixtures, rank, dimension] PCA basis vectors
25+
/// - sigma: [nMixtures] noise standard deviations
26+
/// - parameter_indices: [dimension] mapping to parameter transform indices (-1 if not found)
27+
std::tuple<
28+
pybind11::array_t<float>,
29+
pybind11::array_t<float>,
30+
pybind11::array_t<float>,
31+
pybind11::array_t<float>,
32+
pybind11::array_t<int>>
33+
mppcaToArrays(
34+
const momentum::Mppca& mppca,
35+
std::optional<const momentum::ParameterTransform*> paramTransform);
36+
37+
} // namespace pymomentum

pymomentum/geometry/geometry_pybind.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "pymomentum/geometry/array_blend_shape.h"
99
#include "pymomentum/geometry/array_joint_parameters_to_positions.h"
1010
#include "pymomentum/geometry/array_kd_tree.h"
11+
#include "pymomentum/geometry/array_mppca.h"
1112
#include "pymomentum/geometry/array_parameter_transform.h"
1213
#include "pymomentum/geometry/array_skeleton_state.h"
1314
#include "pymomentum/geometry/array_vertex_normals.h"
@@ -22,11 +23,6 @@
2223
#include "pymomentum/geometry/skeleton_pybind.h"
2324
#include "pymomentum/geometry/skin_weights_pybind.h"
2425

25-
// Keep tensor versions for functions without array equivalents yet
26-
#include "pymomentum/tensor_momentum/tensor_kd_tree.h"
27-
#include "pymomentum/tensor_momentum/tensor_mppca.h"
28-
#include "pymomentum/torch_bridge.h"
29-
3026
#include <momentum/character/blend_shape.h>
3127
#include <momentum/character/character.h>
3228
#include <momentum/character/character_utility.h>
@@ -367,8 +363,8 @@ Each PPCA model is a Gaussian with mean mu and covariance (sigma^2*I + W*W^T).
367363
.def_readonly("n_dimension", &mm::Mppca::d, R"(The dimension of the parameter space.)")
368364
.def_readonly("names", &mm::Mppca::names, R"(The names of the parameters.)")
369365
.def(
370-
"to_tensors",
371-
&mppcaToTensors,
366+
"to_arrays",
367+
&mppcaToArrays,
372368
R"(Return the parameters defining the mixture of probabilistic PCA models.
373369
374370
Each PPCA model a Gaussian N(mu, cov) where the covariance matrix is
@@ -377,16 +373,16 @@ Each PPCA model a Gaussian N(mu, cov) where the covariance matrix is
377373
Note that mu is a vector of length :meth:`dimension` and W is a matrix of dimension :meth:`dimension` x q
378374
where q is the dimensionality of the PCA subspace.
379375
380-
The resulting tensors are as follows:
376+
The resulting arrays are as follows:
381377
382-
* pi: a [n]-dimensional tensor containing the mixture weights. It sums to 1.
383-
* mu: a [n x d]-dimensional tensor containing the mean pose for each mixture.
384-
* weights: a [n x d x q]-dimensional tensor containing the q vectors spanning the PCA space.
385-
* sigma: a [n]-dimensional tensor containing the uniform part of the covariance matrix.
386-
* param_idx: a [d]-dimensional tensor containing the indices of the parameters.
378+
* pi: a [n]-dimensional array containing the mixture weights. It sums to 1.
379+
* mu: a [n x d]-dimensional array containing the mean pose for each mixture.
380+
* weights: a [n x q x d]-dimensional array containing the q vectors spanning the PCA space.
381+
* sigma: a [n]-dimensional array containing the uniform part of the covariance matrix.
382+
* param_idx: a [d]-dimensional array containing the indices of the parameters.
387383
388-
:param parameter_transform: An optional parameter transform used to map the parameters; if not present, then the param_idx tensor will be empty.
389-
:return: an tuple (pi, mean, weights, sigma, param_idx) for the Probabilistic PCA model.)",
384+
:param parameter_transform: An optional parameter transform used to map the parameters; if not present, then the param_idx array will be empty.
385+
:return: a tuple (pi, mean, weights, sigma, param_idx) for the Probabilistic PCA model.)",
390386
py::arg("parameter_transform") = std::optional<const mm::ParameterTransform*>())
391387
.def("get_mixture", &getMppcaModel, py::arg("i_model"))
392388
.def_static(

pymomentum/geometry/momentum_geometry.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <momentum/character/character.h>
1111
#include <momentum/math/mppca.h>
1212

13-
#include <ATen/ATen.h>
1413
#include <pybind11/numpy.h>
1514
#include <pybind11/pybind11.h>
1615

0 commit comments

Comments
 (0)