forked from PWhiddy/PokemonRedExperiments
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrender_all_needed_grids.py
77 lines (68 loc) · 2.98 KB
/
render_all_needed_grids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from os.path import exists
from pathlib import Path
import sys
import uuid
from red_gym_env import RedGymEnv
from stable_baselines3 import A2C, PPO
from stable_baselines3.common import env_checker
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.callbacks import CheckpointCallback
from argparse_pokemon import *
def make_env(rank, env_conf, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = RedGymEnv(env_conf)
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
def run_save(save):
save = Path(save)
ep_length = 2048 * 8
sess_path = f'grid_renders/session_{save.stem}'
args = get_args(usage_string="render_all_needed_grids.py save", ep_length=ep_length, sess_path=sess_path)
env_config = {
'headless': True, 'save_final_state': True, 'early_stop': False,
'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length,
'print_rewards': True, 'save_video': True, 'fast_video': False, 'session_path': sess_path,
'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0
}
env_config = change_env(env_config, args)
num_cpu = 40 # Also sets the number of episodes per training iteration
env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)])
checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path,
name_prefix='poke')
#env_checker.check_env(env)
learn_steps = 1
file_name = save
if exists(file_name):
print('\nloading checkpoint')
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
"n_steps": ep_length
}
model = PPO.load(file_name, env=env, custom_objects=custom_objects)
model.n_steps = ep_length
model.n_envs = num_cpu
model.rollout_buffer.buffer_size = ep_length
model.rollout_buffer.n_envs = num_cpu
model.rollout_buffer.reset()
else:
print('initializing new policy')
model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length, batch_size=512, n_epochs=1, gamma=0.999)
model.learn(total_timesteps=(ep_length)*num_cpu, callback=checkpoint_callback)
if __name__ == '__main__':
run_save(sys.argv[1])
# all_saves = list(Path('session_4da05e87').glob('*.zip'))
# selected_saves = [Path('session_4da05e87/init')] + all_saves[:10] + all_saves[10:120:5] + all_saves[120:420:10]
# len(selected_saves)
# for idx, save in enumerate(selected_saves):