spVIPES.nn.networks.Encoder#
- class spVIPES.nn.networks.Encoder#
Bases:
ModuleVariational encoder network for spVIPES.
This encoder maps input gene expression data to latent representations using a variational approach. It outputs both mean and variance parameters for the latent distribution, enabling sampling during training and inference.
- Parameters:
n_input (
int) – Number of input features (genes) in the expression data.n_topics (
int) – Number of output dimensions in the latent space (topics/factors).hidden (
int, default100) – Number of hidden units in the fully connected layers.dropout (
float, default0.1) – Dropout rate applied to hidden layers for regularization.n_cat_list (
Iterable[int], optional) – List of categorical covariate dimensions. Each element represents the number of categories for a categorical covariate (e.g., batch).groups (
str, optional) – Group identifier for this encoder instance.
Notes
The encoder uses a two-layer fully connected architecture with ReLU activations and batch normalization on the output layers. It outputs parameters for a normal distribution in latent space, following the variational autoencoder framework.
The forward pass returns both the latent representation (theta) and intermediate statistics needed for the variational objective.
- __init__(n_input, n_topics, hidden=100, dropout=0.1, n_cat_list=None, groups=None)#
- forward(data, specie, *cat_list)#
Forward pass through the variational encoder.
- Parameters:
- Returns:
dict Dictionary containing encoder outputs:
logtheta_loc : torch.Tensor - Mean of latent distribution
logtheta_logvar : torch.Tensor - Log variance of latent distribution
logtheta_scale : torch.Tensor - Standard deviation of latent distribution
log_z : torch.Tensor - Sampled latent variable (log space)
theta : torch.Tensor - Normalized latent representation (simplex)
qz : torch.distributions.Normal - Latent distribution object