Skip to content

Commit

Permalink
fix evaluation in standalone (#14)
Browse files Browse the repository at this point in the history
* fix evaluation in standalone ppo
  • Loading branch information
Howuhh authored Mar 24, 2024
1 parent 1f5bc0c commit 501a636
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,13 @@
"env, env_params = xminigrid.make(config.env_id)\n",
"env = GymAutoResetWrapper(env)\n",
"\n",
"# enabling image observations if needed\n",
"if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
" \n",
"\n",
"ruleset = xminigrid.load_benchmark(config.benchmark_id).get_ruleset(ruleset_id=0)\n",
"env_params = env_params.replace(ruleset=ruleset)\n",
"\n",
Expand All @@ -739,6 +746,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
")\n",
"\n",
"# jitting all functions\n",
Expand Down
14 changes: 14 additions & 0 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,19 @@
"env, env_params = xminigrid.make(config.env_id)\n",
"env = GymAutoResetWrapper(env)\n",
"\n",
"# for single-task XLand environments\n",
"if config.benchmark_id is not None:\n",
" assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n",
" assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n",
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
" env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n",
"\n",
"# enabling image observations if needed\n",
"if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
"\n",
"# you can use train_state from the final runner_state also\n",
"# we just demo here how to do it if you loaded params from the checkpoint\n",
"params = train_info[\"runner_state\"][1].params\n",
Expand All @@ -699,6 +712,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
")\n",
"\n",
"# jitting all functions\n",
Expand Down

0 comments on commit 501a636

Please sign in to comment.