API Reference#

Core Classes#

spVIPES Model#

The main model class for shared-private variational inference.

model.spvipes.spVIPES

Implementation of the spVIPES model.

class spVIPES.model.spvipes.spVIPES#

Bases: MultiGroupTrainingMixin, BaseModelClass

Implementation of the spVIPES model.

spVIPES (shared-private Variational Inference with Product of Experts and Supervision) is a method for integrating multi-group single-cell datasets using a shared-private latent space approach. The model learns both shared representations (common across groups) and private representations (group-specific) through a Product of Experts (PoE) framework.

Parameters:
  • adata (AnnData) – AnnData object that has been registered via setup_anndata().

  • n_hidden (int, default 128) – Number of nodes per hidden layer in the neural networks.

  • n_dimensions_shared (int, default 25) – Dimensionality of the shared latent space. This space captures features common across all groups/datasets.

  • n_dimensions_private (int, default 10) – Dimensionality of the private latent spaces. Each group gets its own private latent space of this dimensionality.

  • dropout_rate (float, default 0.1) – Dropout rate for neural networks to prevent overfitting.

  • **model_kwargs – Additional keyword arguments passed to the underlying module.

Examples

Basic usage with cell type labels:

>>> import spVIPES
>>> adata = spVIPES.data.prepare_adatas({"dataset1": dataset1, "dataset2": dataset2})
>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="groups", label_key="cell_type")
>>> model = spVIPES.model.spVIPES(adata)
>>> model.train()
>>> latents = model.get_latent_representation()

Usage with optimal transport:

>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="groups", transport_plan_key="transport_plan")
>>> model = spVIPES.model.spVIPES(adata)
>>> model.train()

Notes

  • We recommend setting n_dimensions_private < n_dimensions_shared for optimal performance

  • The model automatically selects the appropriate PoE variant based on provided inputs

  • GPU acceleration is strongly recommended for large datasets

__init__(adata, n_hidden=128, n_dimensions_shared=25, n_dimensions_private=10, dropout_rate=0.1, **model_kwargs)#
classmethod setup_anndata(cls, adata, groups_key, match_clusters=False, transport_plan_key=None, label_key=None, batch_key=None, layer=None, **kwargs)#

Set up AnnData object for spVIPES model.

This method registers the AnnData object with the model, configuring the appropriate data fields and PoE strategy based on the provided parameters. The method automatically determines whether to use label-based PoE, optimal transport PoE, or cluster-based PoE.

Parameters:
  • adata (AnnData) – Annotated data object containing the single-cell data to be integrated.

  • groups_key (str) – Key in adata.obs that defines the grouping of cells (e.g., dataset, batch, condition). This determines which cells belong to which group for integration.

  • match_clusters (bool, default False) – Whether to match clusters when using optimal transport. If True, enables cluster-based PoE which automatically matches cell clusters between groups.

  • transport_plan_key (str, optional) – Key in adata.uns containing the precomputed optimal transport plan. If provided, enables optimal transport PoE for data integration.

  • label_key (str, optional) – Key in adata.obs containing cell type labels. If provided, enables label-based PoE which uses supervised alignment based on cell types.

  • batch_key (str, optional) – Key in adata.obs for batch information to enable batch effect correction.

  • layer (str, optional) – Key in adata.layers to use for the expression data. If None, uses adata.X.

  • **kwargs – Additional keyword arguments passed to the parent setup method.

Return type:

None

Returns:

None The method modifies the AnnData object in place and registers it with the model.

Notes

Priority of PoE strategies (when multiple options are available): 1. Label-based PoE (if label_key is provided) 2. Optimal transport PoE (if transport_plan_key is provided) 3. Cluster-based PoE (if match_clusters=True)

Examples

Basic setup with groups only:

>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset")

Setup with cell type supervision:

>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset", label_key="cell_type")

Setup with optimal transport:

>>> spVIPES.model.spVIPES.setup_anndata(adata, groups_key="dataset", transport_plan_key="transport_matrix")
get_latent_representation(group_indices_list, adata=None, indices=None, normalized=False, give_mean=True, mc_samples=5000, batch_size=None, drop_last=None)#

Return the latent representation for each cell.

