-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
55 lines (39 loc) · 1.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import custom_layers
import custom_models
import keras
from keras import layers, models, ops, optimizers, preprocessing
import os
import tensorflow as tf
## adjust config dictionaries for better results. widths are far too low, higher depth, higher latent_filters for better recon. use kl_annealing for lower kl
model_config = {
"encoder_widths": [16, 32],
"decoder_widths": [32, 16],
"kernel_size": 3,
"activation": "swish",
"use_bn": True,
"depth": 1,
"latent_filters": 3,
"output_channels": 3,
"kl_beta": 1
}
dataset_config = {
"directory": "celeba",
"image_size": (64, 64),
"label_mode": None,
"batch_size": 32,
}
model_compile = {
"optimizer": optimizers.Adam(learning_rate=1e-4)
}
config = {
"model_config": model_config,
"optimizer_config": model_compile,
"dataset_config": dataset_config
}
dataset = preprocessing.image_dataset_from_directory(**config["dataset_config"])
dataset = dataset.map(lambda x: (x / 255)).repeat()
vae = custom_models.VAE(**config["model_config"])
vae.compile(**config["optimizer_config"])
vae.fit(dataset, epochs=1000, steps_per_epoch=1000, callbacks=keras.callbacks.LambdaCallback(vae.save_model))
vae.encoder.save_weights("vae_encoder.weights.h5")
vae.decoder.save_weights("vae_decoder.weights.h5")