Skip to content

Commit

Permalink
Version 1.0.0 release
Browse files Browse the repository at this point in the history
clone repository
  • Loading branch information
diegoluna3 authored Sep 30, 2024
2 parents d304cbd + 689ca72 commit 15f9359
Show file tree
Hide file tree
Showing 51 changed files with 107,638 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ckpt filter=lfs diff=lfs merge=lfs -text
70 changes: 5 additions & 65 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
Expand All @@ -39,17 +37,13 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
Expand All @@ -58,8 +52,6 @@ cover/
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
Expand All @@ -72,51 +64,16 @@ instance/
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
.python-version

# celery beat schedule file
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py
Expand All @@ -127,8 +84,6 @@ celerybeat.pid
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
Expand All @@ -142,21 +97,6 @@ venv.bak/

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# vscode project settings
.vscode
111 changes: 111 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# DeepPaint

<img alt="DeepPaint" src="./pipeline.png" width="800px" style="max-width: 100%;">

A python package dedicated to classification and exploration of Cell Painting data. This package relies on [`lightning`](https://lightning.ai) for training/evaluation of the `DenseNet` model.

## Installation

1. Ensure `Python >=3.11` and `conda` are installed on your machine
The recommended installer for `conda` is [`miniforge`](https://github.com/conda-forge/miniforge)
2. Clone this repository
~~~bash
$ git clone https://github.com/jhuapl-bio/DeepPaint.git
~~~
3. Navigate to the [DeepPaint](.) directory (containing the README)
~~~bash
$ cd DeepPaint
~~~
4. Create a `conda` virtual environment from the [`environment.yml`](./environment.yml) file and activate it
~~~bash
$ conda env create -n <env_name> -f environment.yml
$ conda activate <env_name>
~~~
5. Install the `DeepPaint` package with pip
~~~bash
$ pip install .
~~~

## Usage (Overview)

The `DeepPaint` package can be run as a module with the command `python -m deep_paint` to invoke the CLI. This is the entry point for training and evaluating models.

## CLI (Command Line Interface)

Four commands are available:
- `fit`: Train or finetune a model
- `validate`: Run one evaluation epoch on a validation set
- `test`: Run one test epoch on a test set
- `predict`: Get predictions from a trained model on part or all of a dataset

These commands correspond to the [lightning.pytorch.Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods) methods. All commands can be run with the `--config` argument to specify a configuration file.

### Config File

The configuration files used for training, getting model predictions, and getting model embeddings are available in the [configs](./results/configs/) directory. Ensure to update the paths in the configuration files (they are commented for convenience).

The configuration file is a YAML file that contains all the necessary parameters for training, evaluating, or testing a model. The YAML file is divided into the following fields:

| Field | Subclass | Description | Required? |
|--------------|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|-----------|
| model | [`LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#lightningmodule) | Model architecture and hyperparameters ||
| data | [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#lightningdatamodule) | Data preprocessing and augmentation ||
| trainer | [`Trainer`](https://lightning.ai/docs/pytorch/stable/common/trainer.html) | [Training arguments](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api) ||
| optimizer | [`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer) | Optimizer ||
| lr_scheduler | [`LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Learning Rate Scheduler ||
| ckpt_path | N/A | Path to model checkpoint ||

All fields except `trainer` and `ckpt_path` require a `class_path` parameter. A full path to the class must be provided. Following this parameter, the rest of the field is parsed as keyword arguments to the class constructor via the `init_args` parameter.


### Example Usage

- Train a model:
~~~bash
python -m deep_paint fit --config /path/to/your_config.yaml
~~~
- Run a validation epoch:
~~~bash
python -m deep_paint validate --config /path/to/your_config.yaml
~~~
- Run a test epoch:
~~~bash
python -m deep_paint test --config /path/to/your_config.yaml
~~~
- Get model predictions:
~~~bash
python -m deep_paint predict --config /path/to/your_config.yaml
~~~

### Getting Model Embeddings

A custom script has been created to extract embeddings from a trained model. The script can be run with the following command:
~~~bash
python -m deep_paint.utils.embeddings --config /path/to/your_config.yaml
~~~

This config file looks slightly different than the config file used for the four main commands. Refer to the [configs](./results/configs/) directory for examples.

## Results

### Overview

The [results](./results/) directory contains the following subdirectories:
- `checkpoints`: Contains model checkpoints
- `configs`: Contains configuration files used for training, getting model predictions, and getting model embeddings
- `embeddings`: Contains embeddings extracted from the model on the test set of the `RxRx2` data
- `logs`: Contains csv files extracted from `tensorboard` logs
- `metadata`: Contains custom metadata used for training the `DenseNet` model
- `predictions`: Contains model predictions on the test set of the `RxRx2` data

### Data Availability

The [`RxRx2`](https://www.rxrx.ai/rxrx2) dataset was used for training and evaluation of the `DenseNet` model. The dataset is freely available to download from the [`RxRx.ai`](https://www.rxrx.ai/) website.

### Model Weights

The [checkpoints](./results/checkpoints) directory contains model checkpoints for the binary and multiclass `DenseNet` model. These checkpoints can be used to load the trained models and make predictions.

### Notebooks

The [notebooks](./notebooks/) directory contains Jupyter notebooks that demonstrate the performance of the `DenseNet` model on the `RxRx2` dataset. The notebooks contain visualizations of the model predictions and embeddings.
2 changes: 2 additions & 0 deletions deep_paint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import deep_paint.lightning
import deep_paint.utils
23 changes: 23 additions & 0 deletions deep_paint/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from deep_paint.lightning.cli import DeepLightningCLI


def main():
"""Driver script."""
# Set torch precision depending on device
device_name = torch.cuda.get_device_name()
if "A100" or "H100" in device_name:
torch.set_float32_matmul_precision("high")

# CLI
cli = DeepLightningCLI(
subclass_mode_model=True,
subclass_mode_data=True,
seed_everything_default=42,
save_config_kwargs={"overwrite": True},
parser_kwargs={"default_env": True}
)

if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions deep_paint/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .callback import PlotConfusionMatrix, ModuleFreezeUnfreeze
from .cli import DeepLightningCLI
from .datamodule import ForwardDataset, DeepDataset, DeepLightningDataModule
from .module import DeepLightningModule
Loading

0 comments on commit 15f9359

Please sign in to comment.