spVIPES.nn.networks.Encoder

spVIPES.nn.networks.Encoder#

class spVIPES.nn.networks.Encoder#

Bases: Module

Variational encoder network for spVIPES.

This encoder maps input gene expression data to latent representations using a variational approach. It outputs both mean and variance parameters for the latent distribution, enabling sampling during training and inference.

Parameters:
  • n_input (int) – Number of input features (genes) in the expression data.

  • n_topics (int) – Number of output dimensions in the latent space (topics/factors).

  • hidden (int, default 100) – Number of hidden units in the fully connected layers.

  • dropout (float, default 0.1) – Dropout rate applied to hidden layers for regularization.

  • n_cat_list (Iterable[int], optional) – List of categorical covariate dimensions. Each element represents the number of categories for a categorical covariate (e.g., batch).

  • groups (str, optional) – Group identifier for this encoder instance.

Notes

The encoder uses a two-layer fully connected architecture with ReLU activations and batch normalization on the output layers. It outputs parameters for a normal distribution in latent space, following the variational autoencoder framework.

The forward pass returns both the latent representation (theta) and intermediate statistics needed for the variational objective.

__init__(n_input, n_topics, hidden=100, dropout=0.1, n_cat_list=None, groups=None)#
forward(data, specie, *cat_list)#

Forward pass through the variational encoder.

Parameters:
  • data (torch.Tensor) – Input gene expression data with shape (batch_size, n_input).

  • specie (int) – Species or group identifier (currently unused but kept for compatibility).

  • *cat_list (int) – Variable length list of categorical covariate indices for each sample.

Returns:

dict Dictionary containing encoder outputs:

  • logtheta_loc : torch.Tensor - Mean of latent distribution

  • logtheta_logvar : torch.Tensor - Log variance of latent distribution

  • logtheta_scale : torch.Tensor - Standard deviation of latent distribution

  • log_z : torch.Tensor - Sampled latent variable (log space)

  • theta : torch.Tensor - Normalized latent representation (simplex)

  • qz : torch.distributions.Normal - Latent distribution object

static __new__(cls, *args, **kwargs)#
Return type:

Any