forked from DLR-RM/stable-baselines3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_deterministic.py
37 lines (32 loc) · 1.27 KB
/
test_deterministic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pytest
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise
N_STEPS_TRAINING = 500
SEED = 0
@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
else:
if algo == DQN:
env_id = "CartPole-v1"
kwargs.update({"learning_starts": 100, "target_update_interval": 100})
elif algo == PPO:
kwargs.update({"n_steps": 64, "n_epochs": 4})
for i in range(2):
model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards