RETapp- A revised gradio-based app for training models to predict diseases from retinal images.
This repository is a revised gradio-based app for training models predicting diseases from retinal imaging. (source: https://github.com/rmaphoh/RETFound_MAE)
We just want to let doctors to train their own models on their own datasets of retinal images easily.
Also, we updated the packages to be compatible with Python 3.8.+, cuda 11.7 and ubuntu 22.04.
We first install the dependencies: pip install -r requirements.txt
Then we install cuda on Linux Ubuntu 22.04 (64-bit) support cuda 11.7+ Here we install cuda 11.7 and pytorch==1.13.1+cu117
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
# or
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia`
- create a new folder for the datasets in the repo
data/
. e.g. I have download the OCTID dataset from downloaded here. It was split into 3 folders: train, val and test and then organised into 5 classes: ANormal, ARMD, CSR, Diabetic_retinopathy, Macular_Hole.
The structure should be like this:
├── data folder
├──train
├──class_a
├──class_b
├──class_c
├──val
├──class_a
├──class_b
├──class_c
├──test
├──class_a
├──class_b
├──class_c
Note: the folder name should be the same as the class name.
- Training
Run python train_web.py
to launch the gradio app and input the paramiters.
The paramiters are listed below:
--batch_size 16
--world_size 1
--model vit_large_patch16
--epochs 50
--blr 0
--layer_decay 0.65
--weight_decay 0.05
--nb_classes 5 (number of classes)
--data_path ./data/OCTID/
--task ./finetune_OCTID/ (path to the task folder, including metrics and checkpoints)
--finetune ./models/RETFound_cfp_weights.pth (path to the pretrained weights)
--input_size 224
--drop_path 0.1
--device cuda
You can see the training progress (Tensorboard) and log in the task
folder.
- Prediction
modify the paramiters in app.py
This is for the finetuned models.
This is for the task you want and choose the basic model.
Run python app.py
to launch the gradio app (modify the paramiters).
Official repo for RETFound: a foundation model for generalizable disease detection from retinal images, which is based on MAE:
Please contact [email protected] or [email protected] if you have questions.
Keras version implemented by Yuka Kihara can be found here
- RETFound is pre-trained on 1.6 million retinal images with self-supervised learning
- RETFound has been validated in multiple disease detection tasks
- RETFound can be efficiently adapted to customised tasks
- 🐉2024/01: Feature vector notebook are now online!
- 🐉2024/01: Data split and model checkpoints for public datasets are now online!
- 🎄2023/12: Colab notebook is now online - free GPU & simple operation!
- 2023/09: a visualisation demo is added
- 2023/10: change the hyperparameter of input_size for any image size
- Create environment with conda:
conda create -n retfound python=3.7.5 -y
conda activate retfound
- Install dependencies
git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE
pip install -r requirement.txt
To fine tune RETFound on your own data, follow these steps:
- Download the RETFound pre-trained weights
ViT-Large | |
---|---|
Colour fundus image | download |
OCT | download |
- Organise your data into this directory structure (Public datasets used in this study can be downloaded here)
├── data folder
├──train
├──class_a
├──class_b
├──class_c
├──val
├──class_a
├──class_b
├──class_c
├──test
├──class_a
├──class_b
├──class_c
- Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
--batch_size 16 \
--world_size 1 \
--model vit_large_patch16 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD_data/ \
--task ./finetune_IDRiD/ \
--finetune ./RETFound_cfp_weights.pth \
--input_size 224
- For evaluation only (download data and model checkpoints here; change the path below)
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
--eval --batch_size 16 \
--world_size 1 \
--model vit_large_patch16 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD_data/ \
--task ./internal_IDRiD/ \
--resume ./finetune_IDRiD/checkpoint-best.pth \
--input_size 224
import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
# call the model
model = models_vit.__dict__['vit_large_patch16'](
num_classes=2,
drop_path_rate=0.2,
global_pool=True,
)
# load RETFound weights
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
print("Model = %s" % str(model))
If you find this repository useful, please consider citing this paper:
@article{zhou2023foundation,
title={A foundation model for generalizable disease detection from retinal images},
author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others},
journal={Nature},
volume={622},
number={7981},
pages={156--163},
year={2023},
publisher={Nature Publishing Group UK London}
}