Parameters:
  • group_indices_list (list[list[int]]) – List of lists containing the indices of cells in each of the groups used as input for spVIPES.

  • adata (Optional[AnnData] (default: None)) – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.

  • indices (Optional[Sequence[int]] (default: None)) – Indices of cells in adata to use. If None, all cells are used.

  • normalized (bool (default: False)) – Whether to return the normalized cell embedding (softmaxed) or not

  • give_mean (bool (default: True)) – Give mean of distribution or sample from it.

  • mc_samples (int (default: 5000)) – For distributions with no closed-form mean (e.g., logistic normal), how many Monte Carlo samples to take for computing mean.

  • batch_size (Optional[int] (default: None)) – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • drop_last (Optional[bool] (default: None)) – Whether to drop the last incomplete batch. If None, automatically determined based on whether using paired PoE (True for paired, False for others).

Return type:

ndarray

Returns:

Low-dimensional topic for each cell.

get_loadings()#

Extract per-gene weights in the linear decoder.

Shape is genes by n_latent.

Return type:

dict

static __new__(cls, *args, **kwargs)#
Return type:

Any

train(group_indices_list, batch_size=128, max_epochs=None, use_gpu=None, train_size=0.9, validation_size=None, early_stopping=False, plan_kwargs=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=400, **trainer_kwargs)#

Train a multigroup spVIPES model.

This method trains the model using a custom data splitter that handles multiple groups of cells separately while maintaining the shared-private latent space learning objective.

Parameters:
  • group_indices_list (list[list[int]]) – List of indices corresponding to each group of samples. Each inner list contains the indices for cells belonging to that specific group.

  • max_epochs (int, optional) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).

  • use_gpu (str, int, bool, optional) – GPU usage specification. Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str, e.g., “cuda:0”), or use CPU (if False).

  • train_size (float, default 0.9) – Size of training set in the range [0.0, 1.0].

  • validation_size (float, optional) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to the test set.

  • batch_size (int, default 128) – Mini-batch size to use during training.

  • early_stopping (bool, default False) – Whether to perform early stopping. Additional arguments can be passed in **trainer_kwargs.

  • plan_kwargs (dict, optional) – Keyword arguments for the training plan. Arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • n_steps_kl_warmup (int, optional) – Number of training steps for KL warmup. Takes precedence over n_epochs_kl_warmup.

  • n_epochs_kl_warmup (int, default 400) – Number of epochs for KL divergence warmup.

  • **trainer_kwargs – Additional keyword arguments for the trainer.

Return type:

None

Returns:

None The model is trained in-place.

Notes

This method uses a specialized MultiGroupDataSplitter that ensures proper handling of multiple cell groups during training, maintaining the integrity of the shared-private latent space learning.

spVIPES Module#

The PyTorch Lightning module implementing the variational autoencoder.

module.spVIPESmodule.spVIPESmodule

PyTorch implementation of spVIPES variational autoencoder module.

class spVIPES.module.spVIPESmodule.spVIPESmodule#

Bases: BaseModuleClass

PyTorch implementation of spVIPES variational autoencoder module.

This module implements the core variational autoencoder with Product of Experts (PoE) for shared-private latent space learning. It extends scVI’s underlying VAE architecture with multi-group integration capabilities and support for different PoE strategies.

Parameters:
  • groups_lengths (list of int) – List containing the number of features (genes) for each group/dataset.

  • groups_obs_names (list) – List of observation names for each group.

  • groups_var_names (list) – List of variable (gene) names for each group.

  • groups_obs_indices (list) – List of observation indices for each group.

  • groups_var_indices (list) – List of variable indices for each group.

  • transport_plan (torch.Tensor, optional) – Precomputed optimal transport plan matrix for PoE alignment.

  • pair_data (bool, default False) – Whether to use paired data for direct cell-to-cell correspondences.

  • use_labels (bool, default False) – Whether to use cell type labels for supervised PoE alignment.

  • n_labels (int, optional) – Number of unique cell type labels when using supervised alignment.

  • n_batch (int, default 0) – Number of batches. If 0, no batch correction is performed.

  • n_hidden (int, default 128) – Number of nodes per hidden layer in encoder and decoder networks.

  • n_dimensions_shared (int, default 25) – Dimensionality of the shared latent space capturing common features.

  • n_dimensions_private (int, default 10) – Dimensionality of private latent spaces capturing group-specific features.

  • dropout_rate (float, default 0.1) – Dropout rate for neural networks to prevent overfitting.

  • use_batch_norm (bool, default True) – Whether to use batch normalization in neural networks.

  • use_layer_norm (bool, default False) – Whether to use layer normalization in neural networks.

  • log_variational_inference (bool, default True) – Whether to log-transform data before encoding for numerical stability.

  • log_variational_generative (bool, default True) – Whether to log-transform data before decoding for numerical stability.

  • dispersion ({"gene", "gene-batch", "gene-cell"}, default "gene") – Level at which to model the dispersion parameter in the negative binomial distribution.

