Skip to content

Commit

Permalink
Changes I forgot to push :-/ (deepfakes#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
Clorr authored Feb 7, 2018
1 parent b3ae613 commit 5815baa
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
19 changes: 9 additions & 10 deletions lib/ModelAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import numpy
from lib.training_data import TrainingDataGenerator, stack_images

encoderH5 = '/encoder.h5'
decoder_AH5 = '/decoder_A.h5'
decoder_BH5 = '/decoder_B.h5'
encoderH5 = 'encoder.h5'
decoder_AH5 = 'decoder_A.h5'
decoder_BH5 = 'decoder_B.h5'

class ModelAE:
def __init__(self, model_dir):

self.model_dir = model_dir

self.encoder = self.Encoder()
Expand All @@ -23,9 +22,9 @@ def load(self, swapped):
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)

try:
self.encoder.load_weights(self.model_dir + encoderH5)
self.decoder_A.load_weights(self.model_dir + face_A)
self.decoder_B.load_weights(self.model_dir + face_B)
self.encoder.load_weights(self.model_dir / encoderH5)
self.decoder_A.load_weights(self.model_dir / face_A)
self.decoder_B.load_weights(self.model_dir / face_B)
print('loaded model weights')
return True
except Exception as e:
Expand All @@ -34,9 +33,9 @@ def load(self, swapped):
return False

def save_weights(self):
self.encoder.save_weights(self.model_dir + encoderH5)
self.decoder_A.save_weights(self.model_dir + decoder_AH5)
self.decoder_B.save_weights(self.model_dir + decoder_BH5)
self.encoder.save_weights(self.model_dir / encoderH5)
self.decoder_A.save_weights(self.model_dir / decoder_AH5)
self.decoder_B.save_weights(self.model_dir / decoder_BH5)
print('saved model weights')

class TrainerAE():
Expand Down
6 changes: 3 additions & 3 deletions lib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class DirectoryProcessor(object):
input_dir = None
output_dir = None

verify_output = False
images_found = 0
faces_detected = 0
verify_output = False

def __init__(self, subparser, command, description='default'):
self.create_parser(subparser, command, description)
Expand Down Expand Up @@ -64,9 +64,9 @@ def get_faces(self, image):
yield faces_count, face

self.faces_detected = self.faces_detected + 1
faces_count +=1
faces_count += 1

if faces_count > 0 and self.arguments.verbose:
if faces_count > 1 and self.arguments.verbose:
print('Note: Found more than one face in an image!')
self.verify_output = True

Expand Down
17 changes: 9 additions & 8 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm

from lib.cli import DirectoryProcessor, FullPaths
from lib.utils import BackgroundGenerator
from lib.utils import BackgroundGenerator, get_folder

from plugins.PluginLoader import PluginLoader

Expand Down Expand Up @@ -116,7 +116,7 @@ def process(self):
model_name = self.arguments.trainer
conv_name = self.arguments.converter

model = PluginLoader.get_model(model_name)(self.arguments.model_dir)
model = PluginLoader.get_model(model_name)(get_folder(self.arguments.model_dir))
if not model.load(self.arguments.swap_model):
print('Model Not Found! A valid model must be provided to continue!')
exit(1)
Expand Down Expand Up @@ -160,15 +160,16 @@ def check_skipframe(self, filename):
def convert(self, converter, item):
try:
(filename, image, faces) = item

if not self.check_skipframe(filename): # process as normal
for idx, face in faces:
image = converter.patch_image(image, face)

output_file = self.output_dir / Path(filename).name

skip = self.check_skip(filename)
if self.arguments.discard_frames and skip:
return

if not skip: # process as normal
for idx, face in faces:
image = converter.patch_image(image, face)

output_file = get_folder(self.output_dir) / Path(filename).name
cv2.imwrite(str(output_file), image)
except Exception as e:
print('Failed to convert image: {}. Reason: {}'.format(filename, e))
Expand Down
4 changes: 3 additions & 1 deletion scripts/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

from lib.cli import DirectoryProcessor
from lib.utils import get_folder
from lib.multithreading import pool_process
from plugins.PluginLoader import PluginLoader

Expand Down Expand Up @@ -33,6 +34,7 @@ def add_optional_arguments(self, parser):

parser.add_argument('-j', '--processes',
type=int,
default=1,
help="Number of processes to use.")
return parser

Expand Down Expand Up @@ -65,7 +67,7 @@ def handleImage(self, filename):
for idx, face in self.get_faces(image):
count = idx
resized_image = self.extractor.extract(image, face, 256)
output_file = self.output_dir / Path(filename).stem
output_file = get_folder(self.output_dir) / Path(filename).stem
cv2.imwrite(str(output_file) + str(idx) + Path(filename).suffix, resized_image)
return count + 1

4 changes: 2 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

from threading import Lock
from lib.utils import get_image_paths
from lib.utils import get_image_paths, get_folder
from lib.cli import FullPaths
from plugins.PluginLoader import PluginLoader

Expand Down Expand Up @@ -122,7 +122,7 @@ def processThread(self):
# this is so that you can enter case insensitive values for trainer
trainer = self.arguments.trainer
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
model = PluginLoader.get_model(trainer)(self.arguments.model_dir)
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir))
model.load(swapped=False)

images_A = get_image_paths(self.arguments.input_A)
Expand Down

0 comments on commit 5815baa

Please sign in to comment.