-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain_kth_rladder.py
68 lines (54 loc) · 1.85 KB
/
main_kth_rladder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from model.kth.model_rladder import *
from analysis.quantitative import print_results
from analysis.qualitative import save_sequences
from analysis.layer_removal import test_layer_subsets
import tensorflow as tf
import cPickle
def train(resume_training=False):
# If resuming previous training, load from disk
if resume_training:
MODEL.load(SAVE_PATH)
# Train future frame predictor
MODEL.train(
x=TRAIN_PREPROCESSOR,
batch_size=BATCH_SIZE,
iterations=TRAIN_STEPS,
save_path=SAVE_PATH,
save_frequency=100,
)
def test():
# Load model from disk and perform testing
MODEL.load(SAVE_PATH)
errs_pred, errs_base = MODEL.test(
x=TEST_PREPROCESSOR,
batch_size=BATCH_SIZE,
metric=('mse', 'psnr', 'dssim'),
)
# Save and print baseline/prediction errors
cPickle.dump((errs_pred, errs_base), open(SAVE_ROOT + 'predictions.pkl', 'wb'), cPickle.HIGHEST_PROTOCOL)
print_results('Baseline errors', errs_base)
print_results('Prediction errors', errs_pred)
def run():
# Load from disk and start predicting
MODEL.load(SAVE_PATH)
predictions = MODEL.run(
x=TEST_PREPROCESSOR,
batch_size=BATCH_SIZE,
plot=True,
)
def analyse():
# Load results from disk
errs_pred, errs_base = cPickle.load(open(SAVE_ROOT + 'predictions.pkl', 'rb'))
# Print results
print_results('Baseline errors', errs_base)
print_results('Prediction errors', errs_pred)
# Plot best predictions, perform layer removal
MODEL.load(SAVE_PATH)
indices, _, _ = save_sequences(MODEL, TEST_PREPROCESSOR, SAVE_ROOT + 'qual/')
test_layer_subsets(MODEL, TEST_PREPROCESSOR, indices, SAVE_ROOT + 'lremoval/')
if __name__ == '__main__':
with tf.device(DEVICE):
train(resume_training=False)
test()
analyse()
# run()