Notes

This module is based on the scVI framework and implements the variational inference described in the spVIPES paper. The Product of Experts mechanism allows for flexible integration of multiple single-cell datasets with different feature sets.

__init__(groups_lengths, groups_obs_names, groups_var_names, groups_obs_indices, groups_var_indices, transport_plan=None, pair_data=False, use_labels=False, n_labels=None, n_batch=0, n_hidden=128, n_dimensions_shared=25, n_dimensions_private=10, dropout_rate=0.1, use_batch_norm=True, use_layer_norm=False, log_variational_inference=True, log_variational_generative=True, dispersion='gene')#

Initialize the spVIPES variational autoencoder module.

This method sets up the neural network components including encoders and decoders for each group, and configures the Product of Experts mechanism based on the provided parameters. The module extends scVI’s VAE architecture for multi-group integration with shared-private latent spaces.

Notes

The initialization automatically configures the appropriate PoE strategy based on the provided inputs (transport_plan, use_labels, pair_data). The module creates separate encoders and decoders for each group while sharing the latent space structure for integration.

The dispersion parameter controls how the negative binomial distribution variance is modeled, with gene-level being the most common choice for single-cell data.

inference(x, batch_index, groups, global_indices, **kwargs)#

Runs the encoder model.

generative(private_stats, shared_stats, poe_stats, library, groups, batch_index)#

Runs the generative model.

get_loadings(dataset, type_latent)#

Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder.

Return type:

ndarray

loss(tensors_by_group, inference_outputs, generative_outputs, kl_weight=1.0)#

Loss function.

static __new__(cls, *args, **kwargs)#
Return type:

Any

Neural Network Components#

Encoder#

nn.networks.Encoder

Variational encoder network for spVIPES.

class spVIPES.nn.networks.Encoder#

Bases: Module

Variational encoder network for spVIPES.

This encoder maps input gene expression data to latent representations using a variational approach. It outputs both mean and variance parameters for the latent distribution, enabling sampling during training and inference.

Parameters:
  • n_input (int) – Number of input features (genes) in the expression data.

  • n_topics (int) – Number of output dimensions in the latent space (topics/factors).

  • hidden (int, default 100) – Number of hidden units in the fully connected layers.

  • dropout (float, default 0.1) – Dropout rate applied to hidden layers for regularization.

  • n_cat_list (Iterable[int], optional) – List of categorical covariate dimensions. Each element represents the number of categories for a categorical covariate (e.g., batch).

  • groups (str, optional) – Group identifier for this encoder instance.

Notes

The encoder uses a two-layer fully connected architecture with ReLU activations and batch normalization on the output layers. It outputs parameters for a normal distribution in latent space, following the variational autoencoder framework.

The forward pass returns both the latent representation (theta) and intermediate statistics needed for the variational objective.

__init__(n_input, n_topics, hidden=100, dropout=0.1, n_cat_list=None, groups=None)#
forward(data, specie, *cat_list)#

Forward pass through the variational encoder.

Parameters:
  • data (torch.Tensor) – Input gene expression data with shape (batch_size, n_input).

  • specie (int) – Species or group identifier (currently unused but kept for compatibility).

  • *cat_list (int) – Variable length list of categorical covariate indices for each sample.

Returns:

dict Dictionary containing encoder outputs:

  • logtheta_loc : torch.Tensor - Mean of latent distribution

  • logtheta_logvar : torch.Tensor - Log variance of latent distribution

  • logtheta_scale : torch.Tensor - Standard deviation of latent distribution

  • log_z : torch.Tensor - Sampled latent variable (log space)

  • theta : torch.Tensor - Normalized latent representation (simplex)

  • qz : torch.distributions.Normal - Latent distribution object

static __new__(cls, *args, **kwargs)#
Return type:

Any

Decoder#

nn.networks.LinearDecoderSPVIPE

