Unofficial training code for SegGPT.
From left to right: input image, masked label, ground truth, raw model prediction, discretized model prediction.
- This implementation is based on my understanding of the paper and by reverse-engineering how the model works from the official implementation.
- I have tested this code to fine-tune on OEM dataset and got promising results. However, there might be some bugs or mistakes in the code. Feel free to raise an issue.
- Fine-tuning from the provided checkpoint requires a lot of GPU memory (at least 24GB) as this trains the whole ViT-16 backbone. Consider using smaller batch size or smaller model overall. I might implement training using LoRA to support smaller VRAM in the future if this repo gains enough tractions.
This code is developed with Python 3.9.
Install the required packages by running:
pip install -r requirements.txt
Create a new conda environment and install the required packages by running:
conda env create -f env.yml
Setup your dataset directory as follows:
<root_dataset_path>
├── images
│ ├── image1.tif
│ ├── image2.tif
│ ...
└── labels
├── image1.tif
├── image2.tif
...
Note:
- Image and labels must have the same name and extension (or you can modify
data.py
to support your needs). - The extension does not have to be
.tif
as long as it can be loaded usingPIL
library. - The label is a single-channel image where each pixel value represents the class of that pixel.
Create a .json
config file. You can use the provided configs/base.json
as a template. Then, run:
python train.py --config <path_to_json_config>
The training uses DDP strategy and utilizes all available GPUs by default. You can specify the GPU to use by setting CUDA_VISIBLE_DEVICES
in the environment variable.
You can also launch tensorboard to monitor the training progress:
tensorboard --logdir logs
In the paper, the author mentioned using learnable tensor for in-context tuning. You can find my implementation for this in model.py
.