forked from PaddlePaddle/PaddleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add wav2lip training code (PaddlePaddle#142)
* add wav2lip trainning code
- Loading branch information
1 parent
776fe80
commit edd6211
Showing
23 changed files
with
1,531 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import argparse | ||
|
||
import paddle | ||
from ppgan.apps.wav2lip_predictor import Wav2LipPredictor | ||
|
||
parser = argparse.ArgumentParser( | ||
description= | ||
'Inference code to lip-sync videos in the wild using Wav2Lip models') | ||
|
||
parser.add_argument('--checkpoint_path', | ||
type=str, | ||
help='Name of saved checkpoint to load weights from', | ||
required=True) | ||
|
||
parser.add_argument('--face', | ||
type=str, | ||
help='Filepath of video/image that contains faces to use', | ||
required=True) | ||
parser.add_argument( | ||
'--audio', | ||
type=str, | ||
help='Filepath of video/audio file to use as raw audio source', | ||
required=True) | ||
parser.add_argument('--outfile', | ||
type=str, | ||
help='Video path to save result. See default for an e.g.', | ||
default='results/result_voice.mp4') | ||
|
||
parser.add_argument( | ||
'--static', | ||
type=bool, | ||
help='If True, then use only first video frame for inference', | ||
default=False) | ||
parser.add_argument( | ||
'--fps', | ||
type=float, | ||
help='Can be specified only if input is a static image (default: 25)', | ||
default=25., | ||
required=False) | ||
|
||
parser.add_argument( | ||
'--pads', | ||
nargs='+', | ||
type=int, | ||
default=[0, 10, 0, 0], | ||
help= | ||
'Padding (top, bottom, left, right). Please adjust to include chin at least' | ||
) | ||
|
||
parser.add_argument('--face_det_batch_size', | ||
type=int, | ||
help='Batch size for face detection', | ||
default=16) | ||
parser.add_argument('--wav2lip_batch_size', | ||
type=int, | ||
help='Batch size for Wav2Lip model(s)', | ||
default=128) | ||
|
||
parser.add_argument( | ||
'--resize_factor', | ||
default=1, | ||
type=int, | ||
help= | ||
'Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p' | ||
) | ||
|
||
parser.add_argument( | ||
'--crop', | ||
nargs='+', | ||
type=int, | ||
default=[0, -1, 0, -1], | ||
help= | ||
'Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' | ||
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width' | ||
) | ||
|
||
parser.add_argument( | ||
'--box', | ||
nargs='+', | ||
type=int, | ||
default=[-1, -1, -1, -1], | ||
help= | ||
'Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' | ||
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).' | ||
) | ||
|
||
parser.add_argument( | ||
'--rotate', | ||
default=False, | ||
action='store_true', | ||
help= | ||
'Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' | ||
'Use if you get a flipped result, despite feeding a normal looking video') | ||
|
||
parser.add_argument( | ||
'--nosmooth', | ||
default=False, | ||
action='store_true', | ||
help='Prevent smoothing face detections over a short temporal window') | ||
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") | ||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
if args.cpu: | ||
paddle.set_device('cpu') | ||
|
||
predictor = Wav2LipPredictor(args) | ||
predictor.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
total_iters: 200000000 | ||
output_dir: output | ||
checkpoints_dir: checkpoints | ||
|
||
model: | ||
name: Wav2LipModel | ||
syncnet_wt: 0.0 | ||
max_eval_steps: 700 | ||
generator: | ||
name: Wav2Lip | ||
discriminator: | ||
name: SyncNetColor | ||
|
||
dataset: | ||
train: | ||
name: Wav2LipDataset | ||
dataroot: data/lrs2_preprocessed | ||
filelists_path: ./ | ||
img_size: 96 | ||
split: train | ||
batch_size: 8 | ||
num_workers: 4 | ||
use_shared_memory: False | ||
test: | ||
name: Wav2LipDataset | ||
dataroot: data/lrs2_preprocessed | ||
filelists_path: ./ | ||
img_size: 96 | ||
split: val | ||
batch_size: 16 | ||
num_workers: 4 | ||
use_shared_memory: False | ||
|
||
optimizer: | ||
optimizer_G: | ||
name: Adam | ||
net_names: | ||
- netG | ||
beta1: 0.5 | ||
optimizer_D: | ||
name: Adam | ||
net_names: | ||
- netD | ||
beta1: 0.5 | ||
|
||
validate: | ||
interval: 3000 | ||
save_img: false | ||
|
||
lr_scheduler: | ||
name: LinearDecay | ||
learning_rate: 0.0001 | ||
start_epoch: 2000000 | ||
decay_epochs: 2000000 | ||
# will get from real dataset | ||
iters_per_epoch: 1 | ||
|
||
log_config: | ||
interval: 10 | ||
visiual_interval: 500 | ||
|
||
snapshot_config: | ||
interval: 3000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
total_iters: 200000000 | ||
output_dir: output_hq | ||
checkpoints_dir: checkpoints_hq | ||
|
||
model: | ||
name: Wav2LipModelHq | ||
syncnet_wt: 0. | ||
disc_wt: 0.07 | ||
max_eval_steps: 700 | ||
generator: | ||
name: Wav2Lip | ||
discriminator_sync: | ||
name: SyncNetColor | ||
discriminator_hq: | ||
name: Wav2LipDiscQual | ||
|
||
dataset: | ||
train: | ||
name: Wav2LipDataset | ||
dataroot: data/lrs2_preprocessed | ||
filelists_path: ./ | ||
img_size: 96 | ||
split: train | ||
batch_size: 8 | ||
num_workers: 0 | ||
use_shared_memory: False | ||
test: | ||
name: Wav2LipDataset | ||
dataroot: data/lrs2_preprocessed | ||
filelists_path: ./ | ||
img_size: 96 | ||
split: val | ||
batch_size: 16 | ||
num_workers: 0 | ||
use_shared_memory: False | ||
|
||
optimizer: | ||
optimizer_G: | ||
name: Adam | ||
net_names: | ||
- netG | ||
beta1: 0.5 | ||
optimizer_DS: | ||
name: Adam | ||
net_names: | ||
- netDS | ||
beta1: 0.5 | ||
optimizer_DH: | ||
name: Adam | ||
net_names: | ||
- netDH | ||
beta1: 0.5 | ||
|
||
validate: | ||
interval: 3000 | ||
save_img: false | ||
|
||
lr_scheduler: | ||
name: LinearDecay | ||
learning_rate: 0.0001 | ||
start_epoch: 2000000 | ||
decay_epochs: 2000000 | ||
# will get from real dataset | ||
iters_per_epoch: 1 | ||
|
||
log_config: | ||
interval: 10 | ||
visiual_interval: 500 | ||
|
||
snapshot_config: | ||
interval: 3000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import sys | ||
|
||
if sys.version_info[0] < 3 and sys.version_info[1] < 2: | ||
raise Exception("Must be using >= Python 3.2") | ||
|
||
from os import listdir, path | ||
|
||
import multiprocessing as mp | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
import numpy as np | ||
import argparse, os, cv2, traceback, subprocess | ||
from tqdm import tqdm | ||
from glob import glob | ||
|
||
from ppgan.utils import audio | ||
from ppgan.faceutils import face_detection | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--ngpu', | ||
help='Number of GPUs across which to run in parallel', | ||
default=1, | ||
type=int) | ||
parser.add_argument('--batch_size', | ||
help='Single GPU Face detection batch size', | ||
default=32, | ||
type=int) | ||
parser.add_argument("--data_root", | ||
help="Root folder of the LRS2 dataset", | ||
required=True) | ||
parser.add_argument("--preprocessed_root", | ||
help="Root folder of the preprocessed dataset", | ||
required=True) | ||
|
||
args = parser.parse_args() | ||
|
||
fa = [ | ||
face_detection.FaceAlignment(face_detection.LandmarksType._2D, | ||
flip_input=False) | ||
] | ||
|
||
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' | ||
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}' | ||
|
||
|
||
def process_video_file(vfile, args, gpu_id): | ||
video_stream = cv2.VideoCapture(vfile) | ||
|
||
frames = [] | ||
while 1: | ||
still_reading, frame = video_stream.read() | ||
if not still_reading: | ||
video_stream.release() | ||
break | ||
frames.append(frame) | ||
|
||
vidname = os.path.basename(vfile).split('.')[0] | ||
dirname = vfile.split('/')[-2] | ||
|
||
fulldir = path.join(args.preprocessed_root, dirname, vidname) | ||
os.makedirs(fulldir, exist_ok=True) | ||
|
||
batches = [ | ||
frames[i:i + args.batch_size] | ||
for i in range(0, len(frames), args.batch_size) | ||
] | ||
|
||
i = -1 | ||
for fb in batches: | ||
preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb)) | ||
|
||
for j, f in enumerate(preds): | ||
i += 1 | ||
if f is None: | ||
continue | ||
|
||
x1, y1, x2, y2 = f | ||
cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, | ||
x1:x2]) | ||
|
||
|
||
def process_audio_file(vfile, args): | ||
vidname = os.path.basename(vfile).split('.')[0] | ||
dirname = vfile.split('/')[-2] | ||
|
||
fulldir = path.join(args.preprocessed_root, dirname, vidname) | ||
os.makedirs(fulldir, exist_ok=True) | ||
|
||
wavpath = path.join(fulldir, 'audio.wav') | ||
|
||
command = template.format(vfile, wavpath) | ||
subprocess.call(command, shell=True) | ||
|
||
|
||
def mp_handler(job): | ||
vfile, args, gpu_id = job | ||
try: | ||
process_video_file(vfile, args, gpu_id) | ||
except KeyboardInterrupt: | ||
exit(0) | ||
except: | ||
traceback.print_exc() | ||
|
||
|
||
def main(args): | ||
print('Started processing for {} with {} GPUs'.format( | ||
args.data_root, args.ngpu)) | ||
|
||
filelist = glob(path.join(args.data_root, '*/*.mp4')) | ||
|
||
jobs = [(vfile, args, i % args.ngpu) for i, vfile in enumerate(filelist)] | ||
p = ThreadPoolExecutor(args.ngpu) | ||
futures = [p.submit(mp_handler, j) for j in jobs] | ||
_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] | ||
|
||
print('Dumping audios...') | ||
|
||
for vfile in tqdm(filelist): | ||
try: | ||
process_audio_file(vfile, args) | ||
except KeyboardInterrupt: | ||
exit(0) | ||
except: | ||
traceback.print_exc() | ||
continue | ||
|
||
|
||
if __name__ == '__main__': | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.