from typing import Literal
import numpy as np
import torch
from scvi import REGISTRY_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import Encoder
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl
from ._base_components import DecoderSCGEN
[docs]class SCGENVAE(BaseModuleClass):
"""
Variational auto-encoder model.
Parameters
----------
n_input
Number of input genes
n_hidden
Number of nodes per hidden layer
n_latent
Dimensionality of the latent space
n_layers
Number of hidden layers used for encoder and decoder NNs
dropout_rate
Dropout rate for neural networks
use_layer_norm
Whether to use layer norm in layers
kl_weight
Weight for kl divergence
"""
def __init__(
self,
n_input: int,
n_hidden: int = 800,
n_latent: int = 10,
n_layers: int = 2,
dropout_rate: float = 0.1,
log_variational: bool = False,
latent_distribution: str = "normal",
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
kl_weight: float = 0.00005,
):
super().__init__()
self.n_layers = n_layers
self.n_latent = n_latent
self.log_variational = log_variational
self.latent_distribution = "normal"
self.kl_weight = kl_weight
use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
self.z_encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
activation_fn=torch.nn.LeakyReLU,
)
n_input_decoder = n_latent
self.decoder = DecoderSCGEN(
n_input_decoder,
n_input,
n_layers=n_layers,
n_hidden=n_hidden,
activation_fn=torch.nn.LeakyReLU,
dropout_rate=dropout_rate,
)
def _get_inference_input(self, tensors):
x = tensors[REGISTRY_KEYS.X_KEY]
input_dict = dict(
x=x,
)
return input_dict
def _get_generative_input(self, tensors, inference_outputs):
z = inference_outputs["z"]
input_dict = {
"z": z,
}
return input_dict
[docs] @auto_move_data
def inference(self, x):
"""
High level inference method.
Runs the inference (encoder) model.
"""
qz_m, qz_v, z = self.z_encoder(x)
outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v)
return outputs
[docs] @auto_move_data
def generative(self, z):
"""Runs the generative model."""
px = self.decoder(z)
return dict(px=px)
[docs] def loss(
self,
tensors,
inference_outputs,
generative_outputs,
):
x = tensors[REGISTRY_KEYS.X_KEY]
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
p = generative_outputs["px"]
kld = kl(
Normal(qz_m, torch.sqrt(qz_v)),
Normal(0, 1),
).sum(dim=1)
rl = self.get_reconstruction_loss(p, x)
loss = (0.5 * rl + 0.5 * (kld * self.kl_weight)).mean()
return LossOutput(loss=loss, reconstruction_loss=rl, kl_local=kld)
[docs] @torch.no_grad()
def sample(
self,
tensors,
n_samples=1,
) -> np.ndarray:
r"""
Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.
Parameters
----------
tensors
Tensors dict
n_samples
Number of required samples for each cell
library_size
Library size to scale scamples to
Returns
-------
x_new : :py:class:`torch.Tensor`
tensor with shape (n_cells, n_genes, n_samples)
"""
inference_kwargs = dict(n_samples=n_samples)
(
inference_outputs,
generative_outputs,
) = self.forward(
tensors,
inference_kwargs=inference_kwargs,
compute_loss=False,
)
px = Normal(generative_outputs["px"], 1).sample()
return px.cpu().numpy()
[docs] def get_reconstruction_loss(self, x, px) -> torch.Tensor:
loss = ((x - px) ** 2).sum(dim=1)
return loss