From a005bdbaf917543d0fd29f37cdac06b77e0eb28d Mon Sep 17 00:00:00 2001 From: Daniel Bolin Date: Tue, 26 Mar 2024 15:43:44 -0400 Subject: [PATCH] Popv fixes --- containers/popv/context/download-models.sh | 2 +- .../popv/context/download-reference-data.sh | 2 +- containers/popv/context/main.py | 108 +++++++++++++++--- .../popv/context/requirements-freeze.txt | 51 +++++---- containers/popv/context/requirements.txt | 2 +- 5 files changed, 122 insertions(+), 43 deletions(-) diff --git a/containers/popv/context/download-models.sh b/containers/popv/context/download-models.sh index 9a56172..e9ed36c 100755 --- a/containers/popv/context/download-models.sh +++ b/containers/popv/context/download-models.sh @@ -5,7 +5,7 @@ MODELS_ID=${1:?"A zenodo models id must be provided to download!"} MODELS_DIR=${2:-"./popv/models"} mkdir -p $MODELS_DIR -zenodo_get $MODELS_ID -o $MODELS_DIR +zenodo_get $MODELS_ID -o $MODELS_DIR --continue-on-error --retry 2 --pause 30 for ARCHIVE in $MODELS_DIR/*.tar.gz; do MODEL=$(basename -s .tar.gz $ARCHIVE) diff --git a/containers/popv/context/download-reference-data.sh b/containers/popv/context/download-reference-data.sh index c34e322..a2d791b 100755 --- a/containers/popv/context/download-reference-data.sh +++ b/containers/popv/context/download-reference-data.sh @@ -5,4 +5,4 @@ REFERENCE_DATA_ID=${1:?"A zenodo reference data id must be provided to download! REFERENCE_DATA_DIR=${2:-"./popv/reference-data"} mkdir -p $REFERENCE_DATA_DIR -zenodo_get $REFERENCE_DATA_ID -o $REFERENCE_DATA_DIR +zenodo_get $REFERENCE_DATA_ID -o $REFERENCE_DATA_DIR --continue-on-error --retry 2 --pause 30 \ No newline at end of file diff --git a/containers/popv/context/main.py b/containers/popv/context/main.py index 1e88189..b6e357a 100644 --- a/containers/popv/context/main.py +++ b/containers/popv/context/main.py @@ -4,14 +4,89 @@ from pathlib import Path import anndata -import numpy -import popv +import celltypist +import h5py +import numpy as np +import pandas as pd import scanpy +import scipy.sparse as sp_sparse +import scvi.data.fields._layer_field as scvi_layer_field import torch +import popv from src.algorithm import Algorithm, RunResult, add_common_arguments from src.util.layers import set_data_layer +# From https://github.com/scverse/scvi-tools/blob/1.1.2/scvi/data/_utils.py#L15 +try: + # anndata >= 0.10 + from anndata.experimental import CSCDataset, CSRDataset + + SparseDataset = (CSRDataset, CSCDataset) +except ImportError: + from anndata._core.sparse_dataset import SparseDataset + + +# From https://github.com/scverse/scvi-tools/blob/1.1.2/scvi/data/_utils.py#L248 +# But with jax operations replaced by regular numpy calls +def _check_nonnegative_integers( + data: t.Union[pd.DataFrame, np.ndarray, sp_sparse.spmatrix, h5py.Dataset], + n_to_check: int = 20, +): + """Approximately checks values of data to ensure it is count data.""" + # for backed anndata + if isinstance(data, h5py.Dataset) or isinstance(data, SparseDataset): + data = data[:100] + + if isinstance(data, np.ndarray): + data = data + elif issubclass(type(data), sp_sparse.spmatrix): + data = data.data + elif isinstance(data, pd.DataFrame): + data = data.to_numpy() + else: + raise TypeError("data type not understood") + + ret = True + if len(data) != 0: + inds = np.random.choice(len(data), size=(n_to_check,)) + # Start of replacements + data = data.flat[inds] + negative = np.any(data < 0) + non_integer = np.any(data % 1 != 0) + # End of replacements + ret = not (negative or non_integer) + return ret + + +def _fix_jax_segfault(): + """Fixes a segfault that can happen inside docker containers + when running both knn_on_scvi and scanvi. + + The error is caused by a race condition or data corruption in jax + when the algorithms load their respective model files. + I suspect there might be a slight version mismatch or similar when + creating the docker container but for now I just monkey patch the offending calls. + """ + scvi_layer_field._check_nonnegative_integers = _check_nonnegative_integers + + +def _fix_celltypist_forced_models_download(model_dir: Path): + """Prevent celltypist from redownloading all models. + + Celltypist's `Model.load` function always attempts to download + all models if it cannot detect at least one *.pkl (pickle serialized) + file in it's default models directory even when provided with + a direct path to a model file. + + Monkey patching celltypist's models directory path to a directory + with at least on *.pkl file will trick it into not downloading the models. + + Args: + model_dir (Path): Directory with at least one *.pkl file + """ + celltypist.models.models_path = model_dir + class PopvOrganMetadata(t.TypedDict): model: str @@ -45,20 +120,21 @@ def do_run( options: PopvOptions, ) -> RunResult: """Annotate data using popv.""" + _fix_jax_segfault() + data = scanpy.read_h5ad(matrix) data = self.prepare_query(data, organ, metadata["model"], options) popv.annotation.annotate_data( data, # TODO: onclass has been removed due to error in fast mode # seen_result_key is not added to the result in fast mode but still expected during compute_consensus - # https://github.com/YosefLab/PopV/blob/main/popv/annotation.py#L64 - # https://github.com/YosefLab/PopV/blob/main/popv/algorithms/_onclass.py#L199 - # Also excludes celltypist since web requests are not available inside the docker container methods=[ "knn_on_scvi", "scanvi", "svm", "rf", + # "onclass", + "celltypist", ], ) @@ -81,23 +157,25 @@ def prepare_query( reference_data_path = self.find_reference_data( options["reference_data_dir"], organ, model ) - model_path = self.find_model_dir(options["models_dir"], organ, model) reference_data = scanpy.read_h5ad(reference_data_path) n_samples_per_label = self.get_n_samples_per_label(reference_data, options) data = self.normalize_var_names(data, options) data = set_data_layer(data, options["query_layers_key"]) - if options["query_layers_key"] in ('X', 'raw'): + if options["query_layers_key"] in ("X", "raw"): options["query_layers_key"] = None - data.X = numpy.rint(data.X) + data.X = np.rint(data.X) + + model_dir = self.find_model_dir(options["models_dir"], organ, model) + _fix_celltypist_forced_models_download(model_dir) - data = self.add_model_genes(data, model_path, options["query_layers_key"]) + data = self.add_model_genes(data, model_dir, options["query_layers_key"]) data.var_names_make_unique() query = popv.preprocessing.Process_Query( data, reference_data, - save_path_trained_models=str(model_path), + save_path_trained_models=f"{model_dir}/", prediction_mode=options["prediction_mode"], query_labels_key=options["query_labels_key"], query_batch_key=options["query_batch_key"], @@ -109,7 +187,7 @@ def prepare_query( cl_obo_folder=f"{options['cell_ontology_dir']}/", compute_embedding=True, hvg=None, - use_gpu=False, # Using gpu with docker requires additional setup + accelerator="cpu", # Using gpu with docker/apptainer requires additional setup ) return query.adata @@ -128,8 +206,8 @@ def get_n_samples_per_label( ref_labels_key = options["ref_labels_key"] n_samples_per_label = options["samples_per_label"] if ref_labels_key in reference_data.obs.columns: - n = numpy.min(reference_data.obs.groupby(ref_labels_key).size()) - n_samples_per_label = numpy.max((n_samples_per_label, t.cast(int, n))) + n = np.min(reference_data.obs.groupby(ref_labels_key).size()) + n_samples_per_label = np.max((n_samples_per_label, t.cast(int, n))) return n_samples_per_label def find_reference_data(self, dir: Path, organ: str, model: str) -> Path: @@ -274,8 +352,8 @@ def add_model_genes( Path.joinpath(model_path, "scvi/model.pt"), map_location="cpu" )["var_names"] n_obs_data = data.X.shape[0] - new_genes = set(numpy.setdiff1d(model_genes, data.var_names)) - zeroes = numpy.zeros((n_obs_data, len(new_genes))) + new_genes = set(np.setdiff1d(model_genes, data.var_names)) + zeroes = np.zeros((n_obs_data, len(new_genes))) layers = {query_layers_key: zeroes} if query_layers_key else None new_data = scanpy.AnnData(X=zeroes, var=new_genes, layers=layers) new_data.obs_names = data.obs_names diff --git a/containers/popv/context/requirements-freeze.txt b/containers/popv/context/requirements-freeze.txt index 942c967..78baff6 100644 --- a/containers/popv/context/requirements-freeze.txt +++ b/containers/popv/context/requirements-freeze.txt @@ -3,7 +3,7 @@ aiohttp==3.9.3 aiosignal==1.3.1 anndata==0.10.6 annoy==1.17.3 -array_api_compat==1.4.1 +array_api_compat==1.5.1 astunparse==1.6.3 attrs==23.2.0 bbknn==1.6.0 @@ -11,23 +11,22 @@ beautifulsoup4==4.12.3 celltypist==1.6.2 certifi==2024.2.2 charset-normalizer==3.3.2 -chex==0.1.85 +chex==0.1.86 click==8.1.7 contextlib2==21.6.0 contourpy==1.2.0 cycler==0.12.1 Cython==3.0.9 -dm-tree==0.1.8 docrep==0.3.2 et-xmlfile==1.1.0 -etils==1.7.0 +etils==1.8.0 fbpca==1.0 -filelock==3.13.1 -flatbuffers==24.3.7 -flax==0.8.1 -fonttools==4.49.0 +filelock==3.13.3 +flatbuffers==24.3.25 +flax==0.8.2 +fonttools==4.50.0 frozenlist==1.4.1 -fsspec==2024.2.0 +fsspec==2024.3.1 gast==0.5.4 gdown==5.1.0 geosketch==1.2 @@ -35,23 +34,24 @@ google-pasta==0.2.0 grpcio==1.62.1 h5py==3.10.0 harmony-pytorch==0.1.8 -huggingface-hub==0.21.4 +huggingface-hub==0.22.1 idna==3.6 igraph==0.11.4 -importlib_resources==6.2.0 +importlib_resources==6.4.0 intervaltree==3.1.0 jax==0.4.25 jaxlib==0.4.25 Jinja2==3.1.3 joblib==1.3.2 -keras==3.0.5 +keras==3.1.1 kiwisolver==1.4.5 +legacy-api-wrap==1.4 leidenalg==0.10.2 -libclang==16.0.6 +libclang==18.1.1 lightning==2.1.4 -lightning-utilities==0.10.1 +lightning-utilities==0.11.1 llvmlite==0.42.0 -Markdown==3.5.2 +Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.8.3 @@ -67,7 +67,7 @@ namex==0.0.7 natsort==8.4.0 nest-asyncio==1.6.0 networkx==3.2.1 -numba==0.59.0 +numba==0.59.1 numpy==1.26.4 numpyro==0.14.0 nvidia-cublas-cu12==12.1.3.1 @@ -87,7 +87,8 @@ OnClass==1.3 openpyxl==3.1.2 opt-einsum==3.3.0 optax==0.2.1 -orbax-checkpoint==0.5.5 +optree==0.11.0 +orbax-checkpoint==0.5.7 packaging==24.0 pandas==1.5.3 patsy==0.5.6 @@ -110,13 +111,13 @@ requests==2.31.0 rich==13.7.1 safetensors==0.4.2 scanorama==1.7.4 -scanpy==1.9.8 +scanpy==1.10.0 scikit-learn==1.1.3 scikit-misc==0.3.1 scipy==1.12.0 scvi-tools==1.1.2 seaborn==0.13.2 -sentence-transformers==2.5.1 +sentence-transformers==2.6.1 session_info==1.0.0 six==1.16.0 sortedcontainers==2.4.0 @@ -128,16 +129,16 @@ tensorboard==2.16.2 tensorboard-data-server==0.7.2 tensorflow==2.16.1 tensorflow-io-gcs-filesystem==0.36.0 -tensorstore==0.1.54 +tensorstore==0.1.56 termcolor==2.4.0 texttable==1.7.0 -threadpoolctl==3.3.0 +threadpoolctl==3.4.0 tokenizers==0.15.2 toolz==0.12.1 torch==2.2.1 -torchmetrics==1.3.1 +torchmetrics==1.3.2 tqdm==4.66.2 -transformers==4.38.2 +transformers==4.39.1 triton==2.2.0 typing_extensions==4.10.0 umap-learn==0.5.5 @@ -146,5 +147,5 @@ Werkzeug==3.0.1 wget==3.2 wrapt==1.16.0 yarl==1.9.4 -zenodo-get==1.4.0 -zipp==3.17.0 +zenodo-get==1.5.1 +zipp==3.18.1 diff --git a/containers/popv/context/requirements.txt b/containers/popv/context/requirements.txt index 4c3b1ca..dfad203 100644 --- a/containers/popv/context/requirements.txt +++ b/containers/popv/context/requirements.txt @@ -1,2 +1,2 @@ popv==0.4.* -zenodo_get==1.4.0 +zenodo_get==1.5.*