From 7a5886ecd6d4f5d0948a162c8d156e3a625defe1 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Wed, 11 Dec 2024 21:57:02 -0800 Subject: [PATCH] Revision done. --- src/scvi_hub_models/__main__.py | 12 +- .../config/haniffa_covid_pbmc.json | 5 +- .../config/heart_cell_atlas.json | 4 + .../config/human_lung_cell_atlas.json | 5 +- .../config/mouse_thymus_cite.json | 7 +- src/scvi_hub_models/config/neurips_cite.json | 7 +- src/scvi_hub_models/config/test_scvi.json | 4 +- src/scvi_hub_models/models/_base_workflow.py | 184 +++--- .../models/_haniffa_covid_pbmc.py | 31 +- .../models/_heart_cell_atlas.py | 9 +- .../models/_human_lung_cell_atlas.py | 17 +- .../models/_mouse_thymus_cite.py | 31 +- src/scvi_hub_models/models/_neurips_cite.py | 49 +- src/scvi_hub_models/models/_test_scvi.py | 27 +- src/scvi_hub_models/test.ipynb | 603 ++++++++++++++++++ 15 files changed, 821 insertions(+), 174 deletions(-) create mode 100644 src/scvi_hub_models/test.ipynb diff --git a/src/scvi_hub_models/__main__.py b/src/scvi_hub_models/__main__.py index f39cf3e..dc083eb 100644 --- a/src/scvi_hub_models/__main__.py +++ b/src/scvi_hub_models/__main__.py @@ -12,7 +12,15 @@ @click.option("--dry_run", type=bool, default=False, help="Dry run the workflow.") @click.option("--config_key", type=str, help="Use a different config file, e.g. for test purpose.") @click.option("--save_dir", type=str, help="Directory to save intermediate results (defaults temporary).") -def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_dir: str = None) -> None: +@click.option("--reload_data", type=bool, help="Reload the data or get from DVC.") +@click.option("--reload_model", type=bool, help="Reload the model or get from DVC.") +def run_workflow( + model_name: str, + dry_run: bool, + config_key: str = None, + save_dir: str = None, + reload_data: bool = False, + reload_model: bool = False) -> None: """Run the workflow for a specific model.""" from importlib import import_module if not config_key: @@ -22,7 +30,7 @@ def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_di Workflow = workflow_module._Workflow config = json_data_store[config_key] - workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config) + workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config, reload_data=reload_data, reload_model=reload_model) workflow.run() diff --git a/src/scvi_hub_models/config/haniffa_covid_pbmc.json b/src/scvi_hub_models/config/haniffa_covid_pbmc.json index fe5bb16..12c9e50 100644 --- a/src/scvi_hub_models/config/haniffa_covid_pbmc.json +++ b/src/scvi_hub_models/config/haniffa_covid_pbmc.json @@ -2,9 +2,11 @@ "model_dir": "haniffa_covid_pbmc", "model_class": "TOTALVI", "repo_name": "scvi-tools/haniffa_covid_pbmc_totalvi", + "reload_data": true, "extra_data_kwargs": { "reference_adata_cxg_id": "c7775e88-49bf-4ba2-a03b-93f00447c958", - "reference_adata_fname": "haniffa_covid_pbmc.h5ad" + "reference_adata_fname": "haniffa_covid_pbmc.h5ad", + "large_training_file_name": "haniffa_covid_pbmc.h5mu" }, "metadata": { "training_data_url": "https://datasets.cellxgene.cziscience.com/5ad66a4f-d619-4cb3-8015-a87c755647b3.h5ad", @@ -16,6 +18,7 @@ "description": "CITE-seq to measure RNA and surface proteins in thymocytes from wild-type and T cell lineage-restricted mice to generate a comprehensive timeline of cell state for each T cell lineage.", "references": "Steier, Z., Aylard, D.A., McIntyre, L.L. et al. Single-cell multiomic analysis of thymocyte development reveals drivers of CD4+ T cell and CD8+ T cell lineage commitment. Nat Immunol 24, 1579–1590 (2023). https://doi.org/10.1038/s41590-023-01584-0." }, + "criticism_settings": { "n_samples": 3, "cell_type_key": "cell_type" diff --git a/src/scvi_hub_models/config/heart_cell_atlas.json b/src/scvi_hub_models/config/heart_cell_atlas.json index 71ee449..70a77f5 100644 --- a/src/scvi_hub_models/config/heart_cell_atlas.json +++ b/src/scvi_hub_models/config/heart_cell_atlas.json @@ -2,6 +2,10 @@ "model_dir": "heart_cell_atlas_scvi", "model_class": "SCVI", "repo_name": "scvi-tools/heart-cell-atlas-scvi", + "extra_data_kwargs": { + "reference_adata_fname": "heart_cell_atlas.h5ad", + "large_training_file_name": "heart_cell_atlas.h5ad" + }, "metadata": { "training_data_url": "https://www.heartcellatlas.org/#DataSources", "tissues": ["heart"], diff --git a/src/scvi_hub_models/config/human_lung_cell_atlas.json b/src/scvi_hub_models/config/human_lung_cell_atlas.json index 3054b4b..a26c61e 100644 --- a/src/scvi_hub_models/config/human_lung_cell_atlas.json +++ b/src/scvi_hub_models/config/human_lung_cell_atlas.json @@ -1,5 +1,5 @@ { - "model_dir": "hlca_scanvi_reference", + "model_dir": "hlca_reference_scanvi", "model_class": "SCANVI", "repo_name": "scvi-tools/human-lung-cell-atlas-scanvi", "extra_data_kwargs": { @@ -7,7 +7,8 @@ "legacy_model_hash": "a7cd60f4342292b3cba54545bcd8a34decdc8e6b82163f009273d543e7e3910e", "legacy_model_dir": "hlca_scanvi_reference_legacy", "reference_adata_cxg_id": "066943a2-fdac-4b29-b348-40cede398e4e", - "reference_adata_fname": "hlca_core.h5ad" + "reference_adata_fname": "hlca_core.h5ad", + "large_training_file_name": "hlca_core.h5ad" }, "metadata": { "training_data_url": "https://cellxgene.cziscience.com/collections/6f6d381a-7701-4781-935c-db10d30de293", diff --git a/src/scvi_hub_models/config/mouse_thymus_cite.json b/src/scvi_hub_models/config/mouse_thymus_cite.json index 06b590a..5c2ca42 100644 --- a/src/scvi_hub_models/config/mouse_thymus_cite.json +++ b/src/scvi_hub_models/config/mouse_thymus_cite.json @@ -1,11 +1,10 @@ { - "model_dir": "mouse_thymus_cite", + "model_dir": "mouse_thymus_cite_totalvi", "model_class": "TOTALVI", - "repo_name": "scvi-tools/mouse_thymus_totalvi", - "reload_data": true, + "repo_name": "scvi-tools/mouse_thymus_cite_totalvi", "extra_data_kwargs": { "reference_adata_cxg_id": "c14c54f8-85d8-45db-9de7-6ab572cc748a", - "reference_adata_fname": "thymus_cite.h5ad", + "reference_adata_fname": "mouse_thymus_cite.h5ad", "large_training_file_name": "mouse_thymus_cite.h5mu" }, "metadata": { diff --git a/src/scvi_hub_models/config/neurips_cite.json b/src/scvi_hub_models/config/neurips_cite.json index 8d7eec6..e8cf6e0 100644 --- a/src/scvi_hub_models/config/neurips_cite.json +++ b/src/scvi_hub_models/config/neurips_cite.json @@ -1,9 +1,12 @@ { - "model_dir": "bone_marrow_cite", + "model_dir": "bone_marrow_cite_totalvi", "model_class": "TOTALVI", "repo_name": "scvi-tools/bone_marrow_cite_totalvi", "extra_data_kwargs": { - "reference_adata_fname": "bmmc_cite.h5ad" + "url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz", + "hash": "b9b50fade9349719cba23c97c6515d3501a32ee3735fe95fe51221d2e8a5f361", + "reference_adata_fname": "bmmc_cite.h5ad.gz", + "large_training_file_name": "neurips_bone_marrow_cite.h5mu" }, "metadata": { "training_data_url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz", diff --git a/src/scvi_hub_models/config/test_scvi.json b/src/scvi_hub_models/config/test_scvi.json index 3ba0f83..104b9e1 100644 --- a/src/scvi_hub_models/config/test_scvi.json +++ b/src/scvi_hub_models/config/test_scvi.json @@ -2,9 +2,11 @@ "model_dir": "test_scvi", "model_class": "SCVI", "repo_name": "scvi-tools/test-scvi", + "extra_data_kwargs": { + "large_training_file_name": "test_data.h5ad" + }, "collection_name": "test", "minify_model": false, - "extra_data_kwargs": {}, "metadata": { "tissues": ["synthetic"], "data_modalities": ["rna"], diff --git a/src/scvi_hub_models/models/_base_workflow.py b/src/scvi_hub_models/models/_base_workflow.py index 08d19c8..0787577 100644 --- a/src/scvi_hub_models/models/_base_workflow.py +++ b/src/scvi_hub_models/models/_base_workflow.py @@ -14,28 +14,9 @@ from scvi.hub import HubMetadata, HubModel, HubModelCardHelper from scvi.model.base import BaseModelClass -import subprocess, os -from pydrive.auth import GoogleAuth -from pydrive.drive import GoogleDrive - - -def upload_gdrive(path): - remote_url = subprocess.check_output(["dvc", "remote", "list"], text=True).split()[-1] - - # Extract the Google Drive folder ID from the remote URL - folder_id = remote_url.split("gdrive://")[-1] - - # Check if DVC detected an update - if "modified" in subprocess.check_output(["dvc", "diff", "--json"], text=True): - GoogleAuth().LocalWebserverAuth() - drive = GoogleDrive(GoogleAuth()) - file = drive.CreateFile({"title": os.path.basename(path), "parents": [{"id": folder_id}]}) - file.SetContentFile(path) - file.Upload() - print(f"Uploaded {path} to Google Drive folder {folder_id}.") - # Specify your repository and target file -repo_path = "." + +repo_path = os.path.abspath(Path(__file__).parent.parent.parent.parent) dvc_repo = Repo(repo_path) git_repo = git.Repo(repo_path) @@ -67,17 +48,25 @@ class BaseModelWorkflow: config A :class:`~frozendict.frozendict` containing the configuration for the workflow. Can only be set once. + reload_data + If ``True``, the data will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``. + reload_model + If ``True``, the model will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``. """ def __init__( self, save_dir: str | None = None, dry_run: bool = False, - config: frozendict | None = None + config: frozendict | None = None, + reload_data: bool = True, + reload_model: bool = True, ): self.save_dir = save_dir self.dry_run = dry_run self.config = config + self.reload_data = reload_data + self.reload_model = reload_model @property def save_dir(self): @@ -114,22 +103,41 @@ def config(self, value: frozendict): value = frozendict(value) self._config = value + @property + def reload_data(self): + return self._reload_data + + @reload_data.setter + def reload_data(self, value: bool): + if hasattr(self, "_reload_data"): + raise AttributeError("`reload_data` can only be set once.") + self._reload_data = value + + @property + def reload_model(self): + return self._reload_model + + @reload_model.setter + def reload_model(self, value: bool): + if hasattr(self, "_reload_model"): + raise AttributeError("`reload_model` can only be set once.") + self._reload_model = value + def get_adata(self) -> anndata.AnnData | None: """Download and load the dataset.""" logger.info("Loading dataset.") if self.dry_run: return None - if self.config['reload_data']: - path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name']) - adata = self.download_adata() + if self.reload_data: + path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name']) + print(path_file) + adata = self.download_adata(path_file) dvc_repo.add(path_file) git_repo.index.commit(f"Track {path_file} with DVC") - print(f"Pushing {path_file} to DVC remote...") dvc_repo.push() git_repo.remote().push() - upload_gdrive(path_file) else: - path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name']) + path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name']) dvc_repo.pull([path_file]) if path_file.endswith(".h5mu"): adata = mudata.read_h5mu(path_file) @@ -137,34 +145,40 @@ def get_adata(self) -> anndata.AnnData | None: adata = anndata.read_h5ad(path_file) return adata - def _get_adata(self, url: str, hash: str, file_path: str) -> str: + def get_model(self, adata) -> BaseModelClass | None: + """Download and load the model.""" + logger.info("Loading model.") + if self.dry_run: + return None + if self.reload_model: + path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir']) + model = self.load_model(adata) + model.save(path_file, overwrite=True, save_anndata=False) + dvc_repo.add(path_file) + git_repo.index.commit(f"Track {path_file} with DVC") + dvc_repo.push() + git_repo.remote().push() + else: + path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir']) + dvc_repo.pull([path_file]) + model = self.default_load_model(adata, self.config['model_class'], path_file) + return model + + def _get_adata(self, url: str, hash: str, file_path: str, processor: str | None = None) -> str: logger.info("Downloading and reading data.") if self.dry_run: return None - retrieve( + file_out = retrieve( url=url, known_hash=hash, fname=file_path, path=self.save_dir, - processor=None, - ) - return anndata.read_h5ad(os.path.join(self.save_dir, file_path)) - - def _download_model(self, url: str, hash: str, file_path: str) -> str: - logger.info("Downloading model.") - if self.dry_run: - return None - - return retrieve( - url=url, # - known_hash=hash, #config["adata_hashes"][tissue], - fname=file_path, # f"{tissue}_adata.h5ad" - path=self.save_dir, - processor=None, + processor=processor, ) + return anndata.read_h5ad(file_out) - def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str): + def default_load_model(self, adata: anndata.AnnData, model_name: str, model_path: str | None = None) -> BaseModelClass: """Load the model.""" logger.info("Loading model.") if self.dry_run: @@ -178,48 +192,20 @@ def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str): elif model_name == "CondSCVI": from scvi.model import CondSCVI model_cls = CondSCVI + elif model_name == "TOTALVI": + from scvi.model import TOTALVI + model_cls = TOTALVI elif model_name == "Stereoscope": from scvi.external import RNAStereoscope model_cls = RNAStereoscope else: raise ValueError(f"Model {model_name} not recognized.") - model = model_cls.load(os.path.join(self.save_dir, model_path), adata=adata) + if model_path is None: + model_path = os.path.join(self.save_dir, self.config["model_dir"]) + model = model_cls.load(model_path, adata=adata) return model - def _create_hub_model( - self, - model_path: str, - training_data_url: str | None = None - ) -> HubModel | None: - logger.info("Creating the HubModel.") - if self.dry_run: - return None - - if training_data_url is None: - training_data_url = self.config.get("training_data_url", None) - - metadata = self.config["metadata"] - hub_metadata = HubMetadata.from_dir( - model_path, - anndata_version=anndata_version - ) - model_card = HubModelCardHelper.from_dir( - model_path, - anndata_version=anndata_version, - license_info=metadata.get("license_info", "mit"), - data_modalities=metadata.get("data_modalities", None), - tissues=metadata.get("tissues", None), - data_is_annotated=metadata.get("data_is_annotated", False), - data_is_minified=metadata.get("data_is_minified", False), - training_data_url=training_data_url, - training_code_url=metadata.get("training_code_url", None), - description=metadata.get("description", None), - references=metadata.get("references", None), - ) - - return HubModel(model_path, hub_metadata, model_card) - def _minify_and_save_model( self, model: BaseModelClass, @@ -252,11 +238,47 @@ def _minify_and_save_model( qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) adata.obsm[qzm_key] = qzm adata.obsm[qzv_key] = qzv - model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) + if isinstance(adata, mudata.MuData): + model.minify_mudata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) + else: + model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) model.save(mini_model_path, overwrite=True, save_anndata=True) return mini_model_path + def _create_hub_model( + self, + model_path: str, + training_data_url: str | None = None + ) -> HubModel | None: + logger.info("Creating the HubModel.") + if self.dry_run: + return None + + if training_data_url is None: + training_data_url = self.config.get("training_data_url", None) + + metadata = self.config["metadata"] + hub_metadata = HubMetadata.from_dir( + model_path, + anndata_version=anndata_version + ) + model_card = HubModelCardHelper.from_dir( + model_path, + anndata_version=anndata_version, + license_info=metadata.get("license_info", "mit"), + data_modalities=metadata.get("data_modalities", None), + tissues=metadata.get("tissues", None), + data_is_annotated=metadata.get("data_is_annotated", False), + data_is_minified=metadata.get("data_is_minified", False), + training_data_url=training_data_url, + training_code_url=metadata.get("training_code_url", None), + description=metadata.get("description", None), + references=metadata.get("references", None), + ) + + return HubModel(model_path, hub_metadata, model_card) + def _upload_hub_model(self, hub_model: HubModel, repo_name: str | None = None, **kwargs) -> HubModel: """Upload the HubModel to Hugging Face.""" collection_name = self.config.get("collection_name", None) diff --git a/src/scvi_hub_models/models/_haniffa_covid_pbmc.py b/src/scvi_hub_models/models/_haniffa_covid_pbmc.py index dc47600..6262d18 100644 --- a/src/scvi_hub_models/models/_haniffa_covid_pbmc.py +++ b/src/scvi_hub_models/models/_haniffa_covid_pbmc.py @@ -18,15 +18,12 @@ def _load_adata(self) -> AnnData: adata_path = os.path.join(self.save_dir, self.config['extra_data_kwargs']["reference_adata_fname"]) if not os.path.exists(adata_path): - # TODO for next LTX remove census_version='latest'. - download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path, census_version='latest') + download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path) return sc.read_h5ad(adata_path) def _preprocess_adata(self, adata: AnnData) -> AnnData: import scanpy as sc - print(adata, adata.X.data) - sc.pp.filter_genes(adata, min_counts=3) adata.layers["counts"] = adata.X.copy() sc.pp.highly_variable_genes( @@ -38,20 +35,28 @@ def _preprocess_adata(self, adata: AnnData) -> AnnData: batch_key="sample_id", span=1.0, ) - protein_adata = AnnData(adata.obsm["protein_expression"]) + protein_adata = AnnData( + adata.uns['antibody_raw.X'].toarray(), + obs=adata.obs, + var=adata.uns['antibody_features']) + protein_adata.obs_names = adata.obs_names - del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + del adata.uns['antibody_raw.X'] + del adata.uns['antibody_features'] + del adata.uns['antibody_X'] + del adata.uns['neighbors'] + mdata = MuData({"rna": adata, "protein": protein_adata}) - return adata + return mdata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") if self.dry_run: return None adata = self._load_adata() mdata = self._preprocess_adata(adata) + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: @@ -59,7 +64,7 @@ def _initialize_model(self, mdata: MuData) -> TOTALVI: mdata, rna_layer="counts", protein_layer=None, - batch_key="sample_id", + batch_key="donor_id", modalities={ "rna_layer": "rna", "protein_layer": "protein", @@ -70,11 +75,11 @@ def _initialize_model(self, mdata: MuData) -> TOTALVI: def _train_model(self, model: TOTALVI) -> TOTALVI: """Train the scVI model.""" - model.train(max_epochs=200) + model.train(max_epochs=50) return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -85,7 +90,7 @@ def get_model(self, adata) -> TOTALVI | None: def run(self): super().run() - mdata = self.load_adata() + mdata = self.get_adata() model = self.get_model(mdata) model_path = self._minify_and_save_model(model, mdata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_heart_cell_atlas.py b/src/scvi_hub_models/models/_heart_cell_atlas.py index 65c3822..abea1b5 100644 --- a/src/scvi_hub_models/models/_heart_cell_atlas.py +++ b/src/scvi_hub_models/models/_heart_cell_atlas.py @@ -31,13 +31,14 @@ def _preprocess_adata(self, adata: AnnData) -> AnnData: return adata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" logger.info(f"Saving heart cell atlas dataset to {self.save_dir}.") if self.dry_run: return None adata = self._load_adata() adata = self._preprocess_adata(adata) + adata.write_h5ad(path) return adata def _initialize_model(self, adata: AnnData) -> SCVI: @@ -51,11 +52,11 @@ def _initialize_model(self, adata: AnnData) -> SCVI: def _train_model(self, model: SCVI) -> SCVI: """Train the scVI model.""" - model.train(max_epochs=5) + model.train(max_epochs=200) return model - def get_model(self, adata) -> SCVI | None: + def load_model(self, adata) -> SCVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -66,7 +67,7 @@ def get_model(self, adata) -> SCVI | None: def run(self): super().run() - adata = self.load_adata() + adata = self.get_adata() model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_human_lung_cell_atlas.py b/src/scvi_hub_models/models/_human_lung_cell_atlas.py index f68baef..6dd8101 100644 --- a/src/scvi_hub_models/models/_human_lung_cell_atlas.py +++ b/src/scvi_hub_models/models/_human_lung_cell_atlas.py @@ -10,6 +10,9 @@ class _Workflow(BaseModelWorkflow): + def load_model(self, adata: anndata.AnnData) -> BaseModelWorkflow: + return self.default_load_model(adata, self.config['model_class']) + def _download_model(self): from pathlib import Path @@ -23,7 +26,6 @@ def _download_model(self): path=self.save_dir, ) untarred = sorted(untarred) - print(untarred) return str(Path(untarred[0]).parent) def _get_model(self) -> str: @@ -64,7 +66,7 @@ def _preprocess_reference_adata(self, adata: anndata.AnnData, model_path: str) - # .X does not contain raw counts initially adata.X = adata.raw.X - _, genes, _, _ = _load_saved_files(model_path, load_adata=False) + _, genes, _, _ = _load_saved_files(os.path.join(self.save_dir, self.config["model_dir"]), load_adata=False) adata = adata[:, adata.var.index.isin(genes)].copy() # get rid of some var columns that we dont need @@ -108,13 +110,14 @@ def _download_embedding_adata(self) -> str: adata = anndata.io.read_h5ad(adata) return adata[adata.obs["core_or_extension"] == "core"].copy() - def _get_adata(self, model_path: str) -> anndata.AnnData: + def download_adata(self, path) -> anndata.AnnData: logging.info("Loading data.") if self.dry_run: return None ref_adata = self._download_reference_adata() - ref_adata = self._preprocess_reference_adata(ref_adata, model_path) + ref_adata = self._preprocess_reference_adata(ref_adata, self.model_path) ref_adata = self._postprocess_reference_adata(ref_adata) + ref_adata.write_h5ad(path) return ref_adata @property @@ -124,9 +127,9 @@ def id(self) -> str: def run(self): super().run() - model_path = self._get_model() - adata = self._get_adata(model_path) - model = self._load_model(model_path, adata, "SCANVI") + self._get_model() + adata = self.get_adata() + model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) hub_model = self._upload_hub_model(hub_model) diff --git a/src/scvi_hub_models/models/_mouse_thymus_cite.py b/src/scvi_hub_models/models/_mouse_thymus_cite.py index 56c1ec8..255b1ee 100644 --- a/src/scvi_hub_models/models/_mouse_thymus_cite.py +++ b/src/scvi_hub_models/models/_mouse_thymus_cite.py @@ -23,28 +23,35 @@ def _load_adata(self) -> AnnData: return sc.read_h5ad(adata_path) def _preprocess_adata(self, adata: AnnData) -> AnnData: - import scanpy as sc - - print(adata, adata.X.data) - - sc.pp.filter_genes(adata, min_counts=3) matching_indices = [adata.raw.var_names.get_loc(gene) for gene in adata.var_names] adata.layers["counts"] = adata.raw.X[:, matching_indices].copy() - protein_adata = AnnData(adata.obsm["protein_expression"]) + sc.pp.highly_variable_genes( + adata, + n_top_genes=4000, + subset=True, + layer="counts", + flavor="seurat_v3", + span=1.0, + ) + protein_adata = AnnData(adata.obsm["protein_expression"], obs=adata.obs) protein_adata.obs_names = adata.obs_names del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + del adata.obsm['denoised_genes'] + del adata.obsm['denoised_proteins'] + del adata.uns['AB_adata'] + mdata = MuData({"rna": adata, "protein": protein_adata}) + print(mdata) - return adata + return mdata - def download_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" - logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") + logger.info(f"Saving dataset to {path} and preprocessing.") if self.dry_run: return None adata = self._load_adata() mdata = self._preprocess_adata(adata) - mdata.write_h5mu(f'data/{self.config['extra_data_kwargs']['large_training_file_name']}') + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: @@ -67,7 +74,7 @@ def _train_model(self, model: TOTALVI) -> TOTALVI: return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: diff --git a/src/scvi_hub_models/models/_neurips_cite.py b/src/scvi_hub_models/models/_neurips_cite.py index 74f6cfd..af4c470 100644 --- a/src/scvi_hub_models/models/_neurips_cite.py +++ b/src/scvi_hub_models/models/_neurips_cite.py @@ -1,9 +1,9 @@ import logging -import os import scanpy as sc from anndata import AnnData from mudata import MuData +from pooch import Decompress from scvi.model import TOTALVI from scvi_hub_models.models import BaseModelWorkflow @@ -12,52 +12,45 @@ class _Workflow(BaseModelWorkflow): - - def _load_adata(self) -> AnnData: - from cellxgene_census import download_source_h5ad - - adata_path = os.path.join(self.save_dir, self.config['extra_data_kwargs']["reference_adata_fname"]) - if not os.path.exists(adata_path): - # TODO for next LTS remove census_version='latest'. - download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path, census_version='latest') - return sc.read_h5ad(adata_path) - def _preprocess_adata(self, adata: AnnData) -> AnnData: - import scanpy as sc - - sc.pp.filter_genes(adata, min_counts=3) - adata.layers["counts"] = adata.X.copy() + rna = adata[:, adata.var['feature_types']=='GEX'].copy() + protein = adata[:, adata.var['feature_types']=='ADT'].copy() + protein.layers["counts"] = protein.layers["counts"].toarray() + sc.pp.filter_genes(rna, min_counts=3) sc.pp.highly_variable_genes( - adata, + rna, n_top_genes=4000, subset=True, layer="counts", flavor="seurat_v3", - batch_key="sample_id", + batch_key="Site", span=1.0, ) - protein_adata = AnnData(adata.obsm["protein_expression"]) - protein_adata.obs_names = adata.obs_names - del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + adata = MuData({"rna": rna, "protein": protein}) return adata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" - logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") + logger.info(f"Saving dataset to {path} and preprocessing.") if self.dry_run: return None - adata = self._load_adata() + adata = self._get_adata( + url=self.config["extra_data_kwargs"]["url"], + hash=self.config["extra_data_kwargs"]["hash"], + file_path=self.config["extra_data_kwargs"]["reference_adata_fname"], + processor=Decompress(), + ) mdata = self._preprocess_adata(adata) + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: TOTALVI.setup_mudata( mdata, rna_layer="counts", - protein_layer=None, - batch_key="sample_id", + protein_layer="counts", + batch_key="batch", modalities={ "rna_layer": "rna", "protein_layer": "protein", @@ -72,7 +65,7 @@ def _train_model(self, model: TOTALVI) -> TOTALVI: return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -83,7 +76,7 @@ def get_model(self, adata) -> TOTALVI | None: def run(self): super().run() - mdata = self.load_adata() + mdata = self.get_adata() model = self.get_model(mdata) model_path = self._minify_and_save_model(model, mdata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_test_scvi.py b/src/scvi_hub_models/models/_test_scvi.py index b72913d..7c326b1 100644 --- a/src/scvi_hub_models/models/_test_scvi.py +++ b/src/scvi_hub_models/models/_test_scvi.py @@ -12,29 +12,23 @@ class _Workflow(BaseModelWorkflow): - def load_dataset(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: from scvi.data import synthetic_iid logger.info("Loading synthetic dataset.") if self.dry_run: return None + adata = synthetic_iid() + adata.write_h5ad(path) + return adata - return synthetic_iid() - - def initialize_model(self, adata: AnnData | None) -> SCVI | None: - logger.info("Initializing the scVI model.") + def load_model(self, adata: AnnData) -> SCVI: + logger.info("Training the scVI model.") if self.dry_run: return None - SCVI.setup_anndata(adata) - return SCVI(adata) - - def train_model(self, model: SCVI | None) -> SCVI | None: - logger.info("Training the scVI model.") - if self.dry_run: - return model - - model.train(max_epochs=1) + model = SCVI(adata) + model.train(max_epochs=10) return model @property @@ -44,9 +38,8 @@ def id(self) -> str: def run(self): super().run() - adata = self.load_dataset() - model = self.initialize_model(adata) - model = self.train_model(model) + adata = self.get_adata() + model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) hub_model = self._upload_hub_model(hub_model) diff --git a/src/scvi_hub_models/test.ipynb b/src/scvi_hub_models/test.ipynb new file mode 100644 index 0000000..c730192 --- /dev/null +++ b/src/scvi_hub_models/test.ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '/home/cane/Documents/scvi-tools/scvi-hub-models/src')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "\n", + "import scanpy as sc\n", + "from anndata import AnnData\n", + "from mudata import MuData\n", + "from scvi.model import TOTALVI\n", + "\n", + "from scvi_hub_models.models import BaseModelWorkflow\n", + "\n", + "logger = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n" + ] + } + ], + "source": [ + "import mudata\n", + "mu = mudata.read_h5mu('/home/cane/Documents/scvi-tools/scvi-hub-models/test/mini_totalvi/mdata.h5mu')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
MuData object with n_obs × n_vars = 72042 × 4111\n",
+       "  obs:\t'_scvi_labels', 'observed_lib_size'\n",
+       "  var:\t'cv_gene'\n",
+       "  uns:\t'_scvi_adata_minify_type', '_scvi_manager_uuid', '_scvi_uuid'\n",
+       "  obsm:\t'totalvi_latent_qzm', 'totalvi_latent_qzv'\n",
+       "  2 modalities\n",
+       "    rna:\t72042 x 4000\n",
+       "      obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n",
+       "      var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'cv_gene'\n",
+       "      uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'hvg', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n",
+       "      obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n",
+       "      varm:\t'lfc_model', 'lfc_raw'\n",
+       "      layers:\t'counts'\n",
+       "    protein:\t72042 x 111\n",
+       "      obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n",
+       "      var:\t'cv_gene'\n",
+       "      varm:\t'lfc_model', 'lfc_raw'
" + ], + "text/plain": [ + "MuData object with n_obs × n_vars = 72042 × 4111\n", + " obs:\t'_scvi_labels', 'observed_lib_size'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_adata_minify_type', '_scvi_manager_uuid', '_scvi_uuid'\n", + " obsm:\t'totalvi_latent_qzm', 'totalvi_latent_qzv'\n", + " 2 modalities\n", + " rna:\t72042 x 4000\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'cv_gene'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'hvg', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " varm:\t'lfc_model', 'lfc_raw'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " varm:\t'lfc_model', 'lfc_raw'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mu" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "from scvi_hub_models.config import json_data_store" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from scvi_hub_models.models import _neurips_cite\n", + "Workflow = _neurips_cite._Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "self = Workflow(config=json_data_store['neurips_cite'], save_dir='.', reload_model=True, reload_data=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'bmmc_cite.h5ad.gz'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "self.config['extra_data_kwargs']['reference_adata_fname']" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Decompressing '/home/cane/.cache/pooch/bmmc_cite.h5ad.gz' to '/home/cane/.cache/pooch/bmmc_cite.h5ad.gz.decomp' using method 'auto'.\n" + ] + } + ], + "source": [ + "from pooch import retrieve\n", + "from pooch import Decompress\n", + "\n", + "adata_path = retrieve(\n", + " url=self.config['extra_data_kwargs']['url'],\n", + " known_hash=\"b9b50fade9349719cba23c97c6515d3501a32ee3735fe95fe51221d2e8a5f361\",\n", + " fname='bmmc_cite.h5ad.gz',\n", + " processor=Decompress(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/anndata.py:1758: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", + " utils.warn_names_duplicates(\"var\")\n" + ] + } + ], + "source": [ + "import anndata\n", + "\n", + "ad = anndata.read_h5ad('/home/cane/.cache/pooch/bmmc_cite.h5ad')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AnnData object with n_obs × n_vars = 90261 × 14087\n", + " obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'\n", + " var: 'feature_types', 'gene_id'\n", + " uns: 'dataset_id', 'genome', 'organism'\n", + " obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'\n", + " layers: 'counts'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ad" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/cane/Documents/scvi-tools/scvi-hub-models/data/neurips_bone_marrow_cite.h5mu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/anndata.py:1758: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", + " utils.warn_names_duplicates(\"var\")\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:931: UserWarning: Cannot join columns with the same name because var_names are intersecting.\n", + " warnings.warn(\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "mdata = self.get_adata()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "GEX_n_genes_by_counts                         893\n",
+       "GEX_pct_counts_mt                        6.723979\n",
+       "GEX_size_factors                         0.356535\n",
+       "GEX_phase                                      G1\n",
+       "ADT_n_antibodies_by_counts                    115\n",
+       "ADT_total_counts                           2828.0\n",
+       "ADT_iso_count                                 5.0\n",
+       "cell_type                     Naive CD20+ B IGKC+\n",
+       "batch                                        s1d1\n",
+       "ADT_pseudotime_order                          NaN\n",
+       "GEX_pseudotime_order                          NaN\n",
+       "Samplename                      site1_donor1_cite\n",
+       "Site                                        site1\n",
+       "DonorNumber                                donor1\n",
+       "Modality                                     cite\n",
+       "VendorLot                                 3054455\n",
+       "DonorID                                     15078\n",
+       "DonorAge                                       34\n",
+       "DonorBMI                                     24.8\n",
+       "DonorBloodType                                 B-\n",
+       "DonorRace                                   White\n",
+       "Ethnicity                      HISPANIC OR LATINO\n",
+       "DonorGender                                  Male\n",
+       "QCMeds                                      False\n",
+       "DonorSmoker                             Nonsmoker\n",
+       "is_train                                    train\n",
+       "_scvi_batch                                     0\n",
+       "cv_cell                                  7.578709\n",
+       "Name: GCATTAGCATAAGCGG-1-s1d1, dtype: object"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "mdata['rna'].obs.iloc[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "INFO     Computing empirical prior initialization for protein background.                                          \n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "GPU available: True (cuda), used: True\n",
+      "TPU available: False, using: 0 TPU cores\n",
+      "HPU available: False, using: 0 HPUs\n",
+      "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
+      "  warnings.warn(\n",
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py:316: The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.\n",
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n",
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 200/200: 100%|██████████| 200/200 [22:05<00:00,  7.33s/it, v_num=1, train_loss_step=1.58e+3, train_loss_epoch=1.6e+3]"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 200/200: 100%|██████████| 200/200 [22:05<00:00,  6.63s/it, v_num=1, train_loss_step=1.58e+3, train_loss_epoch=1.6e+3]\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['rna', 'protein']\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/cane/Documents/scvi-tools/src/scvi/criticism/_ppc.py:293: UserWarning: n_top_genes_fallback=100 is greater than 10% of the number ofgenes f(134) in the dataset. Setting it to 10%.\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "Only one class present in y_true. ROC AUC score is not defined in that case.",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[23], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_model(mdata)\n\u001b[0;32m----> 2\u001b[0m model_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minify_and_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      3\u001b[0m hub_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_create_hub_model(model_path)\n\u001b[1;32m      4\u001b[0m hub_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_upload_hub_model(hub_model)\n",
+      "File \u001b[0;32m~/Documents/scvi-tools/scvi-hub-models/src/scvi_hub_models/models/_base_workflow.py:239\u001b[0m, in \u001b[0;36m_minify_and_save_model\u001b[0;34m(self, model, adata)\u001b[0m\n\u001b[1;32m    237\u001b[0m qzm, qzv \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mget_latent_representation(give_mean\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, return_dist\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m    238\u001b[0m adata\u001b[38;5;241m.\u001b[39mobsm[qzm_key] \u001b[38;5;241m=\u001b[39m qzm\n\u001b[0;32m--> 239\u001b[0m adata\u001b[38;5;241m.\u001b[39mobsm[qzv_key] \u001b[38;5;241m=\u001b[39m qzv\n\u001b[1;32m    240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(adata, mudata\u001b[38;5;241m.\u001b[39mMuData):\n\u001b[1;32m    241\u001b[0m     model\u001b[38;5;241m.\u001b[39mminify_mudata(use_latent_qzm_key\u001b[38;5;241m=\u001b[39mqzm_key, use_latent_qzv_key\u001b[38;5;241m=\u001b[39mqzv_key)\n",
+      "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:76\u001b[0m, in \u001b[0;36mcreate_criticism_report\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, save_folder)\u001b[0m\n\u001b[1;32m     74\u001b[0m md_cell_wise_cv, md_gene_wise_cv, md_de \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m modalities:\n\u001b[0;32m---> 76\u001b[0m     md_cell_wise_cv_, md_gene_wise_cv_, md_de_ \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     77\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_metrics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     78\u001b[0m     md_cell_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_cell_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     79\u001b[0m     md_gene_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_gene_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n",
+      "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:116\u001b[0m, in \u001b[0;36mcompute_metrics\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, modality)\u001b[0m\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m label_key \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m labels_state_registry\u001b[38;5;241m.\u001b[39moriginal_key \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_scvi_labels\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    115\u001b[0m     label_key \u001b[38;5;241m=\u001b[39m labels_state_registry\u001b[38;5;241m.\u001b[39moriginal_key\n\u001b[0;32m--> 116\u001b[0m \u001b[43mppc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdifferential_expression\u001b[49m\u001b[43m(\u001b[49m\u001b[43mde_groupby\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp_val_thresh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    117\u001b[0m summary_df \u001b[38;5;241m=\u001b[39m ppc\u001b[38;5;241m.\u001b[39mmetrics[METRIC_DIFF_EXP][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msummary\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mset_index(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    118\u001b[0m summary_df \u001b[38;5;241m=\u001b[39m summary_df\u001b[38;5;241m.\u001b[39mdrop(columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
+      "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/utils/_dependencies.py:24\u001b[0m, in \u001b[0;36mdependencies..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m     22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     23\u001b[0m     error_on_missing_dependencies(\u001b[38;5;241m*\u001b[39mmodules)\n\u001b[0;32m---> 24\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_ppc.py:430\u001b[0m, in \u001b[0;36mPosteriorPredictiveCheck.differential_expression\u001b[0;34m(self, de_groupby, de_method, n_samples, cell_scale_factor, p_val_thresh, n_top_genes_fallback)\u001b[0m\n\u001b[1;32m    427\u001b[0m         true \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros_like(pred)\n\u001b[1;32m    428\u001b[0m         true[np\u001b[38;5;241m.\u001b[39margsort(raw_adj_p_vals)[:n_top_genes_fallback]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 430\u001b[0m     roc_aucs\u001b[38;5;241m.\u001b[39mappend(\u001b[43mroc_auc_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    431\u001b[0m     pr_aucs\u001b[38;5;241m.\u001b[39mappend(average_precision_score(true, pred))\n\u001b[1;32m    433\u001b[0m \u001b[38;5;66;03m# Compute means over samples\u001b[39;00m\n",
+      "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/sklearn/utils/_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    208\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m    209\u001b[0m         skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m    210\u001b[0m             prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m    211\u001b[0m         )\n\u001b[1;32m    212\u001b[0m     ):\n\u001b[0;32m--> 213\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    215\u001b[0m     \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m    217\u001b[0m     \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m    218\u001b[0m     \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m    219\u001b[0m     msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m    220\u001b[0m         \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    221\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    222\u001b[0m         \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m    223\u001b[0m     )\n",
+      "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/sklearn/metrics/_ranking.py:640\u001b[0m, in \u001b[0;36mroc_auc_score\u001b[0;34m(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels)\u001b[0m\n\u001b[1;32m    638\u001b[0m     labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39munique(y_true)\n\u001b[1;32m    639\u001b[0m     y_true \u001b[38;5;241m=\u001b[39m label_binarize(y_true, classes\u001b[38;5;241m=\u001b[39mlabels)[:, \u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 640\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_average_binary_score\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    641\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpartial\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_binary_roc_auc_score\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_fpr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_fpr\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    642\u001b[0m \u001b[43m        \u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    643\u001b[0m \u001b[43m        \u001b[49m\u001b[43my_score\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    644\u001b[0m \u001b[43m        \u001b[49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    645\u001b[0m \u001b[43m        \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    646\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    647\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:  \u001b[38;5;66;03m# multilabel-indicator\u001b[39;00m\n\u001b[1;32m    648\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _average_binary_score(\n\u001b[1;32m    649\u001b[0m         partial(_binary_roc_auc_score, max_fpr\u001b[38;5;241m=\u001b[39mmax_fpr),\n\u001b[1;32m    650\u001b[0m         y_true,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    653\u001b[0m         sample_weight\u001b[38;5;241m=\u001b[39msample_weight,\n\u001b[1;32m    654\u001b[0m     )\n",
+      "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/sklearn/metrics/_base.py:76\u001b[0m, in \u001b[0;36m_average_binary_score\u001b[0;34m(binary_metric, y_true, y_score, average, sample_weight)\u001b[0m\n\u001b[1;32m     73\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{0}\u001b[39;00m\u001b[38;5;124m format is not supported\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(y_type))\n\u001b[1;32m     75\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 76\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbinary_metric\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_score\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     78\u001b[0m check_consistent_length(y_true, y_score, sample_weight)\n\u001b[1;32m     79\u001b[0m y_true \u001b[38;5;241m=\u001b[39m check_array(y_true)\n",
+      "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/sklearn/metrics/_ranking.py:382\u001b[0m, in \u001b[0;36m_binary_roc_auc_score\u001b[0;34m(y_true, y_score, sample_weight, max_fpr)\u001b[0m\n\u001b[1;32m    380\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Binary roc auc score.\"\"\"\u001b[39;00m\n\u001b[1;32m    381\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(np\u001b[38;5;241m.\u001b[39munique(y_true)) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[0;32m--> 382\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    383\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOnly one class present in y_true. ROC AUC score \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    384\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis not defined in that case.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    385\u001b[0m     )\n\u001b[1;32m    387\u001b[0m fpr, tpr, _ \u001b[38;5;241m=\u001b[39m roc_curve(y_true, y_score, sample_weight\u001b[38;5;241m=\u001b[39msample_weight)\n\u001b[1;32m    388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m max_fpr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m max_fpr \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
+      "\u001b[0;31mValueError\u001b[0m: Only one class present in y_true. ROC AUC score is not defined in that case."
+     ]
+    }
+   ],
+   "source": [
+    "model = self.get_model(mdata)\n",
+    "model_path = self._minify_and_save_model(model, mdata)\n",
+    "hub_model = self._create_hub_model(model_path)\n",
+    "hub_model = self._upload_hub_model(hub_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "INFO     File ../mini_totalvi/model.pt already downloaded                                                          \n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/cane/Documents/scvi-tools/src/scvi/model/base/_save_load.py:76: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
+      "  model = torch.load(model_path, map_location=map_location)\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "INFO     File ../mini_totalvi/model.pt already downloaded                                                          \n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/cane/Documents/scvi-tools/src/scvi/model/base/_save_load.py:76: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
+      "  model = torch.load(model_path, map_location=map_location)\n",
+      "No files have been modified since last commit. Skipping to prevent empty commit.\n",
+      "No files have been modified since last commit. Skipping to prevent empty commit.\n"
+     ]
+    }
+   ],
+   "source": [
+    "hub_model = self._create_hub_model(model_path)\n",
+    "hub_model = self._upload_hub_model(hub_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n",
+      "  self._update_attr(\"var\", axis=0, join_common=join_common)\n",
+      "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n",
+      "  self._update_attr(\"obs\", axis=1, join_common=join_common)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
MuData object with n_obs × n_vars = 72042 × 16104\n",
+       "  obs:\t'_scvi_labels', 'cv_cell'\n",
+       "  var:\t'cv_gene'\n",
+       "  uns:\t'_scvi_uuid', '_scvi_manager_uuid'\n",
+       "  2 modalities\n",
+       "    rna:\t72042 x 15993\n",
+       "      obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'\n",
+       "      var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts'\n",
+       "      uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n",
+       "      obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n",
+       "      layers:\t'counts'\n",
+       "    protein:\t72042 x 111\n",
+       "      obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'
" + ], + "text/plain": [ + "MuData object with n_obs × n_vars = 72042 × 16104\n", + " obs:\t'_scvi_labels', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_uuid', '_scvi_manager_uuid'\n", + " 2 modalities\n", + " rna:\t72042 x 15993\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mdata.update()\n", + "mdata['protein'].obs = mdata['rna'].obs\n", + "mdata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['rna', 'protein']\n" + ] + }, + { + "ename": "KeyError", + "evalue": "\"Values ['rna', 'protein'], from ['rna', 'protein'], are not valid obs/ var names or indices.\"", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minify_and_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/scvi-tools/scvi-hub-models/src/scvi_hub_models/models/_base_workflow.py:267\u001b[0m, in \u001b[0;36mBaseModelWorkflow._minify_and_save_model\u001b[0;34m(self, model, adata)\u001b[0m\n\u001b[1;32m 265\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(mini_model_path)\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcreate_criticism_report\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m model\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m SUPPORTED_PPC_MODELS:\n\u001b[0;32m--> 267\u001b[0m \u001b[43mcreate_criticism_report\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmini_model_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcriticism_settings\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_samples\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcriticism_settings\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcell_type_key\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminify_model\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m model\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m SUPPORTED_MINIFIED_MODELS:\n\u001b[1;32m 275\u001b[0m qzm_key \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;241m.\u001b[39mlower()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_latent_qzm\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:76\u001b[0m, in \u001b[0;36mcreate_criticism_report\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, save_folder)\u001b[0m\n\u001b[1;32m 74\u001b[0m md_cell_wise_cv, md_gene_wise_cv, md_de \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m [modalities]:\n\u001b[0;32m---> 76\u001b[0m md_cell_wise_cv_, md_gene_wise_cv_, md_de_ \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_metrics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m md_cell_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_cell_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 79\u001b[0m md_gene_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_gene_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:99\u001b[0m, in \u001b[0;36mcompute_metrics\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, modality)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_metrics\u001b[39m(model, adata, skip_metrics, n_samples, label_key, modality\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 98\u001b[0m models_dict \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m: model}\n\u001b[0;32m---> 99\u001b[0m ppc \u001b[38;5;241m=\u001b[39m \u001b[43mPPC\u001b[49m\u001b[43m(\u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodels_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodality\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[38;5;66;03m# run ppc+cv\u001b[39;00m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m METRIC_CV_CELL \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m skip_metrics:\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_ppc.py:85\u001b[0m, in \u001b[0;36mPosteriorPredictiveCheck.__init__\u001b[0;34m(self, adata, models_dict, count_layer_key, n_samples, indices, modality)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(adata, MuData):\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m modality \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality must be defined for MuData.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 85\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata \u001b[38;5;241m=\u001b[39m \u001b[43madata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmodality\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 86\u001b[0m raw_counts \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata\u001b[38;5;241m.\u001b[39mlayers[count_layer_key] \u001b[38;5;28;01mif\u001b[39;00m count_layer_key \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata\u001b[38;5;241m.\u001b[39mX\n\u001b[1;32m 89\u001b[0m )\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:516\u001b[0m, in \u001b[0;36mMuData.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmod[index]\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mMuData\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mas_view\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:166\u001b[0m, in \u001b[0;36mMuData.__init__\u001b[0;34m(self, data, feature_types_names, as_view, index, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_common()\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m as_view:\n\u001b[0;32m--> 166\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_as_view\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# Add all modalities to a MuData object\u001b[39;00m\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:265\u001b[0m, in \u001b[0;36mMuData._init_as_view\u001b[0;34m(self, mudata_ref, index)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_init_as_view\u001b[39m(\u001b[38;5;28mself\u001b[39m, mudata_ref: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMuData\u001b[39m\u001b[38;5;124m\"\u001b[39m, index):\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01manndata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mindex\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _normalize_indices\n\u001b[0;32m--> 265\u001b[0m obsidx, varidx \u001b[38;5;241m=\u001b[39m \u001b[43m_normalize_indices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmudata_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmudata_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# to handle single-element subsets, otherwise when subsetting a Dataframe\u001b[39;00m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;66;03m# we get a Series\u001b[39;00m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obsidx, Integral):\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/index.py:32\u001b[0m, in \u001b[0;36m_normalize_indices\u001b[0;34m(index, names0, names1)\u001b[0m\n\u001b[1;32m 30\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(i\u001b[38;5;241m.\u001b[39mvalues \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, pd\u001b[38;5;241m.\u001b[39mSeries) \u001b[38;5;28;01melse\u001b[39;00m i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index)\n\u001b[1;32m 31\u001b[0m ax0, ax1 \u001b[38;5;241m=\u001b[39m unpack_index(index)\n\u001b[0;32m---> 32\u001b[0m ax0 \u001b[38;5;241m=\u001b[39m \u001b[43m_normalize_index\u001b[49m\u001b[43m(\u001b[49m\u001b[43max0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnames0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m ax1 \u001b[38;5;241m=\u001b[39m _normalize_index(ax1, names1)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ax0, ax1\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/index.py:99\u001b[0m, in \u001b[0;36m_normalize_index\u001b[0;34m(indexer, index)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m np\u001b[38;5;241m.\u001b[39many(positions \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 98\u001b[0m not_found \u001b[38;5;241m=\u001b[39m indexer[positions \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\n\u001b[1;32m 100\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValues \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(not_found)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(indexer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mare not valid obs/ var names or indices.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 102\u001b[0m )\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m positions \u001b[38;5;66;03m# np.ndarray[int]\u001b[39;00m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown indexer \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mindexer\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(indexer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mKeyError\u001b[0m: \"Values ['rna', 'protein'], from ['rna', 'protein'], are not valid obs/ var names or indices.\"" + ] + } + ], + "source": [ + "model_path = self._minify_and_save_model(model, mdata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hub_model = self._create_hub_model(model_path)\n", + "hub_model = self._upload_hub_model(hub_model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "scvi-tools", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}