Skip to content

Commit

Permalink
add wav2lip training code (PaddlePaddle#142)
Browse files Browse the repository at this point in the history
* add wav2lip trainning code
  • Loading branch information
lijianshe02 authored Jan 15, 2021
1 parent 776fe80 commit edd6211
Show file tree
Hide file tree
Showing 23 changed files with 1,531 additions and 97 deletions.
108 changes: 108 additions & 0 deletions applications/tools/wav2lip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse

import paddle
from ppgan.apps.wav2lip_predictor import Wav2LipPredictor

parser = argparse.ArgumentParser(
description=
'Inference code to lip-sync videos in the wild using Wav2Lip models')

parser.add_argument('--checkpoint_path',
type=str,
help='Name of saved checkpoint to load weights from',
required=True)

parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument(
'--audio',
type=str,
help='Filepath of video/audio file to use as raw audio source',
required=True)
parser.add_argument('--outfile',
type=str,
help='Video path to save result. See default for an e.g.',
default='results/result_voice.mp4')

parser.add_argument(
'--static',
type=bool,
help='If True, then use only first video frame for inference',
default=False)
parser.add_argument(
'--fps',
type=float,
help='Can be specified only if input is a static image (default: 25)',
default=25.,
required=False)

parser.add_argument(
'--pads',
nargs='+',
type=int,
default=[0, 10, 0, 0],
help=
'Padding (top, bottom, left, right). Please adjust to include chin at least'
)

parser.add_argument('--face_det_batch_size',
type=int,
help='Batch size for face detection',
default=16)
parser.add_argument('--wav2lip_batch_size',
type=int,
help='Batch size for Wav2Lip model(s)',
default=128)

parser.add_argument(
'--resize_factor',
default=1,
type=int,
help=
'Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p'
)

parser.add_argument(
'--crop',
nargs='+',
type=int,
default=[0, -1, 0, -1],
help=
'Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width'
)

parser.add_argument(
'--box',
nargs='+',
type=int,
default=[-1, -1, -1, -1],
help=
'Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).'
)

parser.add_argument(
'--rotate',
default=False,
action='store_true',
help=
'Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
'Use if you get a flipped result, despite feeding a normal looking video')

parser.add_argument(
'--nosmooth',
default=False,
action='store_true',
help='Prevent smoothing face detections over a short temporal window')
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")

if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')

predictor = Wav2LipPredictor(args)
predictor.run()
63 changes: 63 additions & 0 deletions configs/wav2lip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
total_iters: 200000000
output_dir: output
checkpoints_dir: checkpoints

model:
name: Wav2LipModel
syncnet_wt: 0.0
max_eval_steps: 700
generator:
name: Wav2Lip
discriminator:
name: SyncNetColor

dataset:
train:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: train
batch_size: 8
num_workers: 4
use_shared_memory: False
test:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: val
batch_size: 16
num_workers: 4
use_shared_memory: False

optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5

validate:
interval: 3000
save_img: false

lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 2000000
decay_epochs: 2000000
# will get from real dataset
iters_per_epoch: 1

log_config:
interval: 10
visiual_interval: 500

snapshot_config:
interval: 3000
71 changes: 71 additions & 0 deletions configs/wav2lip_hq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
total_iters: 200000000
output_dir: output_hq
checkpoints_dir: checkpoints_hq

model:
name: Wav2LipModelHq
syncnet_wt: 0.
disc_wt: 0.07
max_eval_steps: 700
generator:
name: Wav2Lip
discriminator_sync:
name: SyncNetColor
discriminator_hq:
name: Wav2LipDiscQual

dataset:
train:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: train
batch_size: 8
num_workers: 0
use_shared_memory: False
test:
name: Wav2LipDataset
dataroot: data/lrs2_preprocessed
filelists_path: ./
img_size: 96
split: val
batch_size: 16
num_workers: 0
use_shared_memory: False

optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_DS:
name: Adam
net_names:
- netDS
beta1: 0.5
optimizer_DH:
name: Adam
net_names:
- netDH
beta1: 0.5

validate:
interval: 3000
save_img: false

lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 2000000
decay_epochs: 2000000
# will get from real dataset
iters_per_epoch: 1

log_config:
interval: 10
visiual_interval: 500

snapshot_config:
interval: 3000
129 changes: 129 additions & 0 deletions lsr2_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import sys

if sys.version_info[0] < 3 and sys.version_info[1] < 2:
raise Exception("Must be using >= Python 3.2")

from os import listdir, path

import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import argparse, os, cv2, traceback, subprocess
from tqdm import tqdm
from glob import glob

from ppgan.utils import audio
from ppgan.faceutils import face_detection

parser = argparse.ArgumentParser()

parser.add_argument('--ngpu',
help='Number of GPUs across which to run in parallel',
default=1,
type=int)
parser.add_argument('--batch_size',
help='Single GPU Face detection batch size',
default=32,
type=int)
parser.add_argument("--data_root",
help="Root folder of the LRS2 dataset",
required=True)
parser.add_argument("--preprocessed_root",
help="Root folder of the preprocessed dataset",
required=True)

args = parser.parse_args()

fa = [
face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False)
]

template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'


def process_video_file(vfile, args, gpu_id):
video_stream = cv2.VideoCapture(vfile)

frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
frames.append(frame)

vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]

fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)

batches = [
frames[i:i + args.batch_size]
for i in range(0, len(frames), args.batch_size)
]

i = -1
for fb in batches:
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))

for j, f in enumerate(preds):
i += 1
if f is None:
continue

x1, y1, x2, y2 = f
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2,
x1:x2])


def process_audio_file(vfile, args):
vidname = os.path.basename(vfile).split('.')[0]
dirname = vfile.split('/')[-2]

fulldir = path.join(args.preprocessed_root, dirname, vidname)
os.makedirs(fulldir, exist_ok=True)

wavpath = path.join(fulldir, 'audio.wav')

command = template.format(vfile, wavpath)
subprocess.call(command, shell=True)


def mp_handler(job):
vfile, args, gpu_id = job
try:
process_video_file(vfile, args, gpu_id)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()


def main(args):
print('Started processing for {} with {} GPUs'.format(
args.data_root, args.ngpu))

filelist = glob(path.join(args.data_root, '*/*.mp4'))

jobs = [(vfile, args, i % args.ngpu) for i, vfile in enumerate(filelist)]
p = ThreadPoolExecutor(args.ngpu)
futures = [p.submit(mp_handler, j) for j in jobs]
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]

print('Dumping audios...')

for vfile in tqdm(filelist):
try:
process_audio_file(vfile, args)
except KeyboardInterrupt:
exit(0)
except:
traceback.print_exc()
continue


if __name__ == '__main__':
main(args)
1 change: 1 addition & 0 deletions ppgan/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor
Loading

0 comments on commit edd6211

Please sign in to comment.