forked from Kyubyong/cross_vc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain1.py
93 lines (72 loc) · 2.89 KB
/
train1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
# /usr/bin/python2
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/cross_vc
'''
from __future__ import print_function
from hparams import Hyperparams as hp
from tqdm import tqdm
from graph import Graph
import tensorflow as tf
from data_load import load_data
import numpy as np
def train1():
g = Graph(); print("Training Graph loaded")
logdir = hp.logdir + "/train1"
# Session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Restore saved variables
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net1') + \
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'training')
saver = tf.train.Saver(var_list=var_list)
ckpt = tf.train.latest_checkpoint(logdir)
if ckpt is not None: saver.restore(sess, ckpt)
# Writer & Queue
writer = tf.summary.FileWriter(logdir, sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Inspect variables
saver.save(sess, logdir + '/model_gs_0')
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name=hp.logdir+'/train1/model_gs_0', tensor_name='', all_tensors=False)
# Training
while 1:
for _ in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'):
gs, _ = sess.run([g.global_step, g.train_op])
# Write checkpoint files at every epoch
merged = sess.run(g.merged)
writer.add_summary(merged, global_step=gs)
# evaluation
with tf.Graph().as_default(): eval1()
# Save
saver.save(sess, logdir + '/model_gs_{}'.format(gs))
if gs > 10000: break
writer.close()
coord.request_stop()
coord.join(threads)
def eval1():
# Load data
mfccs, phns = load_data(mode="eval1")
# Graph
g = Graph("eval1"); print("Evaluation Graph loaded")
logdir = hp.logdir + "/train1"
# Session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Restore saved variables
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'net1') +\
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'training')
saver = tf.train.Saver(var_list=var_list)
ckpt = tf.train.latest_checkpoint(logdir)
if ckpt is not None: saver.restore(sess, ckpt)
# Writer
writer = tf.summary.FileWriter(logdir, sess.graph)
# Evaluation
merged, gs = sess.run([g.merged, g.global_step], {g.mfccs: mfccs, g.phones: phns})
# Write summaries
writer.add_summary(merged, global_step=gs)
writer.close()
if __name__ == '__main__':
train1(); print("Done")