API Reference#
Core Classes#
spVIPES Model#
The main model class for shared-private variational inference.
Implementation of the spVIPES model. |
- class spVIPES.model.spvipes.spVIPES#
Bases:
MultiGroupTrainingMixin,BaseModelClassImplementation of the spVIPES model.
spVIPES (shared-private Variational Inference with Product of Experts and Supervision) is a method for integrating multi-group single-cell datasets using a shared-private latent space approach. The model learns both shared representations (common across groups) and private representations (group-specific) through a Product of Experts (PoE) framework.
- Parameters:
adata (
AnnData) – AnnData object that has been registered viasetup_anndata().n_hidden (
int, default128) – Number of nodes per hidden layer in the neural networks.n_dimensions_shared (
int, default25) – Dimensionality of the shared latent space. This space captures features common across all groups/datasets.n_dimensions_private (
int, default10) – Dimensionality of the private latent spaces. Each group gets its own private latent space of this dimensionality.dropout_rate (
float, default0.1) – Dropout rate for neural networks to prevent overfitting.**model_kwargs – Additional keyword arguments passed to the underlying module.
Examples
Basic usage with cell type labels:
>>> import spVIPES >>> adata = spVIPES.data.prepare_adatas({"dataset1": dataset1, "dataset2": dataset2}) >>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="groups", label_key="cell_type") >>> model = spVIPES.model.spVIPES(adata) >>> model.train() >>> latents = model.get_latent_representation()
Usage with optimal transport:
>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="groups", transport_plan_key="transport_plan") >>> model = spVIPES.model.spVIPES(adata) >>> model.train()
Notes
We recommend setting n_dimensions_private < n_dimensions_shared for optimal performance
The model automatically selects the appropriate PoE variant based on provided inputs
GPU acceleration is strongly recommended for large datasets
- __init__(adata, n_hidden=128, n_dimensions_shared=25, n_dimensions_private=10, dropout_rate=0.1, **model_kwargs)#
- classmethod setup_anndata(cls, adata, groups_key, match_clusters=False, transport_plan_key=None, label_key=None, batch_key=None, layer=None, **kwargs)#
Set up AnnData object for spVIPES model.
This method registers the AnnData object with the model, configuring the appropriate data fields and PoE strategy based on the provided parameters. The method automatically determines whether to use label-based PoE, optimal transport PoE, or cluster-based PoE.
- Parameters:
adata (
AnnData) – Annotated data object containing the single-cell data to be integrated.groups_key (
str) – Key inadata.obsthat defines the grouping of cells (e.g., dataset, batch, condition). This determines which cells belong to which group for integration.match_clusters (
bool, defaultFalse) – Whether to match clusters when using optimal transport. If True, enables cluster-based PoE which automatically matches cell clusters between groups.transport_plan_key (
str, optional) – Key inadata.unscontaining the precomputed optimal transport plan. If provided, enables optimal transport PoE for data integration.label_key (
str, optional) – Key inadata.obscontaining cell type labels. If provided, enables label-based PoE which uses supervised alignment based on cell types.batch_key (
str, optional) – Key inadata.obsfor batch information to enable batch effect correction.layer (
str, optional) – Key inadata.layersto use for the expression data. If None, usesadata.X.**kwargs – Additional keyword arguments passed to the parent setup method.
- Return type:
- Returns:
None The method modifies the AnnData object in place and registers it with the model.
Notes
Priority of PoE strategies (when multiple options are available): 1. Label-based PoE (if
label_keyis provided) 2. Optimal transport PoE (iftransport_plan_keyis provided) 3. Cluster-based PoE (ifmatch_clusters=True)Examples
Basic setup with groups only:
>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset")
Setup with cell type supervision:
>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset", label_key="cell_type")
Setup with optimal transport:
>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset", transport_plan_key="transport_matrix")
- get_latent_representation(group_indices_list, adata=None, indices=None, normalized=False, give_mean=True, mc_samples=5000, batch_size=None, drop_last=None)#
Return the latent representation for each cell.
- Parameters:
group_indices_list (
list[list[int]]) – List of lists containing the indices of cells in each of the groups used as input for spVIPES.adata (
Optional[AnnData] (default:None)) – AnnData object with equivalent structure to initial AnnData. IfNone, defaults to the AnnData object used to initialize the model.indices (
Optional[Sequence[int]] (default:None)) – Indices of cells in adata to use. IfNone, all cells are used.normalized (
bool(default:False)) – Whether to return the normalized cell embedding (softmaxed) or notgive_mean (
bool(default:True)) – Give mean of distribution or sample from it.mc_samples (
int(default:5000)) – For distributions with no closed-form mean (e.g.,logistic normal), how many Monte Carlo samples to take for computing mean.batch_size (
Optional[int] (default:None)) – Minibatch size for data loading into model. Defaults toscvi.settings.batch_size.drop_last (
Optional[bool] (default:None)) – Whether to drop the last incomplete batch. If None, automatically determined based on whether using paired PoE (True for paired, False for others).
- Return type:
- Returns:
Low-dimensional topic for each cell.
- get_loadings()#
Extract per-gene weights in the linear decoder.
Shape is genes by
n_latent.- Return type:
- train(group_indices_list, batch_size=128, max_epochs=None, use_gpu=None, train_size=0.9, validation_size=None, early_stopping=False, plan_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, **trainer_kwargs)#
Train a multigroup spVIPES model.
This method trains the model using a custom data splitter that handles multiple groups of cells separately while maintaining the shared-private latent space learning objective.
- Parameters:
group_indices_list (
list[list[int]]) – List of indices corresponding to each group of samples. Each inner list contains the indices for cells belonging to that specific group.max_epochs (
int, optional) – Number of passes through the dataset. If None, defaults tonp.min([round((20000 / n_cells) * 400), 400]).use_gpu (
str,int,bool, optional) – GPU usage specification. Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., “cuda:0”), or use CPU (if False).train_size (
float, default0.9) – Size of training set in the range [0.0, 1.0].validation_size (
float, optional) – Size of the validation set. If None, defaults to1 - train_size. Iftrain_size + validation_size < 1, the remaining cells belong to the test set.batch_size (
int, default128) – Mini-batch size to use during training.early_stopping (
bool, defaultFalse) – Whether to perform early stopping. Additional arguments can be passed in**trainer_kwargs.plan_kwargs (
dict, optional) – Keyword arguments for the training plan. Arguments passed totrain()will overwrite values present inplan_kwargs, when appropriate.n_steps_kl_warmup (
int, optional) – Number of training steps for KL warmup. Takes precedence over n_epochs_kl_warmup.n_epochs_kl_warmup (
int, default400) – Number of epochs for KL divergence warmup.**trainer_kwargs – Additional keyword arguments for the trainer.
- Return type:
- Returns:
None The model is trained in-place.
Notes
This method uses a specialized MultiGroupDataSplitter that ensures proper handling of multiple cell groups during training, maintaining the integrity of the shared-private latent space learning.
spVIPES Module#
The PyTorch Lightning module implementing the variational autoencoder.
PyTorch implementation of spVIPES variational autoencoder module. |
- class spVIPES.module.spVIPESmodule.spVIPESmodule#
Bases:
BaseModuleClassPyTorch implementation of spVIPES variational autoencoder module.
This module implements the core variational autoencoder with Product of Experts (PoE) for shared-private latent space learning. It extends scVI’s underlying VAE architecture with multi-group integration capabilities and support for different PoE strategies.
- Parameters:
groups_lengths (
listofint) – List containing the number of features (genes) for each group/dataset.groups_obs_names (
list) – List of observation names for each group.groups_var_names (
list) – List of variable (gene) names for each group.groups_obs_indices (
list) – List of observation indices for each group.groups_var_indices (
list) – List of variable indices for each group.transport_plan (
torch.Tensor, optional) – Precomputed optimal transport plan matrix for PoE alignment.pair_data (
bool, defaultFalse) – Whether to use paired data for direct cell-to-cell correspondences.use_labels (
bool, defaultFalse) – Whether to use cell type labels for supervised PoE alignment.n_labels (
int, optional) – Number of unique cell type labels when using supervised alignment.n_batch (
int, default0) – Number of batches. If 0, no batch correction is performed.n_hidden (
int, default128) – Number of nodes per hidden layer in encoder and decoder networks.n_dimensions_shared (
int, default25) – Dimensionality of the shared latent space capturing common features.n_dimensions_private (
int, default10) – Dimensionality of private latent spaces capturing group-specific features.dropout_rate (
float, default0.1) – Dropout rate for neural networks to prevent overfitting.use_batch_norm (
bool, defaultTrue) – Whether to use batch normalization in neural networks.use_layer_norm (
bool, defaultFalse) – Whether to use layer normalization in neural networks.log_variational_inference (
bool, defaultTrue) – Whether to log-transform data before encoding for numerical stability.log_variational_generative (
bool, defaultTrue) – Whether to log-transform data before decoding for numerical stability.dispersion (
{"gene", "gene-batch", "gene-cell"}, default"gene") – Level at which to model the dispersion parameter in the negative binomial distribution.
Notes
This module is based on the scVI framework and implements the variational inference described in the spVIPES paper. The Product of Experts mechanism allows for flexible integration of multiple single-cell datasets with different feature sets.
- __init__(groups_lengths, groups_obs_names, groups_var_names, groups_obs_indices, groups_var_indices, transport_plan=None, pair_data=False, use_labels=False, n_labels=None, n_batch=0, n_hidden=128, n_dimensions_shared=25, n_dimensions_private=10, dropout_rate=0.1, use_batch_norm=True, use_layer_norm=False, log_variational_inference=True, log_variational_generative=True, dispersion='gene')#
Initialize the spVIPES variational autoencoder module.
This method sets up the neural network components including encoders and decoders for each group, and configures the Product of Experts mechanism based on the provided parameters. The module extends scVI’s VAE architecture for multi-group integration with shared-private latent spaces.
Notes
The initialization automatically configures the appropriate PoE strategy based on the provided inputs (transport_plan, use_labels, pair_data). The module creates separate encoders and decoders for each group while sharing the latent space structure for integration.
The dispersion parameter controls how the negative binomial distribution variance is modeled, with gene-level being the most common choice for single-cell data.
- inference(x, batch_index, groups, global_indices, **kwargs)#
Runs the encoder model.
- generative(private_stats, shared_stats, poe_stats, library, groups, batch_index)#
Runs the generative model.
- get_loadings(dataset, type_latent)#
Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder.
- Return type:
- loss(tensors_by_group, inference_outputs, generative_outputs, kl_weight=1.0)#
Loss function.
Neural Network Components#
Encoder#
Variational encoder network for spVIPES. |
- class spVIPES.nn.networks.Encoder#
Bases:
ModuleVariational encoder network for spVIPES.
This encoder maps input gene expression data to latent representations using a variational approach. It outputs both mean and variance parameters for the latent distribution, enabling sampling during training and inference.
- Parameters:
n_input (
int) – Number of input features (genes) in the expression data.n_topics (
int) – Number of output dimensions in the latent space (topics/factors).hidden (
int, default100) – Number of hidden units in the fully connected layers.dropout (
float, default0.1) – Dropout rate applied to hidden layers for regularization.n_cat_list (
Iterable[int], optional) – List of categorical covariate dimensions. Each element represents the number of categories for a categorical covariate (e.g., batch).groups (
str, optional) – Group identifier for this encoder instance.
Notes
The encoder uses a two-layer fully connected architecture with ReLU activations and batch normalization on the output layers. It outputs parameters for a normal distribution in latent space, following the variational autoencoder framework.
The forward pass returns both the latent representation (theta) and intermediate statistics needed for the variational objective.
- __init__(n_input, n_topics, hidden=100, dropout=0.1, n_cat_list=None, groups=None)#
- forward(data, specie, *cat_list)#
Forward pass through the variational encoder.
- Parameters:
- Returns:
dict Dictionary containing encoder outputs:
logtheta_loc : torch.Tensor - Mean of latent distribution
logtheta_logvar : torch.Tensor - Log variance of latent distribution
logtheta_scale : torch.Tensor - Standard deviation of latent distribution
log_z : torch.Tensor - Sampled latent variable (log space)
theta : torch.Tensor - Normalized latent representation (simplex)
qz : torch.distributions.Normal - Latent distribution object
Decoder#
Linear decoder for spVIPES with shared-private latent space decomposition. |
- class spVIPES.nn.networks.LinearDecoderSPVIPE#
Bases:
ModuleLinear decoder for spVIPES with shared-private latent space decomposition.
This decoder takes separate shared and private latent representations and decodes them into gene expression parameters. It implements a mixture model that combines shared and private contributions to generate the final output distribution parameters for the negative binomial likelihood.
- Parameters:
n_input_private (
int) – Dimensionality of the private latent space input.n_input_shared (
int) – Dimensionality of the shared latent space input.n_output (
int) – Number of output features (genes) to reconstruct.n_cat_list (
Iterable[int], optional) – List of categorical covariate dimensions for batch correction.use_batch_norm (
bool, defaultFalse) – Whether to use batch normalization in the decoder layers.use_layer_norm (
bool, defaultFalse) – Whether to use layer normalization in the decoder layers.bias (
bool, defaultFalse) – Whether to include bias terms in linear layers.n_hidden (
int, default256) – Number of hidden units in the mixing network.**kwargs – Additional keyword arguments passed to FCLayers.
Notes
The decoder consists of three main components:
Private factor regressor: Maps private latent space to gene-specific factors
Shared factor regressor: Maps shared latent space to gene-specific factors
Mixing network: Learns how to combine shared and private contributions
The output includes both separate private/shared reconstructions and a mixed reconstruction that combines both components according to learned mixing weights.
- __init__(n_input_private, n_input_shared, n_output, n_cat_list=None, use_batch_norm=False, use_layer_norm=False, bias=False, n_hidden=256, **kwargs)#
- forward(dispersion, z_private, z_shared, library, *cat_list)#
Forward pass through the decoder network.
- Parameters:
dispersion (
str) – Dispersion parameter identifier (currently unused but kept for compatibility).z_private (
torch.Tensor) – Private latent representation with shape (batch_size, n_input_private).z_shared (
torch.Tensor) – Shared latent representation with shape (batch_size, n_input_shared).library (
torch.Tensor) – Library size factors with shape (batch_size, 1) for scaling output rates.*cat_list (
int) – Variable length list of categorical covariate indices.
- Returns:
tuple Tuple of decoder outputs:
px_scale_private : torch.Tensor - Normalized expression rates from private space
px_scale_shared : torch.Tensor - Normalized expression rates from shared space
px_rate_private : torch.Tensor - Library-scaled rates from private space
px_rate_shared : torch.Tensor - Library-scaled rates from shared space
px_mixing : torch.Tensor - Learned mixing weights (logits)
px_scale : torch.Tensor - Final mixed expression rates
Data Management#
AnnData Manager#
Provides an interface to validate and process an AnnData object for use in scvi-tools. |
- class spVIPES.data._manager.AnnDataManager#
Bases:
objectProvides an interface to validate and process an AnnData object for use in scvi-tools.
A class which wraps a collection of AnnDataField instances and provides an interface to validate and process an AnnData object with respect to the fields.
- Parameters:
fields (
Optional[list[type[BaseAnnDataField]]] (default:None)) – List of AnnDataFields to intialize with.setup_method_args (
Optional[dict] (default:None)) – Dictionary describing the model and arguments passed in by the user to setup this AnnDataManager.validation_checks (
Optional[AnnDataManagerValidationCheck] (default:None)) – DataClass specifying which global validation checks to run on the data object.
Examples
>>> fields = [LayerField("counts", "raw_counts")] >>> adata_manager = AnnDataManager(fields=fields) >>> adata_manager.register_fields(adata)
Notes
This class is not initialized with a specific AnnData object, but later sets
self.adataviaregister_fields(). This decouples the generalized definition of the scvi-tools interface with the registration of an instance of data.- __init__(fields=None, setup_method_args=None, validation_checks=None)#
- register_fields(adata, source_registry=None, **transfer_kwargs)#
Registers each field associated with this instance with the AnnData object.
Either registers or transfers the setup from
source_setup_dictif passed in. Setsself.adata.- Parameters:
- register_new_fields(fields)#
Register new fields to a manager instance.
This is useful to augment the functionality of an existing manager.
- transfer_fields(adata_target, **kwargs)#
Transfers an existing setup to each field associated with this instance with the target AnnData object.
Creates a new
AnnDataManagerinstance with the same set of fields. Then, registers the fields with a target AnnData object, incorporating details of the source registry where necessary (e.g. for validation or modified data setup).- Parameters:
adata_target (
AnnOrMuData) – AnnData object to be registered.kwargs – Additional keywords which modify transfer behavior.
- Return type:
- validate()#
Checks if AnnData was last setup with this AnnDataManager instance and reregisters it if not.
- Return type:
- update_setup_method_args(setup_method_args)#
Update setup method args.
- Parameters:
setup_method_args (
dict) – This is a bit of a misnomer, this is a dict representing kwargs of the setup method that will be used to update the existing values in the registry of this instance.
- property registry: dict#
Returns the top-level registry dictionary for the AnnData object registered with this instance as an attrdict.
- property data_registry: scvi.utils.attrdict#
Returns the data registry for the AnnData object registered with this instance.
- create_torch_dataset(indices=None, data_and_attributes=None)#
Creates a torch dataset from the AnnData object registered with this instance.
- Parameters:
indices (
Union[Sequence[int],Sequence[bool],None] (default:None)) – The indices of the observations in the adata to usedata_and_attributes (
Union[list[str],dict[str,dtype],None] (default:None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor) or list of such keys. A list can be used to subset to certain keys in the event that more tensors than needed have been registered. IfNone, defaults to all registered data.
- Return type:
AnnTorchDataset- Returns:
Torch Dataset
- property summary_stats: scvi.utils.attrdict#
Returns the summary stats for the AnnData object registered with this instance.
- get_from_registry(registry_key)#
Returns the object in AnnData associated with the key in the data registry.
- Parameters:
registry_key (str) – key of object to get from
self.data_registry- Return type:
np.ndarray | pd.DataFrame
- Returns:
The requested data.
- get_state_registry(registry_key)#
Returns the state registry for the AnnDataField registered with this instance.
- Return type:
attrdict
- static view_setup_method_args(registry)#
Prints setup kwargs used to produce a given registry.
Data Preparation#
|
Prepare and concatenate multiple AnnData objects for spVIPES integration. |
- spVIPES.data.prepare_adatas.prepare_adatas(adatas, layers=None)#
Prepare and concatenate multiple AnnData objects for spVIPES integration.
This function takes multiple single-cell datasets and prepares them for multi-group integration by concatenating them into a single AnnData object while preserving group-specific metadata. It sets up all the necessary data structures for spVIPES to perform shared-private latent space learning.
- Parameters:
adatas (
dict[str,AnnData]) – Dictionary mapping group names (strings) to their corresponding AnnData objects. Each AnnData contains single-cell expression data for one group/dataset. Currently supports exactly 2 groups.layers (
list[list[strorNone]], optional) – Specification of which layers to use from each AnnData object. Currently not implemented in the function body.
- Returns:
AnnData Concatenated AnnData object containing all groups with additional metadata:
groups : Added to
.obsindicating which group each cell belongs toindices : Added to
.obswith within-group cell indicesgroups_var_indices : In
.uns, indices of variables for each groupgroups_obs_indices : In
.uns, indices of observations for each groupgroups_obs_names : In
.uns, observation names for each groupgroups_obs : In
.uns, observation metadata for each groupgroups_lengths : In
.uns, number of features per groupgroups_var_names : In
.uns, variable names for each groupgroups_mapping : In
.uns, mapping from indices to group names
- Raises:
ValueError – If more or fewer than 2 groups are provided (current limitation).
Notes
The function performs several important preprocessing steps:
Variable name prefixing: Adds group prefixes to avoid name conflicts
Metadata harmonization: Combines observation metadata across groups
Index tracking: Creates mappings to track group-specific indices
Outer join concatenation: Preserves all variables from all groups
This prepared data structure enables spVIPES to handle datasets with different feature sets (genes) while maintaining the ability to separate shared and private latent representations.
Examples
Basic usage with two datasets:
>>> import spVIPES >>> import scanpy as sc >>> >>> # Load your datasets >>> adata1 = sc.read_h5ad("dataset1.h5ad") >>> adata2 = sc.read_h5ad("dataset2.h5ad") >>> >>> # Prepare for spVIPES >>> adatas_dict = {"treatment": adata1, "control": adata2} >>> combined_adata = spVIPES.data.prepare_adatas(adatas_dict) >>> >>> # Now ready for spVIPES setup >>> spVIPES.model.spVIPES.setup_anndata(combined_adata, groups_key="groups")
Integration with different feature sets:
>>> # Datasets can have different genes >>> print(f"Dataset 1: {adata1.n_vars} genes") >>> print(f"Dataset 2: {adata2.n_vars} genes") >>> >>> combined = spVIPES.data.prepare_adatas({"batch1": adata1, "batch2": adata2}) >>> print(f"Combined: {combined.n_vars} genes") # Union of all genes
Data Loaders#
Concatenated Data Loader#
DataLoader that supports a list of list of indices to load. |
- class spVIPES.dataloaders._concat_dataloader.ConcatDataLoader#
Bases:
DataLoaderDataLoader that supports a list of list of indices to load.
- Parameters:
adata_manager (
AnnDataManager) –AnnDataManagerobject that has been created viasetup_anndata.indices_list (
list[list[int]]) – List where each element is a list of indices in the adata to loadshuffle (
bool(default:True)) – Whether the data should be shuffleduse_labels (
bool(default:False)) – Whether to use labels for samplingbatch_size (
int(default:128)) – minibatch size to load each iterationdata_and_attributes (
Optional[dict] (default:None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor). IfNone, defaults to all registered data.data_loader_kwargs – Keyword arguments for
DataLoader
- __init__(adata_manager, indices_list, shuffle=True, use_labels=False, batch_size=128, data_and_attributes=None, drop_last=False, **data_loader_kwargs)#
AnnData Loader#
DataLoader for loading tensors from AnnData objects. |
- class spVIPES.dataloaders._ann_dataloader.AnnDataLoader#
Bases:
DataLoaderDataLoader for loading tensors from AnnData objects.
- Parameters:
adata_manager (
AnnDataManager) –AnnDataManagerobject with a registered AnnData object.shuffle (
bool(default:False)) – Whether the data should be shuffleduse_labels (
bool(default:False)) – Whether to use labels for weighted samplingindices (
Union[Sequence[int],Sequence[bool],None] (default:None)) – The indices of the observations in the adata to loadbatch_size (
int(default:128)) – minibatch size to load each iterationsampler (
Optional[Sampler] (default:None)) – Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified. By default, we use a custom sampler that is designed to get a minibatch of data with one call to __getitem__.data_and_attributes (
Union[list[str],dict[str,dtype],None] (default:None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor) or list of such keys. A list can be used to subset to certain keys in the event that more tensors than needed have been registered. IfNone, defaults to all registered data.iter_ndarray (
bool(default:False)) – Whether to iterate over numpy arrays instead of torch tensorsdata_loader_kwargs – Keyword arguments for
DataLoader
- __init__(adata_manager, shuffle=False, use_labels=False, indices=None, batch_size=128, sampler=None, data_and_attributes=None, drop_last=False, iter_ndarray=False, **data_loader_kwargs)#