spVIPES.nn.networks.LinearDecoderSPVIPE#
- class spVIPES.nn.networks.LinearDecoderSPVIPE#
Bases:
ModuleLinear decoder for spVIPES with shared-private latent space decomposition.
This decoder takes separate shared and private latent representations and decodes them into gene expression parameters. It implements a mixture model that combines shared and private contributions to generate the final output distribution parameters for the negative binomial likelihood.
- Parameters:
n_input_private (
int) – Dimensionality of the private latent space input.n_input_shared (
int) – Dimensionality of the shared latent space input.n_output (
int) – Number of output features (genes) to reconstruct.n_cat_list (
Iterable[int], optional) – List of categorical covariate dimensions for batch correction.use_batch_norm (
bool, defaultFalse) – Whether to use batch normalization in the decoder layers.use_layer_norm (
bool, defaultFalse) – Whether to use layer normalization in the decoder layers.bias (
bool, defaultFalse) – Whether to include bias terms in linear layers.n_hidden (
int, default256) – Number of hidden units in the mixing network.**kwargs – Additional keyword arguments passed to FCLayers.
Notes
The decoder consists of three main components:
Private factor regressor: Maps private latent space to gene-specific factors
Shared factor regressor: Maps shared latent space to gene-specific factors
Mixing network: Learns how to combine shared and private contributions
The output includes both separate private/shared reconstructions and a mixed reconstruction that combines both components according to learned mixing weights.
- __init__(n_input_private, n_input_shared, n_output, n_cat_list=None, use_batch_norm=False, use_layer_norm=False, bias=False, n_hidden=256, **kwargs)#
- forward(dispersion, z_private, z_shared, library, *cat_list)#
Forward pass through the decoder network.
- Parameters:
dispersion (
str) – Dispersion parameter identifier (currently unused but kept for compatibility).z_private (
torch.Tensor) – Private latent representation with shape (batch_size, n_input_private).z_shared (
torch.Tensor) – Shared latent representation with shape (batch_size, n_input_shared).library (
torch.Tensor) – Library size factors with shape (batch_size, 1) for scaling output rates.*cat_list (
int) – Variable length list of categorical covariate indices.
- Returns:
tuple Tuple of decoder outputs:
px_scale_private : torch.Tensor - Normalized expression rates from private space
px_scale_shared : torch.Tensor - Normalized expression rates from shared space
px_rate_private : torch.Tensor - Library-scaled rates from private space
px_rate_shared : torch.Tensor - Library-scaled rates from shared space
px_mixing : torch.Tensor - Learned mixing weights (logits)
px_scale : torch.Tensor - Final mixed expression rates