Linear decoder for spVIPES with shared-private latent space decomposition.

class spVIPES.nn.networks.LinearDecoderSPVIPE#

Bases: Module

Linear decoder for spVIPES with shared-private latent space decomposition.

This decoder takes separate shared and private latent representations and decodes them into gene expression parameters. It implements a mixture model that combines shared and private contributions to generate the final output distribution parameters for the negative binomial likelihood.

Parameters:
  • n_input_private (int) – Dimensionality of the private latent space input.

  • n_input_shared (int) – Dimensionality of the shared latent space input.

  • n_output (int) – Number of output features (genes) to reconstruct.

  • n_cat_list (Iterable[int], optional) – List of categorical covariate dimensions for batch correction.

  • use_batch_norm (bool, default False) – Whether to use batch normalization in the decoder layers.

  • use_layer_norm (bool, default False) – Whether to use layer normalization in the decoder layers.

  • bias (bool, default False) – Whether to include bias terms in linear layers.

  • n_hidden (int, default 256) – Number of hidden units in the mixing network.

  • **kwargs – Additional keyword arguments passed to FCLayers.

Notes

The decoder consists of three main components:

  1. Private factor regressor: Maps private latent space to gene-specific factors

  2. Shared factor regressor: Maps shared latent space to gene-specific factors

  3. Mixing network: Learns how to combine shared and private contributions

The output includes both separate private/shared reconstructions and a mixed reconstruction that combines both components according to learned mixing weights.

__init__(n_input_private, n_input_shared, n_output, n_cat_list=None, use_batch_norm=False, use_layer_norm=False, bias=False, n_hidden=256, **kwargs)#
static __new__(cls, *args, **kwargs)#
Return type:

Any

forward(dispersion, z_private, z_shared, library, *cat_list)#

Forward pass through the decoder network.

Parameters:
  • dispersion (str) – Dispersion parameter identifier (currently unused but kept for compatibility).

  • z_private (torch.Tensor) – Private latent representation with shape (batch_size, n_input_private).

  • z_shared (torch.Tensor) – Shared latent representation with shape (batch_size, n_input_shared).

  • library (torch.Tensor) – Library size factors with shape (batch_size, 1) for scaling output rates.

  • *cat_list (int) – Variable length list of categorical covariate indices.

Returns:

tuple Tuple of decoder outputs:

  • px_scale_private : torch.Tensor - Normalized expression rates from private space

  • px_scale_shared : torch.Tensor - Normalized expression rates from shared space

  • px_rate_private : torch.Tensor - Library-scaled rates from private space

  • px_rate_shared : torch.Tensor - Library-scaled rates from shared space

  • px_mixing : torch.Tensor - Learned mixing weights (logits)

  • px_scale : torch.Tensor - Final mixed expression rates

Data Management#

AnnData Manager#

data._manager.AnnDataManager

Provides an interface to validate and process an AnnData object for use in scvi-tools.

class spVIPES.data._manager.AnnDataManager#

Bases: object

Provides an interface to validate and process an AnnData object for use in scvi-tools.

A class which wraps a collection of AnnDataField instances and provides an interface to validate and process an AnnData object with respect to the fields.

Parameters:
  • fields (Optional[list[type[BaseAnnDataField]]] (default: None)) – List of AnnDataFields to intialize with.

  • setup_method_args (Optional[dict] (default: None)) – Dictionary describing the model and arguments passed in by the user to setup this AnnDataManager.

  • validation_checks (Optional[AnnDataManagerValidationCheck] (default: None)) – DataClass specifying which global validation checks to run on the data object.

Examples

>>> fields = [LayerField("counts", "raw_counts")]
>>> adata_manager = AnnDataManager(fields=fields)
>>> adata_manager.register_fields(adata)

Notes

This class is not initialized with a specific AnnData object, but later sets self.adata via register_fields(). This decouples the generalized definition of the scvi-tools interface with the registration of an instance of data.

__init__(fields=None, setup_method_args=None, validation_checks=None)#
register_fields(adata, source_registry=None, **transfer_kwargs)#

Registers each field associated with this instance with the AnnData object.

Either registers or transfers the setup from source_setup_dict if passed in. Sets self.adata.

