- Python >= 3.8 (Recommend to use Anaconda)
- PyTorch >=2.0.1
- NVIDIA GPU + CUDA
Clone the repo and install dependent packages
git clone this_project
cd SEED-X
pip install -r requirements.txt
We release the pretrained De-Tokenizer, the pre-trained foundation model SEED-X, the general instruction-tuned model SEED-X-I, the editing model SEED-X-Edit in Google Drive
Please download the checkpoints and save them under the folder ./pretrained
. For example, ./pretrained/seed_x
.
You also need to download stable-diffusion-xl-base-1.0 and Qwen-VL-Chat, and save them under the folder ./pretrained
. Please use the following script to extract the weights of visual encoder in Qwen-VL-Chat.
python3 src/tools/reload_qwen_vit.py
# For image reconstruction with ViT image features
python3 src/inference/eval_seed_x_detokenizer.py
# For image reconstruction with ViT image features and conditional image
python3 src/inference/eval_seed_x_detokenizer_with_condition.py
# For image comprehension and detection
python3 src/inference/eval_img2text_seed_x.py
# For image generation
python3 src/inference/eval_text2img_seed_x.py
# For image comprehension and detection
python3 src/inference/eval_img2text_seed_x_i.py
# For image generation
python3 src/inference/eval_text2img_seed_x_i.py
# For image editing
python3 src/inference/eval_img2edit_seed_x_edit.py
- Prepare the pretrained models including the pre-trained foundation model SEED-X and the visual encoder of Qwen-VL-Chat (See Model Weights).
- Prepare the instruction tuning data. For example, for "build_llava_jsonl_datapipes" dataloader, each folder stores a number of jsonl files, each jsonl file contains 10K pieces of content, with an example of the content as follows:
{"image": "coco/train2017/000000033471.jpg", "data": ["What are the colors of the bus in the image?", "The bus in the image is white and red.", "What feature can be seen on the back of the bus?", "The back of the bus features an advertisement.", "Is the bus driving down the street or pulled off to the side?", "The bus is driving down the street, which is crowded with people and other vehicles."]}
For "build_caption_datapipes_with_pixels" dataloder, each folder stores a number of .tar files and reads image-text pairs in the form of webdataset.
For "build_single_turn_edit_datapipes" dataloder, each folder stores a number of jsonl files, each jsonl file contains 10K pieces of content, with an example of the content as follows:
{"source_image": "source_images/f6f4d0669694df5b.jpg", "target_image": "target_images/f6f4d0669694df5b.jpg", "instruction": "Erase the car that is parked in front of the Roebuck building."}
- Run the following script.
# For general instruction tuning for multimodal comprehension and generation
sh scripts/train_seed_x_sft_comp_gen.sh
# For training language-guided image editing
sh scripts/train_seed_x_sft_edit.sh
- Obtain "pytorch_model.bin" with the following script.
cd train_output/seed_x_sft_comp_gen/checkpoint-xxxx
python3 zero_to_fp32.py . pytorch_model.bin
- Change "pretrained_model_path" in "configs/clm_models/agent_seed_x.yaml" with the new checkpoint. For example,
pretrained_model_path: train_output/seed_x_sft_comp_gen/checkpoint-4000/pytorch_model.bin
- Change the "llm_cfg_path" and "agent_cfg_path" in the inference script (See below), which will automatically load the trained LoRA weights onto the pretrained model SEED-X.
llm_cfg_path = 'configs/clm_models/llm_seed_x_lora.yaml'
agent_cfg_path = 'configs/clm_models/agent_seed_x.yaml'
- Run the inference script,
# For image comprehension
python3 src/inference/eval_img2text_seed_x_i.py
# For image generation
python3 src/inference/eval_text2img_seed_x_i.py
# For image editing
python3 src/inference/eval_img2edit_seed_x_edit.py