Alex J. Cannon ([email protected])
MST.PMDN
is a 'torch for R' implementation of a distributional regression model based on a deep Multivariate Skew t-Parsimonious Mixture Density Network (MST-PMDN). The MST-PMDN framework represents complicated joint output distributions as mixtures of MST ('sn') components. A volume (L)-shape (A)-orientation (D) (LAD) eigenvalue decomposition parameterization provides a tractable, interpretable, and parsimonious representation of the scale matrices of the MST components, while explicit modeling of skewness and heavy tails can represent asymmetric behavior and tail dependence observed in real-world data (e.g., compound events and extremes). Overall, this provides an MST likelihood-based deep generative model.
In an MST-PMDN model, parameters of a mixture of multivariate skew t distributions that describe a multivariate output are estimated by training a deep learning model with two multi-modal input branches, one for tabular inputs and the other for (optional) image inputs. The two branches are provided as user-defined 'torch' modules. Outputs from each are concatenated and passed through a dense fusion network, which then leads to the MST-PMDN head. In the absence of both branches, the tabular inputs are fed directly into the dense network. The overall network architecture is shown here.
Following the approach used in model-based clustering ('mclust'), scale matrices in the MST-PMDN head are represented using an LAD eigen-decomposition parameterization. LAD attributes, the nu (or degrees of freedom) parameter (n), and the alpha (or skewness) parameter (s) can be forced to be "V"
ariable or "E"
qual between mixture components (plus "I"
dentity for A and D). For n and s parameters, the model can also be constrained to emulate a multivariate "N"
ormal (or Gaussian) distribution. Different model types are specified by setting the argument constraint = "EIINN"
, "VEVEV"
, etc. where each letter position in the argument corresponds, respectively, to each of the LADns attributes. In the case of n, users can specify "F"
ixed values for nu and pass a fixed_nu
vector as an additional argument. If an element of fixed_nu
is set to NA
, then the value of nu for this component is learned by the network. Furthermore, values of mu (or means) (m), pi (or mixing coefficients) (x), volume-shape-orientation attributes (LAD), nu (n), and skewness (s) for the mixtures can be made to be independent of inputs by specifying any combination of constant_attr = "m"
, "mx"
, ..., "LADmxns"
.
By combining appropriate values of constraint
and constant_attr
, MST-PMDN can emulate the Gaussian finite mixture models implemented by 'mclust', i.e., for unconditional density estimation or model-based clustering:
mclust model | Description | MST-PMDN constraint = |
MST-PMDN constant_attr = |
---|---|---|---|
EII | spherical, equal volume | "EIINN" |
"LADmx" |
VII | spherical, unequal volume | "VIINN" |
"LADmx" |
EEI | diagonal, equal volume and shape | "EEINN" |
"LADmx" |
VEI | diagonal, varying volume, equal shape | "VEINN" |
"LADmx" |
EVI | diagonal, equal volume, varying shape | "EVINN" |
"LADmx" |
VVI | diagonal, varying volume and shape | "VVINN" |
"LADmx" |
EEE | ellipsoidal, equal volume, shape, and orientation | "EEENN" |
"LADmx" |
EEV | ellipsoidal, equal volume and equal shape | "EEVNN" |
"LADmx" |
EVE | ellipsoidal, equal volume and orientation | "EVENN" |
"LADmx" |
VEE | ellipsoidal, equal shape and orientation | "VEENN" |
"LADmx" |
VEV | ellipsoidal, equal shape | "VEVNN" |
"LADmx" |
VVE | ellipsoidal, equal orientation | "VVENN" |
"LADmx" |
EVV | ellipsoidal, equal volume | "EVVNN" |
"LADmx" |
VVV | ellipsoidal, varying volume, shape, and orientation | "VVVNN" |
"LADmx" |
A comparison between 'mclust' and MST-PMDN with the constraints in the table above is shown here for the 'iris' dataset. Similarly, if the constraint on the nu parameter (n) is loosened (e.g., constraint = "VVVEN"
with constant_attr = "LADmxn"
), MST-PMDN can emulate model-based multivariate t clustering models provided by 'teigen'. Going one step further, removing the constraint on the skewness parameter (s) (e.g., constraint = "VVVEE"
with constant_attr = "LADmxns"
) implements model-based multivariate skew t clustering ('EMMIXuskew').
While it can be used for model-based density estimation and clustering tasks, the primary purpose of the MST.PMDN
package is to implement likelihood-based deep generative models. With unconstrained or partially constrained constant_attr
, the MST-PMDN framework allows parameters of the mixture of multivariate Gaussian, t, or skew t distributions to depend on tabular and image covariates via user-specified torch
modules. An example of this use case, here demonstrated through simultaneous prediction of significant wave height and storm surge, is provided below.
remotes::install_github("aljaca/MST.PMDN")
or
install.packages("https://github.com/aljaca/MST.PMDN/archive/refs/tags/v0.1.0.tar.gz")
library(MST.PMDN)
device <- ifelse(cuda_is_available(), "cuda", "cpu")
set.seed(1)
torch_manual_seed(1)
# Significant wave height, storm surge, and covariate data from Roberts Bank
x <- wave_surge$x # x and x_image should be appropriateley scaled, e.g.,
x_image <- wave_surge$x_image # standardized to zero mean and unit standard deviation
y <- wave_surge$y
# The TabularModule takes an input vector of length input_dim, runs it
# through two dense layers (input_dim→32 and 32→16) each with
# batch-norm (BN), ReLU and 50 % dropout, then applies a final 16→16
# linear layer plus ReLU to produce a 16-dimensional output.
tabular_module <- nn_module(
"TabularModule",
initialize = function(
input_dim,
hidden_dims,
output_dim,
dropout_rate
) {
# Number of hidden layers
if (is.null(hidden_dims) || length(hidden_dims) == 0) {
# No hidden layers
self$n_hidden_layers <- 0
self$hidden_dims <- c()
} else if (!is.vector(hidden_dims) && !is.list(hidden_dims)) {
# Single hidden size passed, wrap into vector
self$hidden_dims <- c(hidden_dims)
self$n_hidden_layers <- length(self$hidden_dims)
} else {
# Vector or list of hidden sizes
self$hidden_dims <- hidden_dims
self$n_hidden_layers <- length(self$hidden_dims)
}
# Store output size and dropout rate
self$output_dim <- output_dim
self$dropout_rate <- dropout_rate
# Module lists for linear layers, batch-norms, dropouts
self$layers <- nn_module_list()
self$bns <- nn_module_list()
if (self$dropout_rate > 0) {
self$dropouts <- nn_module_list()
}
# Build hidden layers
current_dim <- input_dim
if (self$n_hidden_layers > 0) {
for (i in seq_len(self$n_hidden_layers)) {
# Linear transform
self$layers$append(
nn_linear(current_dim, self$hidden_dims[[i]])
)
# Batch normalization on hidden size
self$bns$append(
nn_batch_norm1d(self$hidden_dims[[i]])
)
# Optional dropout after activation
if (self$dropout_rate > 0) {
self$dropouts$append(
nn_dropout(p = self$dropout_rate)
)
}
# Update input size for next layer
current_dim <- self$hidden_dims[[i]]
}
}
# Final linear layer: last hidden (or input) → output_dim
self$layers$append(
nn_linear(current_dim, output_dim)
)
},
forward = function(x) {
# Pass through each hidden layer
for (i in seq_len(self$n_hidden_layers)) {
x <- self$layers[[i]](x) # linear
x <- self$bns[[i]](x) # batch-norm
x <- nnf_relu(x) # activation
# Apply dropout if configured
if (self$dropout_rate > 0 && !is.null(self$dropouts[[i]])) {
x <- self$dropouts[[i]](x)
}
}
# Final projection and activation
x <- self$layers[[length(self$layers)]](x)
x <- nnf_relu(x)
x
}
)
tabular_mod <- tabular_module(
input_dim = ncol(x),
hidden_dims = c(32, 16),
output_dim = 16,
dropout_rate = 0.5
)
# The ImageModule accepts a 2×32×32 image, applies a 3×3 conv (2→16)
# with BN, ReLU and 2×2 max-pool (→16×16), repeats with a 16→32 conv
# + BN, ReLU and max-pool (→8×8), flattens the 32×8×8 tensor to 2048
# units, and then projects it to 32 features via a linear layer, BN,
# and ReLU. Weight penalty (wd_image) is applied during training.
image_module <- nn_module(
"ImageModule",
initialize = function(
in_channels,
img_size,
conv_channels,
kernel_size = 3,
pool_kernel = 2,
output_dim = 32
) {
# Store output dim
self$output_dim <- output_dim
# Build conv stack
self$n_conv <- length(conv_channels)
self$convs <- nn_module_list()
self$bn_conv <- nn_module_list()
# Track spatial dim through conv+pool
spatial <- img_size
pad <- floor(kernel_size / 2)
for (i in seq_along(conv_channels)) {
in_ch <- if (i == 1) in_channels else conv_channels[i-1]
out_ch <- conv_channels[i]
# conv keeps spatial size (with padding)
self$convs$append(
nn_conv2d(
in_channels = in_ch,
out_channels = out_ch,
kernel_size = kernel_size,
padding = pad
)
)
self$bn_conv$append(nn_batch_norm2d(out_ch))
# Pooling halves spatial dims
spatial <- floor(spatial / pool_kernel)
}
# Store pooling layer and computed flatten_dim
self$pool <- nn_max_pool2d(kernel_size = pool_kernel)
self$flatten_dim <- tail(conv_channels, 1) * spatial * spatial
# Final head: linear( flatten_dim → output_dim ) + BN
self$fc <- nn_linear(self$flatten_dim, output_dim)
self$bn_fc <- nn_batch_norm1d(output_dim)
},
forward = function(x) {
# conv → BN → ReLU → pool
for (i in seq_len(self$n_conv)) {
x <- self$convs[[i]](x)
x <- self$bn_conv[[i]](x)
x <- nnf_relu(x)
x <- self$pool(x)
}
# Flatten and head
x <- torch_flatten(x, start_dim = 2)
x <- self$fc(x)
nnf_relu(self$bn_fc(x))
}
)
image_mod <- image_module(
in_channels = dim(x_image)[2],
img_size = dim(x_image)[3],
conv_channels = c(16, 32),
kernel_size = 3,
pool_kernel = 2,
output_dim = 32
)
# Define the fusion network, MST-PMDN head, and training setup
# Note: hyperparameters and number of epochs are not optimized
hidden_dim <- c(64, 32) # Hidden nodes in fusion network
drop_hidden <- 0.1 # Dropout for fusion network
n_mixtures <- 2 # 2 components in the MST mixture model
constraint <- "VVIFN" # LAD = "V"ariable-"V"ariable-"I"dentity; nu = 1 component "F"ixed; skewness = "N"ormal
fixed_nu <- c(50, NA) # nu = 50 for 1st component (i.e., approximately "N"ormal); "V"ariable for 2nd
constant_attr <- "" # All non-normal component attributes are free to vary with covariates
wd_tabular <- 0 # Weight decay for tabular module
wd_image <- 0.01 # Weight decay for image module
epochs <- 20 # Number of training epochs
lr <- 1e-3 # Initial Adam learning rate
batch_size <- 32 # Batch size
# Model training
fit <- train_mst_pmdn(
inputs = x,
outputs = y,
hidden_dim = hidden_dim,
drop_hidden = drop_hidden,
n_mixtures = n_mixtures,
constraint = constraint,
fixed_nu = fixed_nu,
constant_attr = constant_attr,
epochs = epochs,
lr = lr,
batch_size = batch_size,
wd_tabular = wd_tabular,
wd_image = wd_image,
image_inputs = x_image,
image_module = image_mod,
tabular_module = tabular_mod,
checkpoint_path = "wave_surge_checkpoint.pt",
device = device
)
# Model inference
pred <- predict_mst_pmdn(
fit$model,
new_inputs = x,
image_inputs = x_image,
device = device
)
print(names(pred))
print(pred$pi[1:3, ])
print(pred$mu[1:3, , ])
print(pred$nu[1:3, ])
# Draw samples
df_samples <- sample_mst_pmdn_df(
pred,
num_samples = 1,
device = device
)
print(head(df_samples))
Output from a more complete example using an extended dataset at the same location is shown here, here, and here.
The deep MST-PMDN implementation consists of the following key functions and modules:
- Purpose: Calculates a differentiable approximation of the univariate Student's t cumulative distribution function (CDF).
- Method: For
nu >= nu_switch
, transforms the input quantilez
using a scaling factor derived from the degrees of freedomnu
, then computes the standard normal CDF of the result using the error function (erf
). Otherwise, usespt
from R and manually inserts a gradient forz
into the computation graph and uses a finite difference approximation fornu
. Alternatively, can numerically integrate a torch-compatible probability density functiont_pdf_int
, which will be faster on GPUs. - Context: Switches between the fast approximation and the slow
pt
(or numerical integration) calculation. Essential for the loss function's skewness calculation.
- Purpose: Generates random samples from a Gamma distribution using
torch
. - Method: Wraps R's
rgamma
function, vectorizes it usingmapply
, and converts the output to atorch
tensor on the specified device. - Context: Used within the
sample_mst_pmdn
function to generate the scaling variable needed for sampling from the t-distribution component of the skew-t.
- Purpose: Constructs a batch of orthogonal matrices (representing rotation/orientation
D
). - Method: Uses the matrix exponential of a skew-symmetric matrix, where the input
params
parameterize the upper triangle of the skew-symmetric matrix. - Context: Used in the main model (
define_mst_pmdn
) to generate the orientation componentD
of the LAD decomposition when orientation is not fixed to the identity matrix.
- Purpose: Initializes the component mean parameters (
mu
) using k-means clustering. - Method: Applies k-means to the training output data to find initial centroids. These centroids initialize either the
model$mu
parameters (if constant) or the bias of themodel$fc_mu
layer (if network-dependent), setting initial weights to zero. - Context: A heuristic to provide a potentially better starting point for training compared to random initialization, aiming for faster convergence.
- Purpose: Implements a linear layer with weight normalization.
- Method: Decomposes the weight matrix
W
into a directionV
and a magnitudeg
, learning these instead ofW
directly. - Context: Used for most linear layers within the network architecture (hidden layers and parameter prediction heads) to potentially improve training stability and convergence speed.
- Purpose: Initializes the parameters (
V
,g
) of aweight_norm_linear
layer. - Method: Uses Kaiming (He) normal initialization for the direction
V
and sets the initial magnitudeg
accordingly. - Context: Applied recursively to the model to ensure proper initialization of all weight-normalized layers.
- Purpose: Defines the main MST-PMDN neural network architecture.
- Method:
- Processes optional image and tabular inputs through dedicated modules or uses raw inputs.
- Concatenates features and passes them through a hidden MLP using
weight_norm_linear
layers. - Predicts mixture parameters (
pi
,mu
,L
,A
,D
,nu
,alpha
) using separate output heads (mostlyweight_norm_linear
ornn_parameter
if constant). - Applies constraints (Variable, Equal, Identity, Normal approx., Fixed) to parameters based on configuration.
- Constructs the full scale matrix
Sigma = L * D * diag(A) * D^T
and computes its Cholesky decomposition (scale_chol
) for each component.
- Output: Returns a list containing all mixture parameters (
pi
,mu
,scale_chol
,nu
,alpha
) and LAD components (L
,A
,D
), batched appropriately.
- Purpose: Computes the negative log-likelihood (NLL) loss.
- Method:
- For each data point and mixture component
k
:- Calculates residuals:
diff = target - mu_k
. - Standardizes residuals:
v = scale_chol_k^{-1} * diff
. - Calculates squared Mahalanobis distance:
maha = ||v||^2
. - Calculates the log-PDF of the symmetric multivariate t-distribution part using
maha
,log_det(Sigma_k)
, andnu_k
. - Calculates the skewness adjustment term
log(2 * T_CDF(alpha_k^T w, df=nu_k+d))
, wherew
is proportional tov
, usingt_cdf
.
- Calculates residuals:
- Combines component log-densities using mixture weights
pi
vialogsumexp
. - Returns the mean NLL over the batch.
- For each data point and mixture component
- Purpose: On-device generation of random samples from the predicted mixture distribution.
- Method:
- Samples component indices based on
pi
. - Gathers parameters for the selected components.
- Generates t-distribution scaling factors
W
usingsample_gamma
. - Generates a standard multivariate skew-normal sample
X
based on the component'salpha
(viadelta
). - Transforms the standard sample
X
to the output space:Y = mu_s + W * (scale_chol_s @ X)
.
- Samples component indices based on
- Output: Returns a list with
samples
- a torch tensor of shape[S, B, d]
, whereS
isnum_samples
,B
is the batch size (rows of the predictor matrix), andd
is the response dimension.components
- a torch tensor of shape[S, B]
giving the 1-based component label (1..G
) used for each draw.
- Purpose: Generates random samples from the predicted mixture distribution and returns a formatted R data frame.
- Method:
- Samples component indices based on
pi
. - Gathers parameters for the selected components.
- Generates t-distribution scaling factors
W
usingsample_gamma
. - Generates a standard multivariate skew-normal sample
X
based on the component'salpha
(viadelta
). - Transforms the standard sample
X
to the output space:Y = mu_s + W * (scale_chol_s @ X)
.
- Samples component indices based on
- Output: A data frame with
num_samples * batch_size
rows containing- simulated response variables in columns
V1 ... Vd
; row
- the index (1..B
) of the predictor row that generated the draw;draw
- the draw number (1..num_samples
) for that predictor row;comp
- a factor giving the 1-based component label (1..G
).
- simulated response variables in columns
- Purpose: Manages the model training process.
- Method: Includes data loading, model/optimizer setup (with k-means init), training loop (loss calculation, backpropagation, optimization), validation, learning rate scheduling, checkpointing, and early stopping. Handles optional image inputs correctly.
- Output: Trained model, loss history, and training/validation indices.
- Purpose: Performs inference using the trained model.
- Method: Runs a forward pass on new inputs in evaluation mode (
torch_no_grad()
). - Output: Raw model output list containing mixture parameters for the new inputs.
Ambrogioni, L., Güçlü, U., van Gerven, M. A., & Maris, E. (2017). The kernel mixture network: A nonparametric method for conditional density estimation of continuous random variables. arXiv:1705.07111.
Andrews, J. L., & McNicholas, P. D. (2012). Model-based clustering, classification, and discriminant analysis via mixtures of multivariate t-distributions: the t EIGEN family. Statistics and Computing, 22, 1021-1029.
Azzalini, A., & Capitanio, A. (2003). Distributions generated by perturbation of symmetry with emphasis on a multivariate skew t-distribution. Journal of the Royal Statistical Society Series B: Statistical Methodology, 65(2), 367-389.
Andrews, J. L., Wickins, J. R., Boers, N. M., & McNicholas, P. D. (2018). teigen: An R package for model-based clustering and classification via the multivariate t distribution. Journal of Statistical Software, 83, 1-32.
Banfield, J. D., & Raftery, A. E. (1993). Model-based Gaussian and non-Gaussian clustering. Biometrics, 803-821.
Celeux, G., & Govaert, G. (1995). Gaussian parsimonious clustering models. Pattern Recognition, 28(5), 781-793.
Falbel D., & Luraschi, J. (2025). torch: Tensors and Neural Networks with 'GPU' Acceleration. R package version 0.14.2, https://github.com/mlverse/torch, https://torch.mlverse.org/docs.
Fraley, C., & Raftery, A. E. (2002). Model-based clustering, discriminant analysis, and density estimation. Journal of the American Statistical Association, 97(458), 611-631.
Fraley, C., & Raftery, A. E. (1998). How many clusters? Which clustering method? Answers via model-based cluster analysis. The Computer Journal, 41(8), 578-588.
Lee, S., & McLachlan, G. J. (2014). Finite mixtures of multivariate skew t-distributions: some recent and new results. Statistics and Computing, 24, 181-202.
Kingma, D. P., & Ba, J. (2015). Adam: a method for stochastic optimization. Proceedings of the 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA. arXiv:1412.6980
Klein, N. (2024). Distributional regression for data analysis. Annual Review of Statistics and Its Application, 11:321-346.
Peel, D., & McLachlan, G.J. (2000). Robust mixture modelling using the t distribution. Statistics and Computing 10, 339–348.
Srucca, L., Fop, M., Murphy, T. B., & Raftery, A. E. (2016). mclust 5: Clustering, classification and density estimation using Gaussian finite mixture models. The R Journal, 8(1), 289-317.
Williams, P. M. (1996). Using neural networks to model conditional multivariate densities. Neural Computation, 8(4), 843-854.