scgen.SCGEN¶
- class scgen.SCGEN(adata, n_hidden=800, n_latent=100, n_layers=2, dropout_rate=0.2, **model_kwargs)[source]¶
Implementation of scGen model for batch removal and perturbation prediction.
- Parameters
- adata :
AnnData
AnnData object that has been registered via
setup_anndata()
.- n_hidden :
int
(default:800
) Number of nodes per hidden layer.
- n_latent :
int
(default:100
) Dimensionality of the latent space.
- n_layers :
int
(default:2
) Number of hidden layers used for encoder and decoder NNs.
- dropout_rate :
float
(default:0.2
) Dropout rate for neural networks.
- **model_kwargs
Keyword args for
SCGENVAE
- adata :
Examples
>>> vae = scgen.SCGEN(adata) >>> vae.train() >>> adata.obsm["X_scgen"] = vae.get_latent_representation()
Attributes
Data attached to model instance.
Manager instance associated with self.adata.
The current device that the module's params are on.
Returns computed metrics during training.
Whether the model has been trained.
Observations that are in test set.
Observations that are in train set.
Observations that are in validation set.
Methods
batch_removal
([adata])Removes batch effects.
binary_classifier
(adata, delta, ctrl_key, ...)Latent space classifier.
convert_legacy_save
(dir_path, output_dir_path)Converts a legacy saved model (<v0.15.0) to the updated save format.
get_anndata_manager
(adata[, required])Retrieves the
AnnDataManager
for a given AnnData object specific to this model instance.get_decoded_expression
([adata, indices, ...])Get decoded expression.
get_elbo
([adata, indices, batch_size])Return the ELBO for the data.
get_from_registry
(adata, registry_key)Returns the object in AnnData associated with the key in the data registry.
get_latent_representation
([adata, indices, ...])Return the latent representation for each cell.
get_marginal_ll
([adata, indices, ...])Return the marginal LL for the data.
get_reconstruction_error
([adata, indices, ...])Return the reconstruction error for the data.
load
(dir_path[, adata, use_gpu, prefix, ...])Instantiate a model from the saved output.
predict
([ctrl_key, stim_key, ...])Predicts the cell type provided by the user in stimulated condition.
reg_mean_plot
(adata, axis_keys, labels[, ...])Plots mean matching figure for a set of specific genes.
reg_var_plot
(adata, axis_keys, labels[, ...])Plots variance matching figure for a set of specific genes.
register_manager
(adata_manager)Registers an
AnnDataManager
instance with this model class.save
(dir_path[, prefix, overwrite, save_anndata])Save the state of the model.
setup_anndata
(adata[, batch_key, labels_key])Sets up the
AnnData
object for this model.to_device
(device)Move model to device.
train
([max_epochs, use_gpu, train_size, ...])Train the model.
view_anndata_setup
([adata, ...])Print summary of the setup for the initial AnnData or a given AnnData object.
view_setup_args
(dir_path[, prefix])Print args used to setup a saved model.