-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinf.py
115 lines (105 loc) · 4 KB
/
inf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import time
import gym
# import pybullet_envs
import numpy as np
import random
from tqdm import tqdm
import collections
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from rex_gym.envs.rex_gym_env import RexGymEnv
import sys
sys.path.append("..")
from rex_REDQ.train import ModHalfCheetahEnv, ModAntEnv, ModRexEnv
from spotmicro.spotmicro.spot_gym_env import spotGymEnv
import argparse
DEV = "cpu"
from train import ModHalfCheetahEnv, Agent_net
# ENV_NAME = "MinitaurBulletEnv-v0"
# ENV_NAME = "MinitaurAlternatingLegsEnv-v0"
# from DPG_trot import Agent_net
# ENV_NAME = "MinitaurTrottingEnv-v0"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Adding optional argument
parser.add_argument(
"-c",
"--comment",
help="Optional comment to differentiate experiments. This will be added to both save_path and run name in tensorboard",
default=""
)
parser.add_argument("-r", "--repeat_steps", help="Number of times to repeat each action", default=1, type=int)
parser.add_argument("-d", "--hidden_dim", help="Hidden dim for both actor and critic", default=64, type=int)
parser.add_argument("-e", "--env", help="which environment to use.",
choices=["rex", "ant", "halfcheetah", "spot"], required=True, )
# Read arguments from command line
args = parser.parse_args()
exp_name = f"PPO_{args.env}"
lr = None
hidden_dim = None
hidden_dim = args.hidden_dim
repeat_steps = args.repeat_steps
obs_shape = 11
if args.env == "rex":
env_fn = lambda render=False: ModRexEnv(hard_reset=False, render=render, terrain_id="plane")
elif args.env == "ant":
env_fn = lambda render=False: ModAntEnv(render)
elif args.env == "halfcheetah":
env_fn = lambda render=False: ModHalfCheetahEnv(render)
elif args.env == "spot":
env_fn = lambda render=False: spotGymEnv(hard_reset=False, render=render)
print("###############################################")
print(f"Experiment Name: {exp_name}")
print(f"Hidden dim: {hidden_dim}")
print(f"Using env : {args.env}")
print(f"Repeat steps : {args.repeat_steps}")
print("###############################################")
env = env_fn(render=True)
agent = Agent_net(
env.observation_space.shape[0] + obs_shape,
env.action_space.shape[0],
hidden_dim,
dev=DEV
).to(DEV)
load_path = f'model_{args.env}'
if args.comment:
load_path = f'{load_path}_{args.comment}'
state_dict = torch.load(f'{load_path}/model_{hidden_dim}.pt', map_location=DEV)
agent.load_state_dict(state_dict=state_dict['actor_net'])
# env.rex._pybullet_client.resetBasePositionAndOrientation(
# env.rex.quadruped, [0.5, 0, 0.21], [0, 0, 0, 1])
# control = [0]
num_steps = 0
with torch.no_grad():
agent.logstd.zero_()
for c in range(4):
print(c)
control_one_hot = np.zeros(10)
control_one_hot[c] = 1.
is_done = False
s = env.reset()
env.render()
num_steps = 0
while not is_done:
s = np.append(s, [*control_one_hot, 0.22])
a = agent(torch.tensor(s, dtype=torch.float32).unsqueeze(0).to(DEV), noise=False)
a = a.squeeze().cpu().numpy()
# a = np.random.randn(env.action_space.shape[0]) / 1
# a[0::3] = -0.5
# print(a)
for i in range(args.repeat_steps):
if not is_done:
s, r, is_done, _ = env.step(a, c, 0.22)
# print(env.env.get_body_com("torso")[:3].copy())
env.render()
num_steps += 1
if num_steps > 500:
print("done 500 steps")
break
time.sleep(0.01)
print("resetting")
input("Press any key to exit\n")
env.close()