forked from Sentdex/pygta5
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2. train_model.py
117 lines (81 loc) · 2.58 KB
/
2. train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
from grabscreen import grab_screen
import cv2
import time
import os
import pandas as pd
from tqdm import tqdm
from collections import deque
from models import inception_v3 as googlenet
from random import shuffle
FILE_I_END = 1860
WIDTH = 480
HEIGHT = 270
LR = 1e-3
EPOCHS = 30
MODEL_NAME = ''
PREV_MODEL = ''
LOAD_MODEL = True
wl = 0
sl = 0
al = 0
dl = 0
wal = 0
wdl = 0
sal = 0
sdl = 0
nkl = 0
w = [1,0,0,0,0,0,0,0,0]
s = [0,1,0,0,0,0,0,0,0]
a = [0,0,1,0,0,0,0,0,0]
d = [0,0,0,1,0,0,0,0,0]
wa = [0,0,0,0,1,0,0,0,0]
wd = [0,0,0,0,0,1,0,0,0]
sa = [0,0,0,0,0,0,1,0,0]
sd = [0,0,0,0,0,0,0,1,0]
nk = [0,0,0,0,0,0,0,0,1]
model = googlenet(WIDTH, HEIGHT, 3, LR, output=9, model_name=MODEL_NAME)
if LOAD_MODEL:
model.load(PREV_MODEL)
print('We have loaded a previous model!!!!')
# iterates through the training files
for e in range(EPOCHS):
#data_order = [i for i in range(1,FILE_I_END+1)]
data_order = [i for i in range(1,FILE_I_END+1)]
shuffle(data_order)
for count,i in enumerate(data_order):
try:
file_name = 'J:/phase10-random-padded/training_data-{}.npy'.format(i)
# full file info
train_data = np.load(file_name)
print('training_data-{}.npy'.format(i),len(train_data))
## # [ [ [FRAMES], CHOICE ] ]
## train_data = []
## current_frames = deque(maxlen=HM_FRAMES)
##
## for ds in data:
## screen, choice = ds
## gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
##
##
## current_frames.append(gray_screen)
## if len(current_frames) == HM_FRAMES:
## train_data.append([list(current_frames),choice])
# #
# always validating unique data:
#shuffle(train_data)
train = train_data[:-50]
test = train_data[-50:]
X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,3)
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,3)
test_y = [i[1] for i in test]
model.fit({'input': X}, {'targets': Y}, n_epoch=1, validation_set=({'input': test_x}, {'targets': test_y}),
snapshot_step=2500, show_metric=True, run_id=MODEL_NAME)
if count%10 == 0:
print('SAVING MODEL!')
model.save(MODEL_NAME)
except Exception as e:
print(str(e))
#
#tensorboard --logdir=foo:J:/phase10-code/log