Parameters:
  • adata (AnnOrMuData) – AnnData object to be registered.

  • source_registry (Optional[dict] (default: None)) – Registry created after registering an AnnData using an AnnDataManager object.

  • transfer_kwargs – Additional keywords which modify transfer behavior. Only applicable if source_registry is set.

register_new_fields(fields)#

Register new fields to a manager instance.

This is useful to augment the functionality of an existing manager.

Parameters:

fields (list[type[BaseAnnDataField]]) – List of AnnDataFields to register

transfer_fields(adata_target, **kwargs)#

Transfers an existing setup to each field associated with this instance with the target AnnData object.

Creates a new AnnDataManager instance with the same set of fields. Then, registers the fields with a target AnnData object, incorporating details of the source registry where necessary (e.g. for validation or modified data setup).

Parameters:
  • adata_target (AnnOrMuData) – AnnData object to be registered.

  • kwargs – Additional keywords which modify transfer behavior.

Return type:

AnnDataManager

validate()#

Checks if AnnData was last setup with this AnnDataManager instance and reregisters it if not.

Return type:

None

update_setup_method_args(setup_method_args)#

Update setup method args.

Parameters:

setup_method_args (dict) – This is a bit of a misnomer, this is a dict representing kwargs of the setup method that will be used to update the existing values in the registry of this instance.

property adata_uuid: str#

Returns the UUID for the AnnData object registered with this instance.

property registry: dict#

Returns the top-level registry dictionary for the AnnData object registered with this instance as an attrdict.

property data_registry: scvi.utils.attrdict#

Returns the data registry for the AnnData object registered with this instance.

create_torch_dataset(indices=None, data_and_attributes=None)#

Creates a torch dataset from the AnnData object registered with this instance.

Parameters:
  • indices (Union[Sequence[int], Sequence[bool], None] (default: None)) – The indices of the observations in the adata to use

  • data_and_attributes (Union[list[str], dict[str, dtype], None] (default: None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor) or list of such keys. A list can be used to subset to certain keys in the event that more tensors than needed have been registered. If None, defaults to all registered data.

Return type:

AnnTorchDataset

Returns:

Torch Dataset

property summary_stats: scvi.utils.attrdict#

Returns the summary stats for the AnnData object registered with this instance.

get_from_registry(registry_key)#

Returns the object in AnnData associated with the key in the data registry.

Parameters:

registry_key (str) – key of object to get from self.data_registry

Return type:

np.ndarray | pd.DataFrame

Returns:

The requested data.

get_state_registry(registry_key)#

Returns the state registry for the AnnDataField registered with this instance.

Return type:

attrdict

static view_setup_method_args(registry)#

Prints setup kwargs used to produce a given registry.

Parameters:

registry (dict) – Registry produced by an AnnDataManager.

Return type:

None

view_registry(hide_state_registries=False)#

Prints summary of the registry.

Parameters:

hide_state_registries (bool (default: False)) – If True, prints a shortened summary without details of each state registry.

Return type:

None

Data Preparation#

data.prepare_adatas.prepare_adatas(adatas[, ...])

Prepare and concatenate multiple AnnData objects for spVIPES integration.

spVIPES.data.prepare_adatas.prepare_adatas(adatas, layers=None)#

Prepare and concatenate multiple AnnData objects for spVIPES integration.

This function takes multiple single-cell datasets and prepares them for multi-group integration by concatenating them into a single AnnData object while preserving group-specific metadata. It sets up all the necessary data structures for spVIPES to perform shared-private latent space learning.

Parameters:
  • adatas (dict[str, AnnData]) – Dictionary mapping group names (strings) to their corresponding AnnData objects. Each AnnData contains single-cell expression data for one group/dataset. Currently supports exactly 2 groups.

  • layers (list[list[str or None]], optional) – Specification of which layers to use from each AnnData object. Currently not implemented in the function body.

Returns:

AnnData Concatenated AnnData object containing all groups with additional metadata:

  • groups : Added to .obs indicating which group each cell belongs to

  • indices : Added to .obs with within-group cell indices

  • groups_var_indices : In .uns, indices of variables for each group

  • groups_obs_indices : In .uns, indices of observations for each group

  • groups_obs_names : In .uns, observation names for each group

  • groups_obs : In .uns, observation metadata for each group

  • groups_lengths : In .uns, number of features per group

  • groups_var_names : In .uns, variable names for each group

  • groups_mapping : In .uns, mapping from indices to group names

Raises:

ValueError – If more or fewer than 2 groups are provided (current limitation).

Notes

The function performs several important preprocessing steps:

  1. Variable name prefixing: Adds group prefixes to avoid name conflicts

  2. Metadata harmonization: Combines observation metadata across groups

  3. Index tracking: Creates mappings to track group-specific indices

  4. Outer join concatenation: Preserves all variables from all groups

This prepared data structure enables spVIPES to handle datasets with different feature sets (genes) while maintaining the ability to separate shared and private latent representations.

Examples

Basic usage with two datasets:

>>> import spVIPES
>>> import scanpy as sc
>>>
>>> # Load your datasets
>>> adata1 = sc.read_h5ad("dataset1.h5ad")
>>> adata2 = sc.read_h5ad("dataset2.h5ad")
>>>
>>> # Prepare for spVIPES
>>> adatas_dict = {"treatment": adata1, "control": adata2}
>>> combined_adata = spVIPES.data.prepare_adatas(adatas_dict)
>>>
>>> # Now ready for spVIPES setup
>>> spVIPES.model.spVIPES.setup_anndata(combined_adata, groups_key="groups")

Integration with different feature sets:

>>> # Datasets can have different genes
>>> print(f"Dataset 1: {adata1.n_vars} genes")
>>> print(f"Dataset 2: {adata2.n_vars} genes")
>>>
>>> combined = spVIPES.data.prepare_adatas({"batch1": adata1, "batch2": adata2})
>>> print(f"Combined: {combined.n_vars} genes")  # Union of all genes

Data Loaders#

Concatenated Data Loader#

dataloaders._concat_dataloader.ConcatDataLoader

DataLoader that supports a list of list of indices to load.

class spVIPES.dataloaders._concat_dataloader.ConcatDataLoader#

Bases: DataLoader

DataLoader that supports a list of list of indices to load.

Parameters:
  • adata_manager (AnnDataManager) – AnnDataManager object that has been created via setup_anndata.

  • indices_list (list[list[int]]) – List where each element is a list of indices in the adata to load

  • shuffle (bool (default: True)) – Whether the data should be shuffled

  • use_labels (bool (default: False)) – Whether to use labels for sampling

  • batch_size (int (default: 128)) – minibatch size to load each iteration

  • data_and_attributes (Optional[dict] (default: None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor). If None, defaults to all registered data.

  • data_loader_kwargs – Keyword arguments for DataLoader

__init__(adata_manager, indices_list, shuffle=True, use_labels=False, batch_size=128, data_and_attributes=None, drop_last=False, **data_loader_kwargs)#
static __new__(cls, *args, **kwargs)#
Return type:

Any

AnnData Loader#

dataloaders._ann_dataloader.AnnDataLoader

DataLoader for loading tensors from AnnData objects.

class spVIPES.dataloaders._ann_dataloader.AnnDataLoader#

Bases: DataLoader

DataLoader for loading tensors from AnnData objects.

Parameters:
  • adata_manager (AnnDataManager) – AnnDataManager object with a registered AnnData object.

  • shuffle (bool (default: False)) – Whether the data should be shuffled

  • use_labels (bool (default: False)) – Whether to use labels for weighted sampling

  • indices (Union[Sequence[int], Sequence[bool], None] (default: None)) – The indices of the observations in the adata to load

  • batch_size (int (default: 128)) – minibatch size to load each iteration

  • sampler (Optional[Sampler] (default: None)) – Defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified. By default, we use a custom sampler that is designed to get a minibatch of data with one call to __getitem__.

  • data_and_attributes (Union[list[str], dict[str, dtype], None] (default: None)) – Dictionary with keys representing keys in data registry (adata_manager.data_registry) and value equal to desired numpy loading type (later made into torch tensor) or list of such keys. A list can be used to subset to certain keys in the event that more tensors than needed have been registered. If None, defaults to all registered data.

  • iter_ndarray (bool (default: False)) – Whether to iterate over numpy arrays instead of torch tensors

  • data_loader_kwargs – Keyword arguments for DataLoader

__init__(adata_manager, shuffle=False, use_labels=False, indices=None, batch_size=128, sampler=None, data_and_attributes=None, drop_last=False, iter_ndarray=False, **data_loader_kwargs)#
static __new__(cls, *args, **kwargs)#
Return type:

Any