scgen.SCGEN

class scgen.SCGEN(adata, n_hidden=800, n_latent=100, n_layers=2, dropout_rate=0.2, **model_kwargs)[source]

Implementation of scGen model for batch removal and perturbation prediction.

Parameters
adata : AnnData

AnnData object that has been registered via setup_anndata().

n_hidden : int (default: 800)

Number of nodes per hidden layer.

n_latent : int (default: 100)

Dimensionality of the latent space.

n_layers : int (default: 2)

Number of hidden layers used for encoder and decoder NNs.

dropout_rate : float (default: 0.2)

Dropout rate for neural networks.

**model_kwargs

Keyword args for SCGENVAE

Examples

>>> vae = scgen.SCGEN(adata)
>>> vae.train()
>>> adata.obsm["X_scgen"] = vae.get_latent_representation()

Attributes

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

device

The current device that the module's params are on.

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.

Methods

batch_removal([adata])

Removes batch effects.

binary_classifier(adata, delta, ctrl_key, ...)

Latent space classifier.

convert_legacy_save(dir_path, output_dir_path)

Converts a legacy saved model (<v0.15.0) to the updated save format.

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object specific to this model instance.

get_decoded_expression([adata, indices, ...])

Get decoded expression.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_latent_representation([adata, indices, ...])

Return the latent representation for each cell.

get_marginal_ll([adata, indices, ...])

Return the marginal LL for the data.

get_reconstruction_error([adata, indices, ...])

Return the reconstruction error for the data.

load(dir_path[, adata, use_gpu, prefix, ...])

Instantiate a model from the saved output.

predict([ctrl_key, stim_key, ...])

Predicts the cell type provided by the user in stimulated condition.

reg_mean_plot(adata, axis_keys, labels[, ...])

Plots mean matching figure for a set of specific genes.

reg_var_plot(adata, axis_keys, labels[, ...])

Plots variance matching figure for a set of specific genes.

register_manager(adata_manager)

Registers an AnnDataManager instance with this model class.

save(dir_path[, prefix, overwrite, save_anndata])

Save the state of the model.

setup_anndata(adata[, batch_key, labels_key])

Sets up the AnnData object for this model.

to_device(device)

Move model to device.

train([max_epochs, use_gpu, train_size, ...])

Train the model.

view_anndata_setup([adata, ...])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.