From 3fd61df0e51430531d4ec306bea62776d2657706 Mon Sep 17 00:00:00 2001 From: tewalds Date: Tue, 29 May 2018 17:55:40 +0100 Subject: [PATCH] Rename the current way of playing vs humans, and add a new way that plays smoother. This new way runs an instance on either end, and tunnels the LAN play over ssh tunnels and a udp proxy. PiperOrigin-RevId: 198411194 --- pysc2/bin/agent_remote.py | 229 +++++++++++++++++++++++++ pysc2/bin/play_vs_agent.py | 242 +++++++++++++++++--------- pysc2/env/lan_sc2_env.py | 331 ++++++++++++++++++++++++++++++++++++ pysc2/env/remote_sc2_env.py | 7 +- setup.py | 1 + 5 files changed, 722 insertions(+), 88 deletions(-) create mode 100644 pysc2/bin/agent_remote.py create mode 100644 pysc2/env/lan_sc2_env.py diff --git a/pysc2/bin/agent_remote.py b/pysc2/bin/agent_remote.py new file mode 100644 index 000000000..9eb5f0500 --- /dev/null +++ b/pysc2/bin/agent_remote.py @@ -0,0 +1,229 @@ +#!/usr/bin/python +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Play an agent with an SC2 instance that isn't owned. + +This can be used to play on the sc2ai.net ladder, as well as to play vs humans. + +To play on ladder: + $ python -m pysc2.bin.agent_remote --agent \ + --host_port --lan_port + +To play vs humans: + $ python -m pysc2.bin.agent_remote --human --map +then copy the string it generates which is something similar to above + +If you want to play remotely, you'll need to port forward (eg with ssh -L or -R) +the host_port from localhost on one machine to localhost on the other. + +You can also set your race, observation options, etc by cmdline flags. + +When playing vs humans it launches both instances on the human side. This means +you only need to port-forward a single port (ie the websocket betwen SC2 and the +agent), but you also need to transfer the entire observation, which is much +bigger than the actions transferred over the lan connection between the two SC2 +instances. It also makes it easy to maintain version compatibility since they +are the same binary. Unfortunately it means higher cpu usage where the human is +playing, which on a Mac becomes problematic as OSX slows down the instance +running in the background. There can also be observation differences between +Mac/Win and Linux. For these reasons, prefer play_vs_agent which runs the +instance next to the agent, and tunnels the lan actions instead. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +from absl import logging +import platform +import time + +from absl import app +from absl import flags +import portpicker + +from pysc2 import maps +from pysc2 import run_configs +from pysc2.env import remote_sc2_env +from pysc2.env import run_loop +from pysc2.env import sc2_env +from pysc2.lib import renderer_human + +from s2clientprotocol import sc2api_pb2 as sc_pb + +FLAGS = flags.FLAGS +flags.DEFINE_bool("render", platform.system() == "Linux", + "Whether to render with pygame.") +flags.DEFINE_bool("realtime", False, "Whether to run in realtime mode.") + +flags.DEFINE_string("agent", "pysc2.agents.random_agent.RandomAgent", + "Which agent to run, as a python path to an Agent class.") +flags.DEFINE_enum("agent_race", "random", sc2_env.Race._member_names_, # pylint: disable=protected-access + "Agent's race.") + +flags.DEFINE_float("fps", 22.4, "Frames per second to run the game.") +flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.") + +flags.DEFINE_integer("feature_screen_size", 84, + "Resolution for screen feature layers.") +flags.DEFINE_integer("feature_minimap_size", 64, + "Resolution for minimap feature layers.") +flags.DEFINE_integer("rgb_screen_size", 256, + "Resolution for rendered screen.") +flags.DEFINE_integer("rgb_minimap_size", 128, + "Resolution for rendered minimap.") +flags.DEFINE_enum("action_space", "FEATURES", + sc2_env.ActionSpace._member_names_, # pylint: disable=protected-access + "Which action space to use. Needed if you take both feature " + "and rgb observations.") +flags.DEFINE_bool("use_feature_units", False, + "Whether to include feature units.") + +flags.DEFINE_enum("user_race", "random", sc2_env.Race._member_names_, # pylint: disable=protected-access + "User's race.") + +flags.DEFINE_string("host", "127.0.0.1", "Game Host") +flags.DEFINE_integer("host_port", None, "Host port") +flags.DEFINE_integer("lan_port", None, "Host port") + +flags.DEFINE_string("map", None, "Name of a map to use to play.") + +flags.DEFINE_bool("human", False, "Whether to host a game as a human.") + + +def main(unused_argv): + if FLAGS.human: + human() + else: + agent() + + +def agent(): + """Run the agent, connecting to a (remote) host started independently.""" + agent_module, agent_name = FLAGS.agent.rsplit(".", 1) + agent_cls = getattr(importlib.import_module(agent_module), agent_name) + + logging.info("Starting agent:") + with remote_sc2_env.RemoteSC2Env( + map_name=FLAGS.map, + host=FLAGS.host, + host_port=FLAGS.host_port, + lan_port=FLAGS.lan_port, + race=sc2_env.Race[FLAGS.agent_race], + step_mul=FLAGS.step_mul, + agent_interface_format=sc2_env.parse_agent_interface_format( + feature_screen=FLAGS.feature_screen_size, + feature_minimap=FLAGS.feature_minimap_size, + rgb_screen=FLAGS.rgb_screen_size, + rgb_minimap=FLAGS.rgb_minimap_size, + action_space=FLAGS.action_space, + use_feature_units=FLAGS.use_feature_units), + visualize=FLAGS.render) as env: + agents = [agent_cls()] + logging.info("Connected, starting run_loop.") + try: + run_loop.run_loop(agents, env) + except remote_sc2_env.RestartException: + pass + logging.info("Done.") + + +def human(): + """Run a host which expects one player to connect remotely.""" + run_config = run_configs.get() + + map_inst = maps.get(FLAGS.map) + + if not FLAGS.rgb_screen_size or not FLAGS.rgb_minimap_size: + logging.info("Use --rgb_screen_size and --rgb_minimap_size if you want rgb " + "observations.") + + while True: + start_port = portpicker.pick_unused_port() + ports = [start_port + p for p in range(4)] # 2 * num_players + if all(portpicker.is_port_free(p) for p in ports): + break + + host_proc = run_config.start(extra_ports=ports, host=FLAGS.host, + timeout_seconds=300, window_loc=(50, 50)) + client_proc = run_config.start(extra_ports=ports, host=FLAGS.host, + connect=False, window_loc=(700, 50)) + + create = sc_pb.RequestCreateGame( + realtime=FLAGS.realtime, local_map=sc_pb.LocalMap(map_path=map_inst.path)) + create.player_setup.add(type=sc_pb.Participant) + create.player_setup.add(type=sc_pb.Participant) + + controller = host_proc.controller + controller.save_map(map_inst.path, map_inst.data(run_config)) + controller.create_game(create) + + print("-" * 80) + print("Join host: play_vs_agent --map %s --host %s --host_port %s " + "--lan_port %s" % (FLAGS.map, FLAGS.host, client_proc.port, start_port)) + print("-" * 80) + + join = sc_pb.RequestJoinGame() + join.shared_port = 0 # unused + join.server_ports.game_port = ports.pop(0) + join.server_ports.base_port = ports.pop(0) + join.client_ports.add(game_port=ports.pop(0), base_port=ports.pop(0)) + + join.race = sc2_env.Race[FLAGS.user_race] + if FLAGS.render: + join.options.raw = True + join.options.score = True + if FLAGS.feature_screen_size and FLAGS.feature_minimap_size: + fl = join.options.feature_layer + fl.width = 24 + fl.resolution.x = FLAGS.feature_screen_size + fl.resolution.y = FLAGS.feature_screen_size + fl.minimap_resolution.x = FLAGS.feature_minimap_size + fl.minimap_resolution.y = FLAGS.feature_minimap_size + if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size: + join.options.render.resolution.x = FLAGS.rgb_screen_size + join.options.render.resolution.y = FLAGS.rgb_screen_size + join.options.render.minimap_resolution.x = FLAGS.rgb_minimap_size + join.options.render.minimap_resolution.y = FLAGS.rgb_minimap_size + controller.join_game(join) + + if FLAGS.render: + renderer = renderer_human.RendererHuman( + fps=FLAGS.fps, render_feature_grid=False) + renderer.run(run_configs.get(), controller, max_episodes=1) + else: # Still step forward so the Mac/Windows renderer works. + try: + while True: + frame_start_time = time.time() + if not FLAGS.realtime: + controller.step() + obs = controller.observe() + + if obs.player_result: + break + time.sleep(max(0, frame_start_time - time.time() + 1 / FLAGS.fps)) + except KeyboardInterrupt: + pass + + for p in [host_proc, client_proc]: + p.close() + + +def entry_point(): # Needed so setup.py scripts work. + app.run(main) + + +if __name__ == "__main__": + app.run(main) diff --git a/pysc2/bin/play_vs_agent.py b/pysc2/bin/play_vs_agent.py index 89597632e..840a2e86b 100644 --- a/pysc2/bin/play_vs_agent.py +++ b/pysc2/bin/play_vs_agent.py @@ -12,7 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Play as a human against an agent.""" +"""Play as a human against an agent by setting up a LAN game. + +This needs to be called twice, once for the human, and once for the agent. + +The human plays on the host. There you run it as: +$ python -m pysc2.bin.play_vs_agent --human --map --remote + +And on the machine the agent plays on: +$ python -m pysc2.bin.play_vs_agent --agent + +The `--remote` arg is used to create an SSH tunnel to the remote agent's +machine, so can be dropped if it's running on the same machine. + +SC2 is limited to only allow LAN games on localhost, so we need to forward the +ports between machines. SSH is used to do this with the `--remote` arg. If the +agent is on the same machine as the host, this arg can be dropped. SSH doesn't +forward UDP, so this also sets up a UDP proxy. As part of that it sets up a TCP +server that is also used as a settings server. Note that you won't have an +opportunity to give ssh a password, so you must use ssh keys for authentication. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +40,7 @@ import importlib from absl import logging import platform +import sys import time from absl import app @@ -29,7 +49,7 @@ from pysc2 import maps from pysc2 import run_configs -from pysc2.env import remote_sc2_env +from pysc2.env import lan_sc2_env from pysc2.env import run_loop from pysc2.env import sc2_env from pysc2.lib import renderer_human @@ -67,20 +87,26 @@ flags.DEFINE_enum("user_race", "random", sc2_env.Race._member_names_, # pylint: disable=protected-access "User's race.") -flags.DEFINE_string("host", "127.0.0.1", "Game Host") -flags.DEFINE_integer("host_port", None, "Host port") -flags.DEFINE_integer("lan_port", None, "Host port") +flags.DEFINE_string("host", "127.0.0.1", "Game Host. Can be 127.0.0.1 or ::1") +flags.DEFINE_integer( + "config_port", 14380, + "Where to set/find the config port. The host starts a tcp server to share " + "the config with the client, and to proxy udp traffic if played over an " + "ssh tunnel. This sets that port, and is also the start of the range of " + "ports used for LAN play.") +flags.DEFINE_string("remote", None, + "Where to set up the ssh tunnels to the client.") flags.DEFINE_string("map", None, "Name of a map to use to play.") +flags.DEFINE_bool("human", False, "Whether to host a game as a human.") -def main(unused_argv): - """Run SC2 to play a game or a replay.""" - if FLAGS.host_port: - agent() +def main(unused_argv): + if FLAGS.human: + human() else: - host() + agent() def agent(): @@ -89,11 +115,9 @@ def agent(): agent_cls = getattr(importlib.import_module(agent_module), agent_name) logging.info("Starting agent:") - with remote_sc2_env.RemoteSC2Env( - map_name=FLAGS.map, + with lan_sc2_env.LanSC2Env( host=FLAGS.host, - host_port=FLAGS.host_port, - lan_port=FLAGS.lan_port, + config_port=FLAGS.config_port, race=sc2_env.Race[FLAGS.agent_race], step_mul=FLAGS.step_mul, agent_interface_format=sc2_env.parse_agent_interface_format( @@ -108,12 +132,12 @@ def agent(): logging.info("Connected, starting run_loop.") try: run_loop.run_loop(agents, env) - except remote_sc2_env.RestartException: + except lan_sc2_env.RestartException: pass logging.info("Done.") -def host(): +def human(): """Run a host which expects one player to connect remotely.""" run_config = run_configs.get() @@ -123,75 +147,125 @@ def host(): logging.info("Use --rgb_screen_size and --rgb_minimap_size if you want rgb " "observations.") - while True: - start_port = portpicker.pick_unused_port() - ports = [start_port + p for p in range(4)] # 2 * num_players - if all(portpicker.is_port_free(p) for p in ports): - break - - host_proc = run_config.start(extra_ports=ports, host=FLAGS.host, - timeout_seconds=300, window_loc=(50, 50)) - client_proc = run_config.start(extra_ports=ports, host=FLAGS.host, - connect=False, window_loc=(700, 50)) - - create = sc_pb.RequestCreateGame( - realtime=FLAGS.realtime, local_map=sc_pb.LocalMap(map_path=map_inst.path)) - create.player_setup.add(type=sc_pb.Participant) - create.player_setup.add(type=sc_pb.Participant) - - controller = host_proc.controller - controller.save_map(map_inst.path, map_inst.data(run_config)) - controller.create_game(create) - - print("-" * 80) - print("Join host: play_vs_agent --map %s --host %s --host_port %s " - "--lan_port %s" % (FLAGS.map, FLAGS.host, client_proc.port, start_port)) - print("-" * 80) - - join = sc_pb.RequestJoinGame() - join.shared_port = 0 # unused - join.server_ports.game_port = ports.pop(0) - join.server_ports.base_port = ports.pop(0) - join.client_ports.add(game_port=ports.pop(0), base_port=ports.pop(0)) - - join.race = sc2_env.Race[FLAGS.user_race] - if FLAGS.render: - join.options.raw = True - join.options.score = True - if FLAGS.feature_screen_size and FLAGS.feature_minimap_size: - fl = join.options.feature_layer - fl.width = 24 - fl.resolution.x = FLAGS.feature_screen_size - fl.resolution.y = FLAGS.feature_screen_size - fl.minimap_resolution.x = FLAGS.feature_minimap_size - fl.minimap_resolution.y = FLAGS.feature_minimap_size - if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size: - join.options.render.resolution.x = FLAGS.rgb_screen_size - join.options.render.resolution.y = FLAGS.rgb_screen_size - join.options.render.minimap_resolution.x = FLAGS.rgb_minimap_size - join.options.render.minimap_resolution.y = FLAGS.rgb_minimap_size - controller.join_game(join) - - if FLAGS.render: - renderer = renderer_human.RendererHuman( - fps=FLAGS.fps, render_feature_grid=False) - renderer.run(run_configs.get(), controller, max_episodes=1) - else: # Still step forward so the Mac/Windows renderer works. - try: - while True: - frame_start_time = time.time() - if not FLAGS.realtime: - controller.step() - obs = controller.observe() - - if obs.player_result: + ports = [FLAGS.config_port + p for p in range(5)] # tcp + 2 * num_players + if not all(portpicker.is_port_free(p) for p in ports): + sys.exit("Need 5 free ports after the config port.") + + proc = None + ssh_proc = None + tcp_conn = None + udp_sock = None + try: + proc = run_config.start(extra_ports=ports[1:], timeout_seconds=300, + host=FLAGS.host, window_loc=(50, 50)) + + tcp_port = ports[0] + settings = { + "remote": FLAGS.remote, + "game_version": proc.version.game_version, + "realtime": FLAGS.realtime, + "map_name": map_inst.name, + "map_path": map_inst.path, + "map_data": map_inst.data(run_config), + "ports": { + "server": {"game": ports[1], "base": ports[2]}, + "client": {"game": ports[3], "base": ports[4]}, + } + } + + create = sc_pb.RequestCreateGame( + realtime=settings["realtime"], + local_map=sc_pb.LocalMap(map_path=settings["map_path"])) + create.player_setup.add(type=sc_pb.Participant) + create.player_setup.add(type=sc_pb.Participant) + + controller = proc.controller + controller.save_map(settings["map_path"], settings["map_data"]) + controller.create_game(create) + + if FLAGS.remote: + ssh_proc = lan_sc2_env.forward_ports( + FLAGS.remote, proc.host, [settings["ports"]["client"]["base"]], + [tcp_port, settings["ports"]["server"]["base"]]) + + print("-" * 80) + print("Join: play_vs_agent --host %s --config_port %s" % (proc.host, + tcp_port)) + print("-" * 80) + + tcp_conn = lan_sc2_env.tcp_server( + lan_sc2_env.Addr(proc.host, tcp_port), settings) + + if FLAGS.remote: + udp_sock = lan_sc2_env.udp_server( + lan_sc2_env.Addr(proc.host, settings["ports"]["client"]["game"])) + + lan_sc2_env.daemon_thread( + lan_sc2_env.tcp_to_udp, + (tcp_conn, udp_sock, + lan_sc2_env.Addr(proc.host, settings["ports"]["server"]["game"]))) + + lan_sc2_env.daemon_thread(lan_sc2_env.udp_to_tcp, (udp_sock, tcp_conn)) + + join = sc_pb.RequestJoinGame() + join.shared_port = 0 # unused + join.server_ports.game_port = settings["ports"]["server"]["game"] + join.server_ports.base_port = settings["ports"]["server"]["base"] + join.client_ports.add(game_port=settings["ports"]["client"]["game"], + base_port=settings["ports"]["client"]["base"]) + + join.race = sc2_env.Race[FLAGS.user_race] + if FLAGS.render: + join.options.raw = True + join.options.score = True + if FLAGS.feature_screen_size and FLAGS.feature_minimap_size: + fl = join.options.feature_layer + fl.width = 24 + fl.resolution.x = FLAGS.feature_screen_size + fl.resolution.y = FLAGS.feature_screen_size + fl.minimap_resolution.x = FLAGS.feature_minimap_size + fl.minimap_resolution.y = FLAGS.feature_minimap_size + if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size: + join.options.render.resolution.x = FLAGS.rgb_screen_size + join.options.render.resolution.y = FLAGS.rgb_screen_size + join.options.render.minimap_resolution.x = FLAGS.rgb_minimap_size + join.options.render.minimap_resolution.y = FLAGS.rgb_minimap_size + controller.join_game(join) + + if FLAGS.render: + renderer = renderer_human.RendererHuman( + fps=FLAGS.fps, render_feature_grid=False) + renderer.run(run_configs.get(), controller, max_episodes=1) + else: # Still step forward so the Mac/Windows renderer works. + try: + while True: + frame_start_time = time.time() + if not FLAGS.realtime: + controller.step() + obs = controller.observe() + + if obs.player_result: + break + time.sleep(max(0, frame_start_time - time.time() + 1 / FLAGS.fps)) + except KeyboardInterrupt: + pass + + finally: + if tcp_conn: + tcp_conn.close() + if proc: + proc.close() + if udp_sock: + udp_sock.close() + if ssh_proc: + ssh_proc.terminate() + for _ in range(5): + if ssh_proc.poll() is not None: break - time.sleep(max(0, frame_start_time - time.time() + 1 / FLAGS.fps)) - except KeyboardInterrupt: - pass - - for p in [host_proc, client_proc]: - p.close() + time.sleep(1) + if ssh_proc.poll() is None: + ssh_proc.kill() + ssh_proc.wait() def entry_point(): # Needed so setup.py scripts work. diff --git a/pysc2/env/lan_sc2_env.py b/pysc2/env/lan_sc2_env.py new file mode 100644 index 000000000..e6e79887e --- /dev/null +++ b/pysc2/env/lan_sc2_env.py @@ -0,0 +1,331 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Starcraft II environment for playing LAN games vs humans. + +Check pysc2/bin/play_vs_agent.py for documentation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import binascii +import collections +import hashlib +import json +from absl import logging +import os +import socket +import struct +import subprocess +import threading + +from pysc2 import run_configs +from pysc2.env import sc2_env +from pysc2.lib import run_parallel +import whichcraft + +from s2clientprotocol import sc2api_pb2 as sc_pb + + +class Addr(collections.namedtuple("Addr", ["ip", "port"])): + + def __str__(self): + ip = "[%s]" % self.ip if ":" in self.ip else self.ip + return "%s:%s" % (ip, self.port) + + +def daemon_thread(target, args): + t = threading.Thread(target=target, args=args) + t.daemon = True + t.start() + return t + + +def udp_server(addr): + family = socket.AF_INET6 if ":" in addr.ip else socket.AF_INET + sock = socket.socket(family, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sock.bind(addr) + return sock + + +def tcp_server(tcp_addr, settings): + """Start up the tcp server, send the settings.""" + family = socket.AF_INET6 if ":" in tcp_addr.ip else socket.AF_INET + sock = socket.socket(family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + sock.bind(tcp_addr) + sock.listen(1) + logging.info("Waiting for connection on %s", tcp_addr) + conn, addr = sock.accept() + logging.info("Accepted connection from %s", Addr(*addr)) + + # Send map_data independently for py2/3 and json encoding reasons. + write_tcp(conn, settings["map_data"]) + send_settings = {k: v for k, v in settings.items() if k != "map_data"} + logging.debug("settings: %s", send_settings) + write_tcp(conn, json.dumps(send_settings).encode()) + return conn + + +def tcp_client(tcp_addr): + """Connect to the tcp server, and return the settings.""" + family = socket.AF_INET6 if ":" in tcp_addr.ip else socket.AF_INET + sock = socket.socket(family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + logging.info("Connecting to: %s", tcp_addr) + sock.connect(tcp_addr) + logging.info("Connected.") + + map_data = read_tcp(sock) + settings_str = read_tcp(sock) + if not settings_str: + raise socket.error("Failed to read") + settings = json.loads(settings_str.decode()) + logging.info("Got settings. map_name: %s.", settings["map_name"]) + logging.debug("settings: %s", settings) + settings["map_data"] = map_data + return sock, settings + + +def log_msg(prefix, msg): + logging.debug("%s: len: %s, hash: %s, msg: 0x%s", prefix, len(msg), + hashlib.md5(msg).hexdigest()[:6], binascii.hexlify(msg[:25])) + + +def udp_to_tcp(udp_sock, tcp_conn): + while True: + msg, _ = udp_sock.recvfrom(2**16) + log_msg("read_udp", msg) + if not msg: + return + write_tcp(tcp_conn, msg) + + +def tcp_to_udp(tcp_conn, udp_sock, udp_to_addr): + while True: + msg = read_tcp(tcp_conn) + if not msg: + return + log_msg("write_udp", msg) + udp_sock.sendto(msg, udp_to_addr) + + +def read_tcp(conn): + read_size = read_tcp_size(conn, 4) + if not read_size: + return + size = struct.unpack("@I", read_size)[0] + msg = read_tcp_size(conn, size) + log_msg("read_tcp", msg) + return msg + + +def read_tcp_size(conn, size): + """Read `size` number of bytes from `conn`, retrying as needed.""" + chunks = [] + bytes_read = 0 + while bytes_read < size: + chunk = conn.recv(size - bytes_read) + if not chunk: + if bytes_read > 0: + logging.warning("Incomplete read: %s of %s.", bytes_read, size) + return + chunks.append(chunk) + bytes_read += len(chunk) + return b"".join(chunks) + + +def write_tcp(conn, msg): + log_msg("write_tcp", msg) + conn.sendall(struct.pack("@I", len(msg))) + conn.sendall(msg) + + +def forward_ports(remote_host, local_host, local_listen_ports, + remote_listen_ports): + """Forwards ports such that multiplayer works between machines. + + Args: + remote_host: Where to ssh to. + local_host: "127.0.0.1" or "::1". + local_listen_ports: Which ports to listen on locally to forward remotely. + remote_listen_ports: Which ports to listen on remotely to forward locally. + + Returns: + The ssh process. + + Raises: + ValueError: if it can't find ssh. + """ + if ":" in local_host and not local_host.startswith("["): + local_host = "[%s]" % local_host + + ssh = whichcraft.which("ssh") or whichcraft.which("plink") + if not ssh: + raise ValueError("Couldn't find an ssh client.") + + args = [ssh, remote_host] + for local_port in local_listen_ports: + args += ["-L", "%s:%s:%s:%s" % (local_host, local_port, + local_host, local_port)] + for remote_port in remote_listen_ports: + args += ["-R", "%s:%s:%s:%s" % (local_host, remote_port, + local_host, remote_port)] + + logging.info("SSH port forwarding: %s", " ".join(args)) + return subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + stdin=subprocess.PIPE, close_fds=(os.name == "posix")) + + +class RestartException(Exception): + pass + + +class LanSC2Env(sc2_env.SC2Env): + """A Starcraft II environment for playing vs humans over LAN. + + This owns a single instance, and expects to join a game hosted by some other + script, likely play_vs_agent.py. + """ + + def __init__(self, # pylint: disable=invalid-name + _only_use_kwargs=None, + host="127.0.0.1", + config_port=None, + race=None, + agent_interface_format=None, + discount=1., + visualize=False, + step_mul=None, + replay_dir=None): + """Create a SC2 Env that connects to a remote instance of the game. + + This assumes that the game is already up and running, and it only needs to + join. You need some other script to launch the process and call + RequestCreateGame. It also assumes that it's a multiplayer game, and that + the ports are consecutive. + + You must pass a resolution that you want to play at. You can send either + feature layer resolution or rgb resolution or both. If you send both you + must also choose which to use as your action space. Regardless of which you + choose you must send both the screen and minimap resolutions. + + For each of the 4 resolutions, either specify size or both width and + height. If you specify size then both width and height will take that value. + + Args: + _only_use_kwargs: Don't pass args, only kwargs. + host: Which ip to use. Either ipv4 or ipv6 localhost. + config_port: Where to find the config port. + race: Race for this agent. + agent_interface_format: AgentInterfaceFormat object describing the + format of communication between the agent and the environment. + discount: Returned as part of the observation. + visualize: Whether to pop up a window showing the camera and feature + layers. This won't work without access to a window manager. + step_mul: How many game steps per agent step (action/observation). None + means use the map default. + replay_dir: Directory to save a replay. + + Raises: + ValueError: if the race is invalid. + ValueError: if the resolutions aren't specified correctly. + ValueError: if the host or port are invalid. + """ + if _only_use_kwargs: + raise ValueError("All arguments must be passed as keyword arguments.") + + if host not in ("127.0.0.1", "::1"): + raise ValueError("Bad host arguments. Must be a localhost") + if not config_port: + raise ValueError("Must pass a config_port.") + + if agent_interface_format is None: + raise ValueError("Please specify agent_interface_format.") + + if not race: + race = sc2_env.Race.random + + self._num_agents = 1 + self._discount = discount + self._step_mul = step_mul or 8 + self._save_replay_episodes = 1 if replay_dir else 0 + self._replay_dir = replay_dir + + self._score_index = -1 # Win/loss only. + self._score_multiplier = 1 + self._episode_length = 0 # No limit. + + self._run_config = run_configs.get() + self._parallel = run_parallel.RunParallel() # Needed for multiplayer. + + interface = self._get_interface( + agent_interface_format=agent_interface_format, require_raw=visualize) + + self._launch_remote(host, config_port, race, interface) + + self._finalize([agent_interface_format], [interface], visualize) + + def _launch_remote(self, host, config_port, race, interface): + """Make sure this stays synced with bin/play_vs_agent.py.""" + self._tcp_conn, settings = tcp_client(Addr(host, config_port)) + + self._map_name = settings["map_name"] + + if settings["remote"]: + self._udp_sock = udp_server( + Addr(host, settings["ports"]["server"]["game"])) + + daemon_thread(tcp_to_udp, + (self._tcp_conn, self._udp_sock, + Addr(host, settings["ports"]["client"]["game"]))) + + daemon_thread(udp_to_tcp, (self._udp_sock, self._tcp_conn)) + + extra_ports = [ + settings["ports"]["server"]["game"], + settings["ports"]["server"]["base"], + settings["ports"]["client"]["game"], + settings["ports"]["client"]["base"], + ] + + self._sc2_procs = [self._run_config.start( + extra_ports=extra_ports, host=host, version=settings["game_version"], + window_loc=(700, 50))] + self._controllers = [p.controller for p in self._sc2_procs] + + # Create the join request. + join = sc_pb.RequestJoinGame(options=interface) + join.race = race + join.shared_port = 0 # unused + join.server_ports.game_port = settings["ports"]["server"]["game"] + join.server_ports.base_port = settings["ports"]["server"]["base"] + join.client_ports.add(game_port=settings["ports"]["client"]["game"], + base_port=settings["ports"]["client"]["base"]) + + self._controllers[0].save_map(settings["map_path"], settings["map_data"]) + self._controllers[0].join_game(join) + + def _restart(self): + # Can't restart since it's not clear how you'd coordinate that with the + # other players. + raise RestartException("Can't restart") + + def close(self): + if hasattr(self, "_tcp_conn") and self._tcp_conn: + self._tcp_conn.close() + self._tcp_conn = None + if hasattr(self, "_udp_sock") and self._udp_sock: + self._udp_sock.close() + self._udp_sock = None + super(LanSC2Env, self).close() diff --git a/pysc2/env/remote_sc2_env.py b/pysc2/env/remote_sc2_env.py index 3b09d6f15..558b3ab2f 100644 --- a/pysc2/env/remote_sc2_env.py +++ b/pysc2/env/remote_sc2_env.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A Starcraft II environment.""" +"""A Starcraft II environment for playing using remote SC2 instances.""" from __future__ import absolute_import from __future__ import division @@ -87,8 +87,7 @@ def __init__(self, # pylint: disable=invalid-name replay_dir: Directory to save a replay. Raises: - ValueError: if the agent_race, bot_race or difficulty are invalid. - ValueError: if too many players are requested for a map. + ValueError: if the race is invalid. ValueError: if the resolutions aren't specified correctly. """ if _only_use_kwargs: @@ -125,7 +124,7 @@ def __init__(self, # pylint: disable=invalid-name def _connect_remote(self, host, host_port, lan_port, race, map_inst, interface): - """Make sure this stays synced with bin/play_vs_agent.py.""" + """Make sure this stays synced with bin/agent_remote.py.""" # Connect! logging.info("Connecting...") self._controllers = [remote_controller.RemoteController(host, host_port)] diff --git a/setup.py b/setup.py index 82750729a..daece0daa 100755 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ 'six', 'sk-video', 'websocket-client', + 'whichcraft', ], entry_points={ 'console_scripts': [