From d0f58aa80fe1d591c3ccc0602d18c1770bba7f2f Mon Sep 17 00:00:00 2001 From: tewalds Date: Tue, 15 Aug 2017 10:56:35 +0100 Subject: [PATCH] Change the app.run and entry points to be more explicit to avoid the confusion (eg https://github.com/deepmind/pysc2/pull/14) of why the flags weren't being parsed. PiperOrigin-RevId: 165287285 --- pysc2/bin/agent.py | 8 ++++---- pysc2/bin/gen_actions.py | 2 +- pysc2/bin/map_list.py | 2 +- pysc2/bin/play.py | 8 ++++---- pysc2/bin/replay_actions.py | 3 ++- pysc2/bin/replay_info.py | 9 ++++----- pysc2/bin/valid_actions.py | 3 ++- pysc2/lib/app.py | 9 ++++++--- pysc2/lib/basetest.py | 2 +- setup.py | 6 +++--- 10 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pysc2/bin/agent.py b/pysc2/bin/agent.py index 368751aac..217bdbe0b 100755 --- a/pysc2/bin/agent.py +++ b/pysc2/bin/agent.py @@ -79,7 +79,7 @@ def run_thread(agent_cls, map_name, visualize): env.save_replay(agent_cls.__name__) -def _main(unused_argv): +def main(unused_argv): """Run an agent.""" stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace stopwatch.sw.trace = FLAGS.trace @@ -104,9 +104,9 @@ def _main(unused_argv): print(stopwatch.sw) -def main(): # Needed so setup.py scripts work. - app.really_start(_main) +def entry_point(): # Needed so setup.py scripts work. + app.run(main) if __name__ == "__main__": - main() + app.run(main) diff --git a/pysc2/bin/gen_actions.py b/pysc2/bin/gen_actions.py index 1bf7a8ef4..e33445b71 100755 --- a/pysc2/bin/gen_actions.py +++ b/pysc2/bin/gen_actions.py @@ -211,4 +211,4 @@ def main(unused_argv): if __name__ == "__main__": - app.really_start(main) + app.run(main) diff --git a/pysc2/bin/map_list.py b/pysc2/bin/map_list.py index 64cd4b05e..c848d359b 100755 --- a/pysc2/bin/map_list.py +++ b/pysc2/bin/map_list.py @@ -31,4 +31,4 @@ def main(unused_argv): if __name__ == "__main__": - app.run() + app.run(main) diff --git a/pysc2/bin/play.py b/pysc2/bin/play.py index 6c0e239a7..ce7a32c75 100755 --- a/pysc2/bin/play.py +++ b/pysc2/bin/play.py @@ -66,7 +66,7 @@ flags.DEFINE_string("replay", None, "Name of a replay to show.") -def _main(unused_argv): +def main(unused_argv): """Run SC2 to play a game or a replay.""" stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace stopwatch.sw.trace = FLAGS.trace @@ -175,9 +175,9 @@ def _main(unused_argv): print(stopwatch.sw) -def main(): # Needed so setup.py scripts work. - app.really_start(_main) +def entry_point(): # Needed so setup.py scripts work. + app.run(main) if __name__ == "__main__": - main() + app.run(main) diff --git a/pysc2/bin/replay_actions.py b/pysc2/bin/replay_actions.py index 77426421b..d1c9b00fb 100755 --- a/pysc2/bin/replay_actions.py +++ b/pysc2/bin/replay_actions.py @@ -365,5 +365,6 @@ def main(unused_argv): stats_queue.put(None) # Tell the stats_thread to print and exit. stats_thread.join() + if __name__ == "__main__": - app.run() + app.run(main) diff --git a/pysc2/bin/replay_info.py b/pysc2/bin/replay_info.py index ea5d03dea..0b051bc4b 100755 --- a/pysc2/bin/replay_info.py +++ b/pysc2/bin/replay_info.py @@ -92,7 +92,7 @@ def _replay_info(replay_path): print(info) -def _main(argv=()): +def main(argv): if not argv: raise ValueError("No replay directory or path specified.") if len(argv) > 2: @@ -108,10 +108,9 @@ def _main(argv=()): pass -def main(): # Needed so the setup.py scripts work. - app.really_start(_main) +def entry_point(): # Needed so the setup.py scripts work. + app.run(main) if __name__ == "__main__": - main() - + app.run(main) diff --git a/pysc2/bin/valid_actions.py b/pysc2/bin/valid_actions.py index 85e366e0b..ce66f5082 100755 --- a/pysc2/bin/valid_actions.py +++ b/pysc2/bin/valid_actions.py @@ -53,5 +53,6 @@ def main(unused_argv): print("Total base actions:", count) print("Total possible actions (flattened):", flattened) + if __name__ == "__main__": - app.run() + app.run(main) diff --git a/pysc2/lib/app.py b/pysc2/lib/app.py index 13679e431..85e9c944e 100644 --- a/pysc2/lib/app.py +++ b/pysc2/lib/app.py @@ -39,9 +39,6 @@ def usage(): print(str(FLAGS)) sys.exit() - def run(): - really_start(sys.modules["__main__"].main) - def really_start(main): try: argv = FLAGS(sys.argv) @@ -52,3 +49,9 @@ def really_start(main): if FLAGS.help: usage() sys.exit(main(argv)) + + +# It's ok to override the apputils.app.run as it doesn't seem to do much more +# than call really_start anyway. +def run(main=None): + really_start(main or sys.modules["__main__"].main) diff --git a/pysc2/lib/basetest.py b/pysc2/lib/basetest.py index d2a2a3e27..e7c4a2bbd 100644 --- a/pysc2/lib/basetest.py +++ b/pysc2/lib/basetest.py @@ -31,4 +31,4 @@ TestCase = unittest.TestCase def main(): - app.really_start(lambda argv: unittest.main(argv=argv)) + app.run(lambda argv: unittest.main(argv=argv)) diff --git a/setup.py b/setup.py index fd5b00073..e982dc334 100755 --- a/setup.py +++ b/setup.py @@ -64,9 +64,9 @@ def read(fname): ], entry_points={ 'console_scripts': [ - 'pysc2_agent = pysc2.bin.agent:main', - 'pysc2_play = pysc2.bin.play:main', - 'pysc2_replay_info = pysc2.bin.replay_info:main', + 'pysc2_agent = pysc2.bin.agent:entry_point', + 'pysc2_play = pysc2.bin.play:entry_point', + 'pysc2_replay_info = pysc2.bin.replay_info:entry_point', ], }, classifiers=[