Annotation-efficient classification combining active learning, pre-training and semi-supervised learning for biomedical images
Repository for implementation of active learning and semi-supervised learning algorithms and applying it to medical imaging datasets
- Least Confidence Sampling [1]
- Margin Sampling [1]
- Ratio Sampling [1]
- Maximum Entropy Sampling [1]
- Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning [2]
- Learning Loss for Active Learning [3]
- BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning [4]
- Pseudo Labeling [5]
- Autoencoder [5]
- A Simple Framework for Contrastive Learning of Visual Representations [6]
- FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence [7]
numpy>=1.18.5
torch>=1.4.0
torchvision>=0.5.0
scikit-learn>=0.23.1
pandas>=1.0.4
Pillow>=7.1.2
matplotlib>=3.2.1
toma>=1.1.0
scikit-image>=0.17.2
pytorch-msssim
scikit-learn-extra
dataclasses
The train/test splits for the three datasets can be downloaded here:
A. White Blood Cell: Download
B. Skin Lesion: Download
C. Cell Cycle: Download
usage: python3 train.py [-h] [-n NAME] [-e EPOCHS] [--start-epoch START_EPOCH] [-r]
[--load-pretrained] [-b BATCH_SIZE] [--lr LR] [--mo MO]
[--nesterov NESTEROV] [--wd WD] [-p PRINT_FREQ]
[--layers LAYERS] [--widen-factor WIDEN_FACTOR]
[--drop-rate DROP_RATE] [--no-augment]
[--add-labeled-epochs ADD_LABELED_EPOCHS]
[--add-labeled ADD_LABELED] [--start-labeled START_LABELED]
[--stop-labeled STOP_LABELED]
[--labeled-warmup-epochs LABELED_WARMUP_EPOCHS]
[--unlabeled-subset UNLABELED_SUBSET] [-o] [-m] [--rem]
[--arch {wideresnet,densenet,lenet,resnet}] [-l {ce,fl}]
[--log-path LOG_PATH]
[--al {least_confidence,margin_confidence,ratio_confidence,entropy_based,mc_dropout,learning_loss,augmentations_based}]
[--mc-dropout-iterations MC_DROPOUT_ITERATIONS]
[--augmentations_based_iterations AUGMENTATIONS_BASED_ITERATIONS]
[--root ROOT]
[--weak-supervision-strategy {active_learning,semi_supervised,random_sampling,fully_supervised}]
[--ssl {pseudo_labeling,auto_encoder,simclr,fixmatch,auto_encoder_cl,auto_encoder_no_feat,simclr_with_al,auto_encoder_with_al,fixmatch_with_al}]
[--semi-supervised-uncertainty-method {entropy_based,augmentations_based}]
[--pseudo-labeling-threshold PSEUDO_LABELING_THRESHOLD]
[--simclr-train-epochs SIMCLR_TRAIN_EPOCHS]
[--simclr-temperature SIMCLR_TEMPERATURE] [--simclr-normalize]
[--simclr-batch-size SIMCLR_BATCH_SIZE]
[--simclr-arch {lenet,resnet}]
[--simclr-base-lr SIMCLR_BASE_LR]
[--simclr-optimizer {adam,lars}] [--simclr-resume] [--weighted]
[--eval] [-d {cifar10,matek,cifar100,jurkat,plasmodium,isic}]
[--checkpoint-path CHECKPOINT_PATH] [-s {6666,9999,2323,5555}]
[--store-logs] [--run-batch] [--reset-model]
[--fixmatch-mu FIXMATCH_MU]
[--fixmatch-lambda-u FIXMATCH_LAMBDA_U]
[--fixmatch-threshold FIXMATCH_THRESHOLD]
[--fixmatch-k-img FIXMATCH_K_IMG]
[--fixmatch-epochs FIXMATCH_EPOCHS]
[--fixmatch-warmup FIXMATCH_WARMUP]
[--fixmatch-init {None,random,pretrained,simclr,autoencoder}]
[--learning-loss-weight LEARNING_LOSS_WEIGHT]
[--dlctcs-loss-weight DLCTCS_LOSS_WEIGHT]
[--autoencoder-train-epochs AUTOENCODER_TRAIN_EPOCHS]
[--autoencoder-z-dim AUTOENCODER_Z_DIM] [--autoencoder-resume]
[--k-medoids] [--k-medoids-n-clusters K_MEDOIDS_N_CLUSTERS]
[--novel-class-detection] [--gpu-id GPU_ID]
Short | Long | Default | Description |
---|---|---|---|
-h |
--help |
show this help message and exit | |
-n |
--name |
run_0 |
name of current running experiment |
-e |
--epochs |
1000 |
number of total epochs for AL training |
--start-epoch |
0 |
starting epoch number (useful when resuming) | |
-r |
--resume |
flag to be set if an existing model is to be loaded | |
--load-pretrained |
load pretrained imagenet weights or not | ||
-b |
--batch-size |
256 |
batch size for AL training (default: 256) |
--learning-rate |
0.001 |
initial learning rate for AL optimizer | |
--momentum |
0.9 |
momentum | |
--nesterov |
nesterov momentum | ||
--weight-decay |
0.0005 |
weight decay for AL optimizer | |
-p |
--print-freq |
10 |
print frequency per step |
--layers |
28 |
total number of layers for ResNext architecture | |
--widen-factor |
10 |
widen factor for ResNext architecture | |
--drop-rate |
0.15 |
dropout probability for ResNet/LeNet architecture | |
--no-augment |
whether to use standard augmentations or not | ||
--add-labeled-epochs |
20 |
if recall doesn't improve perform AL cycle | |
--add-labeled |
100 |
amount of labeled data to be added during each AL cycle | |
--start-labeled |
100 |
amount of labeled data to start the AL training | |
--stop-labeled |
1020 |
amount of labeled data to stop the AL training | |
--labeled-warmup-epochs |
35 |
number of warmup epochs before AL training | |
--unlabeled-subset |
0.3 |
the subset of the unlabeled data to use for AL algorithms | |
-o |
--oversampling |
perform oversampling for labeled dataset or not | |
-m |
--merged |
to merge certain classes in the dataset (see dataset scripts) | |
--remove-classes |
to remove certain classes in the dataset (see dataset scripts) | ||
--architecture |
resnet |
the architecture to use for AL training | |
-l |
--loss |
ce |
the loss to be used. ce = cross entropy and fl = focal loss |
--log-path |
~/logs/ |
the directory root for storing/retrieving the logs | |
--uncertainty-sampling-method |
entropy_based |
the AL algorithm to use | |
--mc-dropout-iterations |
25 |
number of iterations for mc dropout | |
--augmentations_based_iterations |
25 |
number of iterations for augmentations based AL algorithm | |
--root |
~/datasets/ |
the root path for the datasets | |
--weak-supervision-strategy |
semi_supervised |
the weakly supervised strategy to use | |
--semi-supervised-method |
fixmatch_with_al |
the SSL algorithm to use | |
--semi-supervised-uncertainty-method |
entropy_based |
the AL algorithm to use in conjunction with a SSL algorithm | |
--pseudo-labeling-threshold |
0.9 |
the threshold for considering the pseudo label as the actual label | |
--simclr-train-epochs |
200 |
number of total epochs for SimCLR training | |
--simclr-temperature |
0.1 |
the temperature term for simclr loss | |
--simclr-normalize |
normalize the hidden feat vectors in simclr or not | ||
--simclr-batch-size |
1024 |
batch size for simclr training (default: 1024) | |
--simclr-arch |
resnet |
which encoder architecture to use for simclr | |
--simclr-base-lr |
0.25 |
base learning rate for SimCLR optimizer | |
--simclr-optimizer |
adam |
which optimizer to use for simclr training | |
--simclr-resume |
flag to be set if an existing simclr model is to be loaded | ||
--weighted |
to use weighted loss or not (only in case of ce) | ||
--eval |
only perform evaluation and exit | ||
-d |
--dataset |
matek |
the dataset to train on |
--checkpoint-path |
~/runs/ |
the directory root for saving/resuming checkpoints from | |
-s |
--seed |
9999 |
the random seed to set |
--store-logs |
store the logs after training | ||
--run-batch |
run all methods in batch mode | ||
--reset-model |
reset models after every labels injection cycle | ||
--fixmatch-mu |
8 |
coefficient of unlabeled batch size i.e. mu.B from paper | |
--fixmatch-lambda-u |
1 |
coefficient of unlabeled loss | |
--fixmatch-threshold |
0.95 |
pseudo label threshold | |
--fixmatch-k-img |
8192 |
number of labeled examples | |
--fixmatch-epochs |
1000 |
epochs for SSL or SSL + AL training | |
--fixmatch-warmup |
0 |
warmup epochs with unlabeled data | |
--fixmatch-init |
None |
the semi supervised method to use | |
--learning-loss-weight |
1.0 |
the weight for the loss network, loss term in the objective function | |
--dlctcs-loss-weight |
100 |
the weight for classification loss in dlctcs | |
--autoencoder-train-epochs |
20 |
number of total epochs for autoencoder training | |
--autoencoder-z-dim |
128 |
the bottleneck dimension for the autoencoder architecture | |
--autoencoder-resume |
flag to be set if an existing autoencoder model is to be loaded | ||
--k-medoids |
to perform k medoids init with SimCLR | ||
--k-medoids-n-clusters |
10 |
number of k medoids clusters | |
--novel-class-detection |
turn on novel class detection | ||
--gpu-id |
0 |
the id of the GPU to use |
show this help message and exit
name of current running experiment
number of total epochs for AL training
starting epoch number (useful when resuming)
flag to be set if an existing model is to be loaded
load pretrained imagenet weights or not
batch size for AL training (default: 256)
initial learning rate for AL optimizer
momentum
nesterov momentum
weight decay for AL optimizer
print frequency per step
total number of layers for ResNext architecture
widen factor for ResNext architecture
dropout probability for ResNet/LeNet architecture
whether to use standard augmentations or not
if recall doesn't improve perform AL cycle
amount of labeled data to be added during each AL cycle
amount of labeled data to start the AL training
amount of labeled data to stop the AL training
number of warmup epochs before AL training
the subset of the unlabeled data to use for AL algorithms
perform oversampling for labeled dataset or not
to merge certain classes in the dataset (see dataset scripts)
to remove certain classes in the dataset (see dataset scripts)
the architecture to use for AL training
the loss to be used. ce = cross entropy and fl = focal loss
the directory root for storing/retrieving the logs
the AL algorithm to use
number of iterations for mc dropout
number of iterations for augmentations based AL algorithm
the root path for the datasets
the weakly supervised strategy to use
the SSL algorithm to use
the AL algorithm to use in conjunction with a SSL algorithm
the threshold for considering the pseudo label as the actual label
number of total epochs for SimCLR training
the temperature term for simclr loss
normalize the hidden feat vectors in simclr or not
batch size for simclr training (default: 1024)
which encoder architecture to use for simclr
base learning rate for SimCLR optimizer
which optimizer to use for simclr training
flag to be set if an existing simclr model is to be loaded
to use weighted loss or not (only in case of ce)
only perform evaluation and exit
the dataset to train on
the directory root for saving/resuming checkpoints from
the random seed to set
store the logs after training
run all methods in batch mode
reset models after every labels injection cycle
coefficient of unlabeled batch size i.e. mu.B from paper
coefficient of unlabeled loss
pseudo label threshold
number of labeled examples
epochs for SSL or SSL + AL training
warmup epochs with unlabeled data
the semi supervised method to use
the weight for the loss network, loss term in the objective function
the weight for classification loss in dlctcs
number of total epochs for autoencoder training
the bottleneck dimension for the autoencoder architecture
flag to be set if an existing autoencoder model is to be loaded
to perform k medoids init with SimCLR
number of k medoids clusters
turn on novel class detection
the id of the GPU to use
To run supervised training with augmentations-based sampling on white blood cell dataset:
python3 train.py --dataset matek --seed <seed> --root <datasets_root> --al augmentations_based --weak-supervision-strategy active_learning
To run semi-supervised training with augmentations-based sampling on cell cycle dataset:
python3 train.py --dataset jurkat --seed <seed> --root <datasets_root> --ssl fixmatch_with_al --weak-supervision-strategy semi_supervised --semi-supervised-uncertainty-method augmentations_based
To run SimCLR self-supervised pre-training on skin lesions dataset:
python3 train.py --dataset isic --seed <seed> --root <datasets_root> --ssl simclr --simclr-resume
Results for the three datasets:
A. White Blood Cell
B. Skin Lesion
C. Cell Cycle
All results can be downloaded from here: Download
[1] Settles, B. (2009). Active learning literature survey. University of Wisconsin-Madison Department of Computer Sciences.
[2] Gal, Y., & Ghahramani, Z. (2016). Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In international conference on machine learning (pp. 1050-1059).
[3] Yoo, D., & Kweon, I. S. (2019). Learning loss for active learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 93-102).
[4] Kirsch, A., van Amersfoort, J., & Gal, Y. (2019). Batchbald: Efficient and diverse batch acquisition for deep bayesian active learning. In Advances in Neural Information Processing Systems (pp. 7026-7037).
[5] Van Engelen, J. E., & Hoos, H. H. (2020). A survey on semi-supervised learning. Machine Learning, 109(2), 373-440.
[6] Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709.
[7] Sohn, K., Berthelot, D., Li, C. L., Zhang, Z., Carlini, N., Cubuk, E. D., ... & Raffel, C. (2020). Fixmatch: Simplifying semi-supervised learning with consistency and confidence. arXiv preprint arXiv:2001.07685.