forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b54b536
commit 860b4e4
Showing
166 changed files
with
2,804 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Stochastic Ensemble Value Expansion | ||
|
||
*A hybrid model-based/model-free reinforcement learning algorithm for sample-efficient continuous control.* | ||
|
||
This is the code repository accompanying the paper Sample-Efficient Reinforcement Learning with | ||
Stochastic Ensemble Value Expansion, by Buckman et al. (2018). | ||
|
||
#### Abstract: | ||
Merging model-free and model-based approaches in reinforcement learning has the potential to achieve | ||
the high performance of model-free algorithms with low sample complexity. This is difficult because | ||
an imperfect dynamics model can degrade the performance of the learning algorithm, and in sufficiently | ||
complex environments, the dynamics model will always be imperfect. As a result, a key challenge is to | ||
combine model-based approaches with model-free learning in such a way that errors in the model do not | ||
degrade performance. We propose *stochastic ensemble value expansion* (STEVE), a novel model-based | ||
technique that addresses this issue. By dynamically interpolating between model rollouts of various horizon | ||
lengths for each individual example, STEVE ensures that the model is only utilized when doing so does not | ||
introduce significant errors. Our approach outperforms model-free baselines on challenging continuous | ||
control benchmarks with an order-of-magnitude increase in sample efficiency, and in contrast to previous | ||
model-based approaches, performance does not degrade as the environment gets more complex. | ||
|
||
## Installation | ||
This code is compatible with Ubuntu 16.04 and Python 2.7. There are several prerequisites: | ||
* Numpy, Scipy, and Portalocker: `pip install numpy scipy portalocker` | ||
* TensorFlow 1.6 or above. Instructions can be found on the official TensorFlow page: | ||
[https://www.tensorflow.org/install/install_linux](https://www.tensorflow.org/install/install_linux). | ||
We suggest installing the GPU version of TensorFlow to speed up training. | ||
* OpenAI Gym version 0.9.4. Instructions can be found in the OpenAI Gym repository: | ||
[https://github.com/openai/gym#installation](https://github.com/openai/gym#installation). | ||
Note that you need to replace "pip install gym[all]" with "pip install gym[all]==0.9.4", which | ||
will ensure that you get the correct version of Gym. (The current version of Gym has deprecated | ||
the -v1 MuJoCo environments, which are the environments studied in this paper.) | ||
* MuJoCo version 1.31, which can be downloaded here: [https://www.roboti.us/download/mjpro131_linux.zip](https://www.roboti.us/download/mjpro131_linux.zip). | ||
Simply run: ``` | ||
cd ~; mkdir -p .mujoco; cd .mujoco/; wget https://www.roboti.us/download/mjpro131_linux.zip; unzip mjpro131_linux.zip``` | ||
You also need to get a license, and put the license key in ~/.mujoco/ as well. | ||
* Optionally, Roboschool version 1.1. This is needed only to replicate the Roboschool experiments. | ||
Instructions can be found in the OpenAI Roboschool repository: | ||
[https://github.com/openai/roboschool#installation](https://github.com/openai/roboschool#installation). | ||
* Optionally, MoviePy to render trained agents. Instructions on the MoviePy homepage: | ||
[https://zulko.github.io/moviepy/install.html](https://zulko.github.io/moviepy/install.html). | ||
|
||
## Running Experiments | ||
To run an experiment, run master.py and pass in a config file and GPU ID. For example: ``` | ||
python master.py config/experiments/speedruns/humanoid/speedy_steve0.json 0``` | ||
The `config/experiments/` | ||
directory contains configuration files for all of the experiments run in the paper. | ||
|
||
The GPU ID specifies the GPU that should be used to learn the policy. For model-based approaches, the | ||
next GPU (i.e. GPU_ID+1) is used to learn the worldmodel in parallel. | ||
|
||
To resume an experiment that was interrupted, use the same config file and pass the `--resume` flag: ``` | ||
python master.py config/experiments/speedruns/humanoid/speedy_steve0.json 0 --resume``` | ||
|
||
## Output | ||
For each experiment, two folders are created in the output directory: `<ENVIRONMENT>/<EXPERIMENT>/log` | ||
and `<ENVIRONMENT>/<EXPERIMENT>/checkpoints`. The log directory contains the following: | ||
|
||
* `hps.json` contains the accumulated hyperparameters of the config file used to generate these results | ||
* `valuerl.log` and `worldmodel.log` contain the log output of the learners. `worldmodel.log` will not | ||
exist if you are not learning a worldmodel. | ||
* `<EXPERIMENT>.greedy.csv` records all of the scores of our evaluators. The four columns contain time (hours), | ||
epochs, frames, and score. | ||
|
||
The checkpoints directory contains the most recent versions of the policy and worldmodel, as well as checkpoints | ||
of the policy, worldmodel, and their respective replay buffers at various points throughout training. | ||
|
||
## Code Organization | ||
`master.py` launches four types of processes: a ValueRlLearner to learn the policy, a WorldmodelLearner | ||
to learn the dynamics model, several Interactors to gather data from the environment to train on, and | ||
a few Evaluators to run the greedy policy in the environment and record the score. | ||
|
||
`learner.py` contains a general framework for models which learn from a replay buffer. This is where | ||
most of the code for the overall training loop is located. `valuerl_learner.py` and `worldmodel_learner.py` | ||
contain a small amount of model-specific training loop code. | ||
|
||
`valuerl.py` implements the core model for all value-function-based policy learning techniques studied | ||
in the paper, including DDPG, MVE, STEVE, etc. Similarly, `worldmodel.py` contains the core model for | ||
our dynamics model and reward function. | ||
|
||
`replay.py` contains the code for the replay buffer. `nn.py`, `envwrap.py`, `config.py`, and `util.py` | ||
each contain various helper functions. | ||
|
||
`toy_demo.py` is a self-contained demo, written in numpy, that was used to generate the results for the | ||
toy examples in the first segment of the paper. | ||
|
||
`visualizer.py` is a utility script for loading trained policies and inspecting them. In addition to a | ||
config file and a GPU, it takes the filename of the model to load as a mandatory third argument. | ||
|
||
## Contact | ||
Please contact GitHub user buckman-google ([email protected]) with any questions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from __future__ import print_function | ||
from builtins import zip | ||
from builtins import range | ||
from builtins import object | ||
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import time, os, traceback, multiprocessing, portalocker | ||
|
||
import envwrap | ||
import valuerl | ||
import util | ||
from config import config | ||
|
||
|
||
def run_env(pipe): | ||
env = envwrap.get_env(config["env"]["name"]) | ||
reset = True | ||
while True: | ||
if reset is True: pipe.send(env.reset()) | ||
action = pipe.recv() | ||
obs, reward, done, reset = env.step(action) | ||
pipe.send((obs, reward, done, reset)) | ||
|
||
class AgentManager(object): | ||
""" | ||
Interact with the environment according to the learned policy, | ||
""" | ||
def __init__(self, proc_num, evaluation, policy_lock, batch_size, config): | ||
self.evaluation = evaluation | ||
self.policy_lock = policy_lock | ||
self.batch_size = batch_size | ||
self.config = config | ||
|
||
self.log_path = util.create_directory("%s/%s/%s/%s" % (config["output_root"], config["env"]["name"], config["name"], config["log_path"])) + "/%s" % config["name"] | ||
self.load_path = util.create_directory("%s/%s/%s/%s" % (config["output_root"], config["env"]["name"], config["name"], config["save_model_path"])) | ||
|
||
## placeholders for intermediate states (basis for rollout) | ||
self.obs_loader = tf.placeholder(tf.float32, [self.batch_size, np.prod(self.config["env"]["obs_dims"])]) | ||
|
||
## build model | ||
self.valuerl = valuerl.ValueRL(self.config["name"], self.config["env"], self.config["policy_config"]) | ||
self.policy_actions = self.valuerl.build_evalution_graph(self.obs_loader, mode="exploit" if self.evaluation else "explore") | ||
|
||
# interactors | ||
self.agent_pipes, self.agent_child_pipes = list(zip(*[multiprocessing.Pipe() for _ in range(self.batch_size)])) | ||
self.agents = [multiprocessing.Process(target=run_env, args=(self.agent_child_pipes[i],)) for i in range(self.batch_size)] | ||
for agent in self.agents: agent.start() | ||
self.obs = [pipe.recv() for pipe in self.agent_pipes] | ||
self.total_rewards = [0. for _ in self.agent_pipes] | ||
self.loaded_policy = False | ||
|
||
self.sess = tf.Session() | ||
self.sess.run(tf.global_variables_initializer()) | ||
|
||
self.rollout_i = 0 | ||
self.proc_num = proc_num | ||
self.epoch = -1 | ||
self.frame_total = 0 | ||
self.hours = 0. | ||
|
||
self.first = True | ||
|
||
def get_action(self, obs): | ||
if self.loaded_policy: | ||
all_actions = self.sess.run(self.policy_actions, feed_dict={self.obs_loader: obs}) | ||
all_actions = np.clip(all_actions, -1., 1.) | ||
return all_actions[:self.batch_size] | ||
else: | ||
return [self.get_random_action() for _ in range(obs.shape[0])] | ||
|
||
def get_random_action(self, *args, **kwargs): | ||
return np.random.random(self.config["env"]["action_dim"]) * 2 - 1 | ||
|
||
def step(self): | ||
actions = self.get_action(np.stack(self.obs)) | ||
self.first = False | ||
[pipe.send(action) for pipe, action in zip(self.agent_pipes, actions)] | ||
next_obs, rewards, dones, resets = list(zip(*[pipe.recv() for pipe in self.agent_pipes])) | ||
|
||
frames = list(zip(self.obs, next_obs, actions, rewards, dones)) | ||
|
||
self.obs = [o if resets[i] is False else self.agent_pipes[i].recv() for i, o in enumerate(next_obs)] | ||
|
||
for i, (t,r,reset) in enumerate(zip(self.total_rewards, rewards, resets)): | ||
if reset: | ||
self.total_rewards[i] = 0. | ||
if self.evaluation and self.loaded_policy: | ||
with portalocker.Lock(self.log_path+'.greedy.csv', mode="a") as f: f.write("%2f,%d,%d,%2f\n" % (self.hours, self.epoch, self.frame_total, t+r)) | ||
|
||
else: | ||
self.total_rewards[i] = t + r | ||
|
||
if self.evaluation and np.any(resets): self.reload() | ||
|
||
self.rollout_i += 1 | ||
return frames | ||
|
||
def reload(self): | ||
if not os.path.exists("%s/%s.params.index" % (self.load_path ,self.valuerl.saveid)): return False | ||
with self.policy_lock: | ||
self.valuerl.load(self.sess, self.load_path) | ||
self.epoch, self.frame_total, self.hours = self.sess.run([self.valuerl.epoch_n, self.valuerl.frame_n, self.valuerl.hours]) | ||
self.loaded_policy = True | ||
self.first = True | ||
return True | ||
|
||
def main(proc_num, evaluation, policy_replay_frame_queue, model_replay_frame_queue, policy_lock, config): | ||
try: | ||
np.random.seed((proc_num * int(time.time())) % (2 ** 32 - 1)) | ||
agentmanager = AgentManager(proc_num, evaluation, policy_lock, config["evaluator_config"]["batch_size"] if evaluation else config["agent_config"]["batch_size"], config) | ||
frame_i = 0 | ||
while True: | ||
new_frames = agentmanager.step() | ||
if not evaluation: | ||
policy_replay_frame_queue.put(new_frames) | ||
if model_replay_frame_queue is not None: model_replay_frame_queue.put(new_frames) | ||
if frame_i % config["agent_config"]["reload_every_n"] == 0: agentmanager.reload() | ||
frame_i += len(new_frames) | ||
|
||
except Exception as e: | ||
print('Caught exception in agent process %d' % proc_num) | ||
traceback.print_exc() | ||
print() | ||
try: | ||
for i in agentmanager.agents: i.join() | ||
except: | ||
pass | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import print_function | ||
from builtins import str | ||
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import argparse, json, util, traceback | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("config") | ||
parser.add_argument("root_gpu", type=int) | ||
parser.add_argument("--resume", action="store_true") | ||
args = parser.parse_args() | ||
|
||
config_loc = args.config | ||
config = util.ConfigDict(config_loc) | ||
|
||
config["name"] = config_loc.split("/")[-1][:-5] | ||
config["resume"] = args.resume | ||
|
||
cstr = str(config) | ||
|
||
def log_config(): | ||
HPS_PATH = util.create_directory("output/" + config["env"]["name"] + "/" + config["name"] + "/" + config["log_path"]) + "/hps.json" | ||
print("ROOT GPU: " + str(args.root_gpu) + "\n" + str(cstr)) | ||
with open(HPS_PATH, "w") as f: | ||
f.write("ROOT GPU: " + str(args.root_gpu) + "\n" + str(cstr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"inherits": ["config/core/basic.json"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"inherits": [ | ||
"config/core/basic.json", | ||
"config/core/model.json" | ||
], | ||
"updates":{ | ||
"policy_config": { | ||
"value_expansion": { | ||
"rollout_len": 3, | ||
"mean_k_return": true | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"inherits": [ | ||
"config/core/basic.json", | ||
"config/core/model.json" | ||
], | ||
"updates":{ | ||
"policy_config": { | ||
"value_expansion": { | ||
"rollout_len": 3, | ||
"tdk_trick": true | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"inherits": [ | ||
"config/core/basic.json", | ||
"config/core/model.json" | ||
], | ||
"updates":{ | ||
"policy_config": { | ||
"value_expansion": { | ||
"rollout_len": 3, | ||
"lambda_return": 0.25 | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"inherits": [ | ||
"config/core/basic.json", | ||
"config/core/model.json", | ||
"config/core/bayesian.json" | ||
], | ||
"updates":{ | ||
"policy_config": { | ||
"value_expansion": { | ||
"rollout_len": 3, | ||
"steve_reweight": true | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"inherits": [ | ||
"config/core/basic.json", | ||
"config/core/model.json", | ||
"config/core/bayesian.json" | ||
], | ||
"updates":{ | ||
"policy_config": { | ||
"value_expansion": { | ||
"rollout_len": 3, | ||
"steve_reweight": true, | ||
"covariances": true | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.