Skip to content

Commit 8ce59df

Browse files
committed
first commit
0 parents  commit 8ce59df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+7279
-0
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.vscode
2+
preprocess

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Xin Ma
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
## LAVITA: Latent Video Diffusion Models with Spatio-temporal Transformers (LAVITA)<br><sub>Official PyTorch Implementation</sub>
2+
3+
### [Paper](https://maxin-cn.github.io/lavita_project/) | [Project Page](https://maxin-cn.github.io/lavita_project/)
4+
5+
6+
7+
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
8+
latent diffusion models with transformers (LAVITA). You can find more visualizations on our [project page](https://maxin-cn.github.io/lavita_project/).
9+
10+
> [**LAVITA: Latent Video Diffusion Models with Spatio-temporal Transformers**](https://maxin-cn.github.io/lavita_project/)<br>
11+
> [Xin Ma](https://maxin-cn.github.io/), [Yaohui Wang](https://wyhsirius.github.io/), [Xinyuan Chen](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Yuan-Fang Li](https://users.monash.edu/~yli/), [Cunjian Chen](https://cunjian.github.io/), [Ziwei Liu](https://liuziwei7.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=zh-CN)
12+
> <br>Department of Data Science \& AI, Faculty of Information Technology, Monash University <br> Shanghai Artificial Intelligence Laboratory, S-Lab, Nanyang Technological University<br>
13+
14+
We propose a novel architecture, the latent video diffusion model with spatio-temporal transformers, referred to as LAVITA, which integrates the Transformer architecture into diffusion models for the first time within the realm of video generation. Conceptually, LATIVA models spatial and temporal information separately to accommodate their inherent disparities as well as to reduce the computational complexity. Following this design strategy, we design several Transformer-based model variants to integrate spatial and temporal information harmoniously. Moreover, we identify the best practices in architectural choices and learning strategies for LAVITA through rigorous empirical analysis. Our comprehensive evaluation demonstrates that LAVITA achieves state-of-the-art performance across several standard video generation benchmarks, including FaceForensics, SkyTimelapse, UCF101, and Taichi-HD, outperforming current best models.
15+
16+
![The architecure of LAVITA](visuals/architecture.svg)
17+
18+
This repository contains:
19+
20+
* 🪐 A simple PyTorch [implementation](models/lavita.py.py) of LAVITA
21+
* ⚡️ Pre-trained LAVITA models trained on FaceForensics, SkyTimelapse, Taichi-HD and UCF101 (256x256)
22+
23+
* 🛸 A LAVITA [training script](train.py) using PyTorch DDP
24+
25+
26+
27+
## Setup
28+
29+
First, download and set up the repo:
30+
31+
```bash
32+
git clone https://github.com/maxin-cn/LAVITA.git
33+
cd LAVITA
34+
```
35+
36+
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
37+
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
38+
39+
```bash
40+
conda env create -f environment.yml
41+
conda activate lavita
42+
```
43+
44+
45+
## Sampling
46+
47+
**Pre-trained LAVITA checkpoints.** You can sample from our pre-trained LAVITA models with [`sample.py`](sample/sample.py). Weights for our pre-trained LAVITA model can be found [here](https://huggingface.co/maxin-cn/LAVITA). The script has various arguments to adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
48+
our model on FaceForensics, you can use:
49+
50+
```bash
51+
bash sample/ffs.sh
52+
```
53+
54+
or if you want to sample hundreds of videos, you can use the following script with Pytorch DDP:
55+
56+
```bash
57+
bash sample/ffs_ddp.sh
58+
```
59+
60+
## Training LAVITA
61+
62+
We provide a training script for LAVITA in [`train.py`](train.py). This script can be used to train class-conditional and unconditional
63+
LAVITA models. To launch LAVITA (256x256) training with `N` GPUs on the FaceForensics dataset
64+
:
65+
66+
```bash
67+
torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_train.yaml
68+
```
69+
70+
or If you have a cluster that uses slurm, you can also train LAVITA's model using the following scripts:
71+
72+
```bash
73+
sbatch slurm_scripts/ffs.slurm
74+
```
75+
76+
We also provide the video-image joint training scripts [`train_with_img.py`](train_with_img.py). Similar to [`train.py`](train.py) scripts, this scripts can be also used to train class-conditional and unconditional
77+
LAVITA models. For example, if you wan to train LAVITA model on the FaceForensics dataset, you can use:
78+
79+
```bash
80+
torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_img_train.yaml
81+
```
82+
83+
<!-- ## BibTeX
84+
85+
```bibtex
86+
@article{Peebles2022DiT,
87+
title={Scalable Diffusion Models with Transformers},
88+
author={William Peebles and Saining Xie},
89+
year={2022},
90+
journal={arXiv preprint arXiv:2212.09748},
91+
}
92+
``` -->
93+
94+
95+
## Acknowledgments
96+
Video generation models are improving quickly and the development of LAVITA has been greatly inspired by the following amazing works and teams: [DiT](https://github.com/facebookresearch/DiT), [U-ViT](https://github.com/baofff/U-ViT), and [Tune-A-Video](https://github.com/showlab/Tune-A-Video).
97+
98+
99+
## License
100+
The code and model weights are licensed under [CC-BY-NC](license_for_usage.txt).

configs/ffs/ffs_img_train.yaml

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# dataset
2+
dataset: "ffs_img"
3+
4+
data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
5+
frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/"
6+
frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt"
7+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
8+
9+
# save and load
10+
results_dir: "./results_img"
11+
pretrained:
12+
13+
# model config:
14+
model: LAVITAIMG-XL/2
15+
num_frames: 16
16+
image_size: 256 # choices=[256, 512]
17+
num_sampling_steps: 250
18+
frame_interval: 3
19+
fixed_spatial: False
20+
attention_bias: True
21+
learn_sigma: True # important
22+
extras: 1 # [1, 2, 78]
23+
24+
# train config:
25+
save_ceph: True # important
26+
use_image_num: 8
27+
learning_rate: 1e-4
28+
ckpt_every: 10000
29+
clip_max_norm: 0.1
30+
start_clip_iter: 500000
31+
local_batch_size: 4 # important
32+
max_train_steps: 1000000
33+
global_seed: 3407
34+
num_workers: 8
35+
log_every: 100
36+
lr_warmup_steps: 0
37+
resume_from_checkpoint:
38+
gradient_accumulation_steps: 1 # TODO
39+
num_classes:
40+
41+
# low VRAM and speed up training
42+
use_compile: False
43+
mixed_precision: False
44+
enable_xformers_memory_efficient_attention: False
45+
gradient_checkpointing: False

configs/ffs/ffs_sample.yaml

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# path:
2+
ckpt: # will be overwrite
3+
save_img_path: "./sample_videos" # will be overwrite
4+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
5+
6+
# model config:
7+
model: LAVITA-XL/2
8+
num_frames: 16
9+
image_size: 256 # choices=[256, 512]
10+
frame_interval: 2
11+
fixed_spatial: False
12+
attention_bias: True
13+
learn_sigma: True
14+
extras: 1 # [1, 2, 78]
15+
num_classes:
16+
17+
# model speedup
18+
use_compile: False
19+
use_fp16: True
20+
21+
# sample config:
22+
seed:
23+
sample_method: 'ddpm'
24+
num_sampling_steps: 250
25+
cfg_scale: 1.0
26+
negative_name:
27+
28+
# ddp sample config
29+
per_proc_batch_size: 2
30+
num_fvd_samples: 2048

configs/ffs/ffs_train.yaml

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# dataset
2+
dataset: "ffs"
3+
4+
data_path: "/path/to/datasets/preprocess_ffs/train/videos/" # s
5+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
6+
7+
# save and load
8+
results_dir: "./results"
9+
pretrained:
10+
11+
# model config:
12+
model: LAVITA-XL/2
13+
num_frames: 16
14+
image_size: 256 # choices=[256, 512]
15+
num_sampling_steps: 250
16+
frame_interval: 3
17+
fixed_spatial: False
18+
attention_bias: True
19+
learn_sigma: True # important
20+
extras: 1 # [1, 2, 78]
21+
22+
# train config:
23+
save_ceph: True # important
24+
learning_rate: 1e-4
25+
ckpt_every: 10000
26+
clip_max_norm: 0.1
27+
start_clip_iter: 20000
28+
local_batch_size: 5 # important
29+
max_train_steps: 1000000
30+
global_seed: 3407
31+
num_workers: 8
32+
log_every: 100
33+
lr_warmup_steps: 0
34+
resume_from_checkpoint:
35+
gradient_accumulation_steps: 1 # TODO
36+
num_classes:
37+
38+
# low VRAM and speed up training
39+
use_compile: False
40+
mixed_precision: False
41+
enable_xformers_memory_efficient_attention: False
42+
gradient_checkpointing: False

configs/sky/sky_img_train.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# dataset
2+
dataset: "sky_img"
3+
4+
data_path: "/path/to/datasets/sky_timelapse/sky_train/" # s/p
5+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
6+
7+
# save and load
8+
results_dir: "./results_img"
9+
pretrained:
10+
11+
# model config:
12+
model: LAVITAIMG-XL/2
13+
num_frames: 16
14+
image_size: 256 # choices=[256, 512]
15+
num_sampling_steps: 250
16+
frame_interval: 3
17+
fixed_spatial: False
18+
attention_bias: True
19+
learn_sigma: True
20+
extras: 1 # [1, 2, 78]
21+
22+
# train config:
23+
save_ceph: True # important
24+
use_image_num: 8 # important
25+
learning_rate: 1e-4
26+
ckpt_every: 10000
27+
clip_max_norm: 0.1
28+
start_clip_iter: 20000
29+
local_batch_size: 4 # important
30+
max_train_steps: 1000000
31+
global_seed: 3407
32+
num_workers: 8
33+
log_every: 50
34+
lr_warmup_steps: 0
35+
resume_from_checkpoint:
36+
gradient_accumulation_steps: 1 # TODO
37+
num_classes:
38+
39+
# low VRAM and speed up training
40+
use_compile: False
41+
mixed_precision: False
42+
enable_xformers_memory_efficient_attention: False
43+
gradient_checkpointing: False

configs/sky/sky_sample.yaml

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# path:
2+
ckpt: # will be overwrite
3+
save_img_path: "./sample_videos/" # will be overwrite
4+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
5+
6+
# model config:
7+
model: LAVITA-XL/2
8+
num_frames: 16
9+
image_size: 256 # choices=[256, 512]
10+
frame_interval: 2
11+
fixed_spatial: False
12+
attention_bias: True
13+
learn_sigma: True
14+
extras: 1 # [1, 2, 78]
15+
num_classes:
16+
17+
# model speedup
18+
use_compile: False
19+
use_fp16: True
20+
21+
# sample config:
22+
seed:
23+
sample_method: 'ddpm'
24+
num_sampling_steps: 250
25+
cfg_scale: 1.0
26+
run_time: 12
27+
num_sample: 1
28+
negative_name:
29+
30+
# ddp sample config
31+
per_proc_batch_size: 1
32+
num_fvd_samples: 2

configs/sky/sky_train.yaml

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# dataset
2+
dataset: "sky"
3+
4+
data_path: "/path/to/datasets/sky_timelapse/sky_train/"
5+
pretrained_model_path: "/path/to/pretrained/LAVITA/"
6+
7+
# save and load
8+
results_dir: "./results"
9+
pretrained:
10+
11+
# model config:
12+
model: LAVITA-XL/2
13+
num_frames: 16
14+
image_size: 256 # choices=[256, 512]
15+
num_sampling_steps: 250
16+
frame_interval: 3
17+
fixed_spatial: False
18+
attention_bias: True
19+
learn_sigma: True
20+
extras: 1 # [1, 2, 78]
21+
22+
# train config:
23+
save_ceph: True # important
24+
learning_rate: 1e-4
25+
ckpt_every: 10000
26+
clip_max_norm: 0.1
27+
start_clip_iter: 20000
28+
local_batch_size: 5 # important
29+
max_train_steps: 1000000
30+
global_seed: 3407
31+
num_workers: 8
32+
log_every: 50
33+
lr_warmup_steps: 0
34+
resume_from_checkpoint:
35+
gradient_accumulation_steps: 1 # TODO
36+
num_classes:
37+
38+
# low VRAM and speed up training
39+
use_compile: False
40+
mixed_precision: False
41+
enable_xformers_memory_efficient_attention: False
42+
gradient_checkpointing: False

0 commit comments

Comments
 (0)