spVIPES.module.spVIPESmodule.spVIPESmodule#
- class spVIPES.module.spVIPESmodule.spVIPESmodule#
Bases:
BaseModuleClassPyTorch implementation of spVIPES variational autoencoder module.
This module implements the core variational autoencoder with Product of Experts (PoE) for shared-private latent space learning. It extends scVI’s underlying VAE architecture with multi-group integration capabilities and support for different PoE strategies.
- Parameters:
groups_lengths (
listofint) – List containing the number of features (genes) for each group/dataset.groups_obs_names (
list) – List of observation names for each group.groups_var_names (
list) – List of variable (gene) names for each group.groups_obs_indices (
list) – List of observation indices for each group.groups_var_indices (
list) – List of variable indices for each group.transport_plan (
torch.Tensor, optional) – Precomputed optimal transport plan matrix for PoE alignment.pair_data (
bool, defaultFalse) – Whether to use paired data for direct cell-to-cell correspondences.use_labels (
bool, defaultFalse) – Whether to use cell type labels for supervised PoE alignment.n_labels (
int, optional) – Number of unique cell type labels when using supervised alignment.n_batch (
int, default0) – Number of batches. If 0, no batch correction is performed.n_hidden (
int, default128) – Number of nodes per hidden layer in encoder and decoder networks.n_dimensions_shared (
int, default25) – Dimensionality of the shared latent space capturing common features.n_dimensions_private (
int, default10) – Dimensionality of private latent spaces capturing group-specific features.dropout_rate (
float, default0.1) – Dropout rate for neural networks to prevent overfitting.use_batch_norm (
bool, defaultTrue) – Whether to use batch normalization in neural networks.use_layer_norm (
bool, defaultFalse) – Whether to use layer normalization in neural networks.log_variational_inference (
bool, defaultTrue) – Whether to log-transform data before encoding for numerical stability.log_variational_generative (
bool, defaultTrue) – Whether to log-transform data before decoding for numerical stability.dispersion (
{"gene", "gene-batch", "gene-cell"}, default"gene") – Level at which to model the dispersion parameter in the negative binomial distribution.
Notes
This module is based on the scVI framework and implements the variational inference described in the spVIPES paper. The Product of Experts mechanism allows for flexible integration of multiple single-cell datasets with different feature sets.
- __init__(groups_lengths, groups_obs_names, groups_var_names, groups_obs_indices, groups_var_indices, transport_plan=None, pair_data=False, use_labels=False, n_labels=None, n_batch=0, n_hidden=128, n_dimensions_shared=25, n_dimensions_private=10, dropout_rate=0.1, use_batch_norm=True, use_layer_norm=False, log_variational_inference=True, log_variational_generative=True, dispersion='gene')#
Initialize the spVIPES variational autoencoder module.
This method sets up the neural network components including encoders and decoders for each group, and configures the Product of Experts mechanism based on the provided parameters. The module extends scVI’s VAE architecture for multi-group integration with shared-private latent spaces.
Notes
The initialization automatically configures the appropriate PoE strategy based on the provided inputs (transport_plan, use_labels, pair_data). The module creates separate encoders and decoders for each group while sharing the latent space structure for integration.
The dispersion parameter controls how the negative binomial distribution variance is modeled, with gene-level being the most common choice for single-cell data.
- inference(x, batch_index, groups, global_indices, **kwargs)#
Runs the encoder model.
- generative(private_stats, shared_stats, poe_stats, library, groups, batch_index)#
Runs the generative model.
- get_loadings(dataset, type_latent)#
Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder.
- Return type:
- loss(tensors_by_group, inference_outputs, generative_outputs, kl_weight=1.0)#
Loss function.