Skip to content

Commit

Permalink
Allow step_mul to be specified per-step in the SC2 environment, remov…
Browse files Browse the repository at this point in the history
…e support for update_observations on step.

PiperOrigin-RevId: 212249716
  • Loading branch information
PySC2 Team authored and tewalds committed Sep 12, 2018
1 parent 23c4934 commit 5ab11a7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 125 deletions.
7 changes: 4 additions & 3 deletions pysc2/env/mock_sc2_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def reset(self):
self._episode_steps = 0
return self.step([None] * self._num_agents)

def step(self, actions, update_observation=True):
def step(self, actions, step_mul=None):
"""Returns `next_observation` modifying its `step_type` if necessary."""
del update_observation # ignored currently
del step_mul # ignored currently

if len(actions) != self._num_agents:
raise ValueError(
Expand Down Expand Up @@ -246,7 +246,8 @@ def save_replay(self, *args, **kwargs):
def _default_observation(self, obs_spec, agent_index):
"""Returns a mock observation from an SC2Env."""

response_observation = dummy_observation.Builder(obs_spec).build()
response_observation = dummy_observation.Builder(
obs_spec).game_loop(0).build()
features_ = self._features[agent_index]
observation = features_.transform_obs(response_observation)

Expand Down
52 changes: 14 additions & 38 deletions pysc2/env/sc2_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pysc2.lib import features
from pysc2.lib import metrics
from pysc2.lib import portspicker
from pysc2.lib import protocol
from pysc2.lib import renderer_human
from pysc2.lib import run_parallel
from pysc2.lib import stopwatch
Expand Down Expand Up @@ -438,14 +437,12 @@ def reset(self):
return self._observe()

@sw.decorate("step_env")
def step(self, actions, update_observation=None):
def step(self, actions, step_mul=None):
"""Apply actions, step the world forward, and return observations.
Args:
actions: A list of actions meeting the action spec, one per agent.
update_observation: A list of booleans, whether to retrieve a new
observation after this step, one per agent. **Note** that if the
game ends a new observation will be retrieved regardless.
step_mul: If specified, use this rather than the environment's default.
Returns:
A tuple of TimeStep namedtuples, one per agent.
Expand All @@ -460,34 +457,27 @@ def step(self, actions, update_observation=None):
self._controllers, self._features, self._obs, actions))

self._state = environment.StepType.MID
return self._step(update_observation)
return self._step(step_mul)

def _step(self, update_observation=None):
if self._controllers[0].status != protocol.Status.ended:
# It's currently possible for the game to enter the 'ended' Status
# during the call to act() - although it should only do this on a call
# to step(). We skip step when that happens (the episode has completed).
with self._metrics.measure_step_time(self._step_mul):
self._parallel.run((c.step, self._step_mul) for c in self._controllers)
def _step(self, step_mul=None):
step_mul_ = step_mul or self._step_mul
with self._metrics.measure_step_time(step_mul_):
self._parallel.run((c.step, step_mul_) for c in self._controllers)

return self._observe(update_observation)

def _observe(self, update_observation=None):
if update_observation is None:
update_observation = [True] * len(self._controllers)
return self._observe()

self._update_observations(update_observation)
def _observe(self):
with self._metrics.measure_observation_time():
self._obs = self._parallel.run(c.observe for c in self._controllers)
self._agent_obs = [f.transform_obs(o)
for f, o in zip(self._features, self._obs)]

# TODO(tewalds): How should we handle more than 2 agents and the case where
# the episode can end early for some agents?
outcome = [0] * self._num_agents
discount = self._discount
episode_complete = any(o.player_result for o in self._obs)
if episode_complete or self._controllers[0].status == protocol.Status.ended:
if not all(update_observation):
# The episode completed so we send new observations to everyone.
self._update_observations([not i for i in update_observation])

if episode_complete:
self._state = environment.StepType.LAST
discount = 0
for i, o in enumerate(self._obs):
Expand Down Expand Up @@ -542,20 +532,6 @@ def zero_on_first_step(value):
discount=zero_on_first_step(discount),
observation=o) for r, o in zip(reward, self._agent_obs))

def _update_observations(self, update_observation):
with self._metrics.measure_observation_time():
# Only retrieve the observation for an agent if it requests us to do so.
next_obs = self._parallel.run(
c.observe if observe else lambda: None
for c, observe in zip(self._controllers, update_observation))

# If a new observation was retrieved, transform it.
# Otherwise keep the previous observation.
for index, obs in enumerate(next_obs):
if update_observation[index]:
self._obs[index] = obs
self._agent_obs[index] = self._features[index].transform_obs(obs)

def send_chat_messages(self, messages):
"""Useful for logging messages into the replay."""
self._parallel.run(
Expand Down
30 changes: 0 additions & 30 deletions pysc2/lib/remote_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,34 +98,6 @@ def _valid_status(self, *args, **kwargs):
return decorator


def catch_game_end(func):
"""Decorator to handle 'Game has already ended' exceptions."""
@functools.wraps(func)
def _catch_game_end(self, *args, **kwargs):
"""Decorator to handle 'Game has already ended' exceptions."""
prev_status = self.status
try:
return func(self, *args, **kwargs)
except protocol.ProtocolError as protocol_error:
if prev_status == Status.in_game and (
"Game has already ended" in str(protocol_error)):
# It's currently possible for us to receive this error even though
# our previous status was in_game. This shouldn't happen according
# to the protocol. It does happen sometimes when we don't observe on
# every step (possibly also requiring us to be playing against a
# built-in bot). To work around the issue, we catch the exception
# and so let the client code continue.
logging.warning(
"Received a 'Game has already ended' error from SC2 whilst status "
"in_game. Suppressing the exception, returning None.")

return None
else:
raise

return _catch_game_end


class RemoteController(object):
"""Implements a python interface to interact with the SC2 binary.
Expand Down Expand Up @@ -236,15 +208,13 @@ def observe(self):
return self._client.send(observation=sc_pb.RequestObservation())

@valid_status(Status.in_game, Status.in_replay)
@catch_game_end
@sw.decorate
def step(self, count=1):
"""Step the engine forward by one (or more) step."""
return self._client.send(step=sc_pb.RequestStep(count=count))

@skip_status(Status.in_replay)
@valid_status(Status.in_game)
@catch_game_end
@sw.decorate
def actions(self, req_action):
"""Send a `sc_pb.RequestAction`, which may include multiple actions."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,86 +30,40 @@
)


class StepWithoutObserveTest(utils.TestCase):
class StepMulOverrideTest(utils.TestCase):

def test_returns_observation_on_first_step_despite_no_observe(self):
def test_returns_game_loop_zero_on_first_step_despite_override(self):
with sc2_env.SC2Env(
map_name="DefeatRoaches",
players=[sc2_env.Agent(sc2_env.Race.random)],
step_mul=1,
agent_interface_format=AGENT_INTERFACE_FORMAT) as env:
timestep = env.step(
actions=[actions.FUNCTIONS.no_op()],
update_observation=[False])
step_mul=1234)

self.assertEqual(
timestep[0].observation.game_loop[0],
0)

def test_returns_old_observation_when_no_observe(self):
def test_respects_override(self):
with sc2_env.SC2Env(
map_name="DefeatRoaches",
players=[sc2_env.Agent(sc2_env.Race.random)],
step_mul=1,
agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

for step in range(10):
observe = step % 3 == 0
expected_game_loop = 0
for delta in range(10):
timestep = env.step(
actions=[actions.FUNCTIONS.no_op()],
update_observation=[observe])
step_mul=delta)

expected_game_loop = 3 * (step // 3)
expected_game_loop += delta
self.assertEqual(
timestep[0].observation.game_loop[0],
expected_game_loop)

def test_respects_observe_parameter_per_player(self):
with sc2_env.SC2Env(
map_name="Simple64",
players=[
sc2_env.Agent(sc2_env.Race.random),
sc2_env.Agent(sc2_env.Race.random),
],
step_mul=1,
agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

for step in range(10):
observe = step % 3 == 0
timestep = env.step(
actions=[actions.FUNCTIONS.no_op()] * 2,
update_observation=[observe, True])

expected_game_loop = 3 * (step // 3)
self.assertEqual(
timestep[0].observation.game_loop[0],
expected_game_loop)

self.assertEqual(
timestep[1].observation.game_loop[0],
step)

def test_episode_ends_when_not_observing(self):
with sc2_env.SC2Env(
map_name="Simple64",
players=[
sc2_env.Agent(sc2_env.Race.random),
sc2_env.Bot(sc2_env.Race.random, sc2_env.Difficulty.cheat_insane)],
step_mul=1000,
agent_interface_format=AGENT_INTERFACE_FORMAT) as env:

ended = False
for _ in range(100):
timestep = env.step(
actions=[actions.FUNCTIONS.no_op()],
update_observation=[False])

if timestep[0].last():
ended = True
break

self.assertTrue(ended)


if __name__ == "__main__":
absltest.main()

0 comments on commit 5ab11a7

Please sign in to comment.