[ICLR'25] ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation
This repo contains the official code of our ICLR'25 paper: ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation.
- [25/01] We release the cuda kernel, and end2end accelerated pipeline for opensora v1.2 and pixart-sigma. for more details, please ref the ./kernels/readme.md
- [25/01] We updated and reorganize the code for ViDiT-Q, with improved file organization, the quantization related code are collected into a standalone python package
quant_utils
, which is easier applicable to new models. - [25/01] ViDiT-Q is accepted by ICLR'25.
- [24/07] We release the ViDiT-Q algorithm-level quantization simulation code at https://github.com/thu-nics/ViDiT-Q, now we have updated the code, and the deprecated older version could be accessed from the branch
viditq_old
.
We introduce ViDiT-Q, a quantization method specialized for diffusion transformers. For popular large-scale models (e.g., open-sora, Latte, Pixart-α, Pixart-Σ) for the video and image generation task, ViDiT-Q could achieve W8A8 quantization without metric degradation, and W4A8 without notable visual quality degradation.
FP16 | ViDiT-Q W4A8 Mixed Precision |
---|---|
FP16 | ViDiT-Q W4A8 Mixed Precision | ViDiT-Q W4A4 Mixed Precision |
---|---|---|
For more examples, please refer to our Project Page: https://a-suozhang.xyz/viditq.github.io/
We pack the quantization (including viditq methodology as a special case) related code into a standalone python package (located in quant_utils
folder). It could be easily adapted to existing codebase, by customize a quant_model
class inherit from the orginal model class deinition. The examples/
folder contains the example of applying viditq to diffuser's piplines (pixart-sigma), and normal code repository (opensora). PRs are welcomed for more model support. We also provide examples for image generation DiT, and Flux model (not fully supported all vidtq techniques).
We recommend using conda for enviornment management. For each model in the examples folder, you could refer to the orginal codebase's readme for environment setup, we recommend using independent environment for different models since they may contain conflict package versions.
Then, for each environment, for the support of quantization software simulation, you could install the qdiff
package by locally install the package in the ./quant_utils
folder. (the -e
is for editable installation, in case you want to modify the quantization related code.)
cd ./quant_utils
pip install -e .
Optionally, for inference with hardware cuda kernel for practical resource savings. Please install the viditq_extension
package by locally install the package in the ./kernels
folder
cd ./kernels
pip install -e .
-
clone the OpenSORA codebase as a submodule into
./examples/opensora1.2/
(the opensora codebase need to be slightly modified to support precompute text embeds, so we include the opensora codebase here.) -
setup the environment following
./examples/opensora1.2/OpenSORA/README.md
-
modify the file path for checkpoints in
./examples/opensora1.2/OpenSORA/configs
Due to the T5 text embedder is large, and need to be loaded into GPU memory. For memory constrained scenarios (e.g., RTX3090, 24GB), we support precompute the text embeddings, and save it offline to avoid loading the T5 on GPU.
For OpenSORA codebase, the scheduler code need to be modified as follows. (we provide the modified one in ./examples/opensora1.2/Open-Sora/opensora/schedulers/rf/init.py, which is the only modification made to the original opensora codebase.)
RF/__init__.py
:precompute_text_embeds=False
as sample() input attribute
# INFO: save the text embeds to avoid save text_encoder
save_d = {}
save_d['model_args'] = model_args
torch.save(save_d, './precomputed_text_embeds.pth')
To configure whether to use "precompute_text_embeds", add precompute_text_embeds = False
in the opensora *.py
config file (example: examples/opensora1.2/configs/software_simulation.py)
Generate the videos with FP16 precision with prompts specifed with --prompt_path
, the configs/software_simulation.py
specifies the details for generation (e.g., resolution). The generated videos are saved in the save_dir
path in the config. (The command line --prompt_path
overwrites the ones in the config file.)
python fp_inference.py configs/software_simulation.py \
--prompt-path ./prompts.txt
We introduce an additional config.yaml to specify the quantization details. The opensora config file *.py's ptq-config
attribute specifies which quant config to use. We provide 3 example for baseline and viditq quantization:
- examples/opensora1.2/configs/config.yaml
- examples/opensora1.2/configs/w8a8.yaml
- examples/opensora1.2/configs/w4a8_mixed_precision.yaml.
Some quantization techniques (e.g., smooth quant) requires calibration of activation distribution. The calib data generatio process involves conducting FP16 inference, and save the activation into calib_data.pth
. The calib data will be saved in the calib_data
sprcified in the quant_config (the ptq_config
file in opensora config).
python get_calib_data.py configs/software_simulation.py \
--prompt-path ./t2v_samples.txt
The PTQ process generates the quant_params (scaling factor and zero point in FP16 format) in the save_dir
.
The quantization related configurations (e.g., bitwidth) could be modified in the yaml foramted quant config.
python ptq.py configs/software_simulation.py --save-dir "./logs/w4a8_mp"
The quant_inference process reads the ptq generated quant params, and conduct software simulation of quantization inference (the process is still in FP16, and could take longer time than FP16 inference). It will generate videos in save_dir
.
python quant_inference.py configs/software_simulation.py --save-dir "./logs/w4a8_mp"
When the hardware
flag is set to True in the opensora config. The quant inference are adopted with cuda kernel to achieve "real" quantized inference. It reads the ptq generated quant params, and export the converted integer weights int_weight.pt
with reduced size, and generate videos with cuda kernels in the viditq_extention
package. Please be noted that
python quant_inference.py configs/cuda_kernel.py --save-dir "./logs/cuda_kernel_test"
Noted that the mixed precision quantization is supported by simply modifying the config. Specify the opensora config with quant config ./examples/opensora1.2/configs/w4a8_mixed_precision.yamls. Run both the ptq and the quant_infer process.
python ptq.py configs/software_simulation.py --save-dir "./logs/w8a8"
python quant_inference.py configs/software_simulation.py --save-dir "./logs/w8a8"
- install diffusers and download the pixart_sigma pipeline. (we provide example of locally save the pipeline in examples/pixart/download_huggingface_model.py)
Generate images in the log
path (precompute_text_embeds
is not supported but could be easily implemented).
python fp_inference.py --log "./logs/fp16"
We provide 3 example for baseline and viditq quantization:
- examples/opensora1.2/configs/config.yaml.
- examples/opensora1.2/configs/w4a8_mixed_precision.yaml.
- examples/pixart/configs/w4a8_mixed_precision.yaml
The calibration data are generated under the log
folder.
python get_calib_data.py --quant-config "./configs/${CFG}" --log "./logs/${LOG}" --prompt $PROMPT_PATH
python ptq.py configs/software_simulation.py --save-dir "./logs/w4a8_mp"
python quant_inference.py --quant-config "./configs/${CFG}" --log "./logs/${LOG}"
python quant_inference.py --quant-config "./configs/${CFG}" --log "./logs/${LOG}" --hardware
If you find our work helpful, please consider citing:
@misc{zhao2024viditq,
title={ViDiT-Q: Efficient and Accurate Quantization of Diffusion Transformers for Image and Video Generation},
author={Tianchen Zhao and Tongcheng Fang and Enshu Liu and Wan Rui and Widyadewi Soedarmadji and Shiyao Li and Zinan Lin and Guohao Dai and Shengen Yan and Huazhong Yang and Xuefei Ning and Yu Wang},
year={2024},
eprint={2406.02540},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Our code was developed based on opensora v1.0(Apache License), PixArt-alpha(AGPL-3.0 license), PixArt-sigama(AGPL-3.0 license) and q-diffusion(MIT License)