Skip to content

Commit

Permalink
refactor(stochastic): move training loop to StochasticPlanner.run method
Browse files Browse the repository at this point in the history
Signed-off-by: Thiago P. Bueno <[email protected]>
  • Loading branch information
thiagopbueno committed Mar 28, 2020
1 parent e39a4ff commit a75e52a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 111 deletions.
13 changes: 4 additions & 9 deletions tests/test_planner_straightline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import rddlgym

from tfplan.planners import DEFAULT_CONFIG, StraightLinePlanner
from tfplan.planners.stochastic import utils


BATCH_SIZE = 16
Expand Down Expand Up @@ -106,14 +107,6 @@ def test_get_batch_initial_state(planner):
assert batch_fluent.shape[0] == planner.compiler.batch_size


def test_get_noise_samples(planner):
# pylint: disable=protected-access
with tf.Session(graph=planner.compiler.graph) as sess:
samples_ = planner._get_noise_samples(sess)
assert planner.simulator.noise.dtype == samples_.dtype
assert planner.simulator.noise.shape.as_list() == list(samples_.shape)


def test_get_action(planner):
# pylint: disable=protected-access
env = rddlgym.make(planner.rddl, mode=rddlgym.GYM)
Expand All @@ -122,7 +115,9 @@ def test_get_action(planner):
sess.run(tf.global_variables_initializer())
state = env.observation_space.sample()
batch_state = planner._get_batch_initial_state(state)
samples = planner._get_noise_samples(sess)
samples = utils.evaluate_noise_samples_as_inputs(
sess, planner.simulator.samples
)
feed_dict = {
planner.initial_state: batch_state,
planner.simulator.noise: samples,
Expand Down
68 changes: 64 additions & 4 deletions tfplan/planners/stochastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# You should have received a copy of the GNU General Public License
# along with tf-plan. If not, see <http://www.gnu.org/licenses/>.

from collections import OrderedDict
# pylint: disable=missing-docstring

import abc
import collections
import os

import numpy as np
import tensorflow as tf
from tqdm import trange

from rddl2tf.compilers import ReparameterizationCompiler

Expand All @@ -34,6 +39,8 @@ class StochasticPlanner(Planner):
config (Dict[str, Any]): The planner config dict.
"""

__metaclass__ = abc.ABCMeta

def __init__(self, rddl, compiler_cls, config):
super().__init__(rddl, ReparameterizationCompiler, config)

Expand All @@ -45,9 +52,23 @@ def __init__(self, rddl, compiler_cls, config):
self.optimizer = None
self.grads_and_vars = None

self.avg_total_reward = None
self.loss = None

self.init_op = None
self.train_op = None

self.summaries = None

@abc.abstractmethod
def build(self):
"""Builds the planner."""
raise NotImplementedError

@abc.abstractmethod
def __call__(self, state, timestep):
raise NotImplementedError

def _build_init_ops(self):
self.init_op = tf.global_variables_initializer()

Expand All @@ -64,11 +85,11 @@ def _build_sequence_length_ops(self):
tf.reshape(self.steps_to_go, [1]), [self.batch_size]
)

def _build_optimization_ops(self, loss):
def _build_optimization_ops(self):
with tf.name_scope("optimization"):
self.optimizer = ActionOptimizer(self.config["optimization"])
self.optimizer.build()
self.grads_and_vars = self.optimizer.compute_gradients(loss)
self.grads_and_vars = self.optimizer.compute_gradients(self.loss)
self.train_op = self.optimizer.apply_gradients(self.grads_and_vars)

def _get_batch_initial_state(self, state):
Expand All @@ -85,10 +106,49 @@ def _get_batch_initial_state(self, state):
def _get_action(self, actions, feed_dict):
action_fluent_ordering = self.compiler.rddl.domain.action_fluent_ordering
actions = self._sess.run(actions, feed_dict=feed_dict)
action = OrderedDict(
action = collections.OrderedDict(
{
name: fluent[0][0]
for name, fluent in zip(action_fluent_ordering, actions)
}
)
return action

def run(self, state, timestep, feed_dict):
self._sess.run(self.init_op)

feed_dict = {
**feed_dict,
self.initial_state: self._get_batch_initial_state(state),
}

if self.summaries:
logdir = os.path.join(self.config.get("logdir"), f"timestep={timestep}")
writer = tf.compat.v1.summary.FileWriter(logdir)

run_id = self.config.get("run_id", 0)
pid = os.getpid()
position = run_id % self.config.get("num_workers", 1)
epochs = self.config["epochs"]
desc = f"(pid={pid}) Run #{run_id:<3d} / step={timestep:<3d}"

with trange(
epochs, desc=desc, unit="epoch", position=position, leave=False
) as t:

for step in t:
_, loss_, avg_total_reward_ = self._sess.run(
[self.train_op, self.loss, self.avg_total_reward],
feed_dict=feed_dict,
)

if self.summaries:
summary_ = self._sess.run(self.summaries, feed_dict=feed_dict)
writer.add_summary(summary_, step)

t.set_postfix(
loss=f"{loss_:10.4f}", avg_total_reward=f"{avg_total_reward_:10.4f}"
)

if self.summaries:
writer.close()
54 changes: 9 additions & 45 deletions tfplan/planners/stochastic/hindsight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
# pylint: disable=missing-docstring


import os

import tensorflow as tf
from tqdm import trange

from rddl2tf.compilers import ReparameterizationCompiler

Expand Down Expand Up @@ -67,10 +64,6 @@ def __init__(self, rddl, config):
self.writer = None
self.summaries = None

@property
def logdir(self):
return self.config.get("logdir") or f"/tmp/tfplan/hindsight/{self.rddl}"

def build(self):
with self.graph.as_default():
self._build_base_policy_ops()
Expand All @@ -80,7 +73,7 @@ def build(self):
self._build_sequence_length_ops()
self._build_trajectory_ops()
self._build_loss_ops()
self._build_optimization_ops(self.loss)
self._build_optimization_ops()
self._build_summary_ops()
self._build_init_ops()

Expand Down Expand Up @@ -137,12 +130,13 @@ def _build_loss_ops(self):
self.loss = tf.square(self.avg_total_reward)

def _build_summary_ops(self):
with tf.name_scope("summary"):
_ = tf.compat.v1.summary.FileWriter(self.logdir, self.graph)
tf.compat.v1.summary.scalar("avg_total_reward", self.avg_total_reward)
tf.compat.v1.summary.scalar("loss", self.loss)
if self.config["verbose"]:

with tf.name_scope("summary"):
_ = tf.compat.v1.summary.FileWriter(self.config["logdir"], self.graph)
tf.compat.v1.summary.scalar("avg_total_reward", self.avg_total_reward)
tf.compat.v1.summary.scalar("loss", self.loss)

if self.config["verbose"]:
tf.compat.v1.summary.histogram("reward", self.reward)
tf.compat.v1.summary.histogram(
"scenario_total_reward", self.scenario_total_reward
Expand All @@ -156,16 +150,9 @@ def _build_summary_ops(self):
tf.compat.v1.summary.histogram(f"{var_name}_grad", grad)
tf.compat.v1.summary.histogram(var_name, variable)

self.summaries = tf.compat.v1.summary.merge_all()
self.summaries = tf.compat.v1.summary.merge_all()

def __call__(self, state, timestep):
# pylint: disable=too-many-locals

logdir = os.path.join(self.logdir, f"timestep={timestep}")
self.writer = tf.compat.v1.summary.FileWriter(logdir)

self._sess.run(self.init_op)

next_state_noise = utils.evaluate_noise_samples_as_inputs(
self._sess, self.cell_samples
)
Expand All @@ -174,35 +161,12 @@ def __call__(self, state, timestep):
)

feed_dict = {
self.initial_state: self._get_batch_initial_state(state),
self.cell_noise: next_state_noise,
self.simulator.noise: scenario_noise,
self.steps_to_go: self.config["horizon"] - timestep - 1,
}

run_id = self.config.get("run_id", 0)
pid = os.getpid()
position = run_id % self.config.get("num_workers", 1)
epochs = self.config["epochs"]
desc = f"(pid={pid}) Run #{run_id:<3d} / step={timestep:<3d}"

with trange(
epochs, desc=desc, unit="epoch", position=position, leave=False
) as t:

for step in t:
_, loss_, avg_total_reward_, summary_ = self._sess.run(
[self.train_op, self.loss, self.avg_total_reward, self.summaries],
feed_dict=feed_dict,
)

self.writer.add_summary(summary_, step)

t.set_postfix(
loss=f"{loss_:10.4f}", avg_total_reward=f"{avg_total_reward_:10.4f}"
)

self.writer.close()
self.run(state, timestep, feed_dict)

action = self._get_action(self.action, feed_dict)
return action
66 changes: 13 additions & 53 deletions tfplan/planners/stochastic/straightline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
# pylint: disable=missing-docstring


# from collections import OrderedDict
import os

# import numpy as np
import tensorflow as tf
from tqdm import trange

from rddl2tf.compilers import ReparameterizationCompiler

Expand All @@ -30,8 +25,6 @@
from tfplan.planners.stochastic.simulation import Simulator
from tfplan.planners.stochastic import utils

# from tfplan.train.optimizer import ActionOptimizer


class StraightLinePlanner(StochasticPlanner):
"""StraightLinePlanner class implements the online gradient-based
Expand Down Expand Up @@ -61,18 +54,14 @@ def __init__(self, rddl, config):
self.writer = None
self.summaries = None

@property
def logdir(self):
return self.config.get("logdir") or f"/tmp/tfplan/straigthline/{self.rddl}"

def build(self,):
with self.graph.as_default():
self._build_policy_ops()
self._build_initial_state_ops()
self._build_sequence_length_ops()
self._build_trajectory_ops()
self._build_loss_ops()
self._build_optimization_ops(self.loss)
self._build_optimization_ops()
self._build_summary_ops()
self._build_init_ops()

Expand All @@ -97,12 +86,13 @@ def _build_loss_ops(self):
self.loss = tf.square(self.avg_total_reward)

def _build_summary_ops(self):
with tf.name_scope("summary"):
_ = tf.compat.v1.summary.FileWriter(self.logdir, self.graph)
tf.compat.v1.summary.scalar("avg_total_reward", self.avg_total_reward)
tf.compat.v1.summary.scalar("loss", self.loss)
if self.config["verbose"]:

with tf.name_scope("summary"):
_ = tf.compat.v1.summary.FileWriter(self.config["logdir"], self.graph)
tf.compat.v1.summary.scalar("avg_total_reward", self.avg_total_reward)
tf.compat.v1.summary.scalar("loss", self.loss)

if self.config["verbose"]:
tf.compat.v1.summary.histogram("total_reward", self.total_reward)
tf.compat.v1.summary.histogram("scenario_noise", self.simulator.noise)

Expand All @@ -111,49 +101,19 @@ def _build_summary_ops(self):
tf.compat.v1.summary.histogram(f"{var_name}_grad", grad)
tf.compat.v1.summary.histogram(var_name, variable)

self.summaries = tf.compat.v1.summary.merge_all()
self.summaries = tf.compat.v1.summary.merge_all()

def __call__(self, state, timestep):
# pylint: disable=too-many-locals

logdir = os.path.join(self.logdir, f"timestep={timestep}")
self.writer = tf.compat.v1.summary.FileWriter(logdir)

self._sess.run(self.init_op)

run_id = self.config.get("run_id", 0)
pid = os.getpid()
position = run_id % self.config.get("num_workers", 1)
epochs = self.config["epochs"]
desc = f"(pid={pid}) Run #{run_id:<3d} / step={timestep:<3d}"
scenario_noise = utils.evaluate_noise_samples_as_inputs(
self._sess, self.simulator.samples
)

feed_dict = {
self.initial_state: self._get_batch_initial_state(state),
self.simulator.noise: self._get_noise_samples(self._sess),
self.simulator.noise: scenario_noise,
self.steps_to_go: self.config["horizon"] - timestep,
}

with trange(
epochs, unit="epoch", desc=desc, position=position, leave=False
) as t:

for step in t:
_, loss_, avg_total_reward_, summary_ = self._sess.run(
[self.train_op, self.loss, self.avg_total_reward, self.summaries],
feed_dict=feed_dict,
)

self.writer.add_summary(summary_, step)

t.set_postfix(
loss=f"{loss_:10.4f}", avg_total_reward=f"{avg_total_reward_:10.4f}"
)

self.writer.close()
self.run(state, timestep, feed_dict)

action = self._get_action(self.trajectory.actions, feed_dict)
return action

def _get_noise_samples(self, sess):
samples = utils.evaluate_noise_samples_as_inputs(sess, self.simulator.samples)
return samples

0 comments on commit a75e52a

Please sign in to comment.