forked from deepfakes/faceswap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelAE.py
77 lines (61 loc) · 2.71 KB
/
ModelAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# AutoEncoder base classes
import time
import numpy
from lib.training_data import minibatchAB, stack_images
encoderH5 = '/encoder.h5'
decoder_AH5 = '/decoder_A.h5'
decoder_BH5 = '/decoder_B.h5'
class ModelAE:
def __init__(self, model_dir):
self.model_dir = model_dir
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)
try:
self.encoder.load_weights(self.model_dir + encoderH5)
self.decoder_A.load_weights(self.model_dir + face_A)
self.decoder_B.load_weights(self.model_dir + face_B)
print('loaded model weights')
return True
except Exception as e:
print('Failed loading existing training data.')
print(e)
return False
def save_weights(self):
self.encoder.save_weights(self.model_dir + encoderH5)
self.decoder_A.save_weights(self.model_dir + decoder_AH5)
self.decoder_B.save_weights(self.model_dir + decoder_BH5)
print('saved model weights')
class TrainerAE():
def __init__(self, model, fn_A, fn_B, batch_size=64):
self.batch_size = batch_size
self.model = model
self.images_A = minibatchAB(fn_A, self.batch_size)
self.images_B = minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
figure = numpy.concatenate([figure_A, figure_B], axis=0)
figure = figure.reshape((4, 7) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')