Skip to content

Commit

Permalink
Full typings support (#6)
Browse files Browse the repository at this point in the history
* almost full type hints support

* type hints support

* fix pyright config

* updated readme
  • Loading branch information
Howuhh authored Feb 11, 2024
1 parent 7a3e55e commit 98c4969
Show file tree
Hide file tree
Showing 27 changed files with 242 additions and 185 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ jobs:
pip install -e ".[dev]"
- name: check codestyle
run: |
ruff --config pyproject.toml --diff .
ruff --config pyproject.toml --diff .
- name: check type hints
run: |
pyright --project=pyproject.toml src/xminigrid
11 changes: 9 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
repos:
# ruff checking
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.4
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- id: ruff-format

# pyright checking
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
hooks:
- id: pyright
args: [--project=pyproject.toml]
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pip install -e ".[dev]"

## Code style

We use awesome [Ruff](https://docs.astral.sh/ruff/) linter and formatter.
We use awesome [Ruff](https://docs.astral.sh/ruff/) linter and formatter and [Pyright](https://microsoft.github.io/pyright/#/) for type checking.
The CI will run several checks on the new code pushed to the repository.
These checks can also be run locally without waiting for the CI by following the steps below:

Expand All @@ -42,6 +42,7 @@ skipped (not recommended) with `git commit --no-verify`.
Be sure to run and fix all issues from the `pre-commit run --all-files` before the push!
If you want to see possible problems before pre-commit, you can run `ruff check --diff .`
and `ruff format --check` to see exact linter and formatter suggestions and possible fixes.
Similarly, run `pyright src/xminigrid` to see possible problems with type hints.

# License

Expand Down
38 changes: 16 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
<a href="https://badge.fury.io/py/xminigrid">
<img src="https://badge.fury.io/py/xminigrid.svg"/>
</a>
<a href="https://github.com/corl-team/xland-minigrid/main/LICENSE">
<img src="https://img.shields.io/badge/license-Apache_2.0-blue"/>
</a>
<a href="https://github.com/astral-sh/ruff">
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json"/>
</a>
<a href="https://arxiv.org/abs/2312.12044">
<img src="https://img.shields.io/badge/arXiv-2210.07105-b31b1b.svg"/>
</a>
<a href="https://twitter.com/vladkurenkov/status/1731709425524543550">
<img src="https://badgen.net/badge/icon/twitter?icon=twitter&label"/>
</a>
Expand All @@ -21,13 +21,18 @@
</a>
</p>


[//]: # ( <a href="https://badge.fury.io/py/xminigrid">)

[//]: # ( <img src="https://img.shields.io/pypi/dm/xminigrid?color=yellow&label=Downloads"/>)

[//]: # ( </a>)

[//]: # ( <a href="https://github.com/corl-team/xland-minigrid/main/LICENSE">)

[//]: # ( <img src="https://img.shields.io/badge/license-Apache_2.0-blue"/>)

[//]: # ( </a>)

![img](figures/readme-main-img.png)

# Meta-Reinforcement Learning in JAX
Expand All @@ -51,11 +56,10 @@ diverse task distributions
- 📈 Easily scales to $2^{16}$ parallel environments and millions of steps per second on a single GPU
- 🔥 Multi-GPU PPO baselines in the [PureJaxRL](https://github.com/luchris429/purejaxrl) style, which can achieve **1 trillion** environment steps under two days

How cool is that? For more details, take a look at the [technical paper (soon)]() or
How cool is that? For more details, take a look at the [technical paper](https://arxiv.org/abs/2312.12044) or
[examples](examples), which will walk you through the basics and training your own adaptive agents in minutes!

![img](figures/times_minigrid.jpg)
TODO: update this with the latest version of the codebase...

## Installation 🎁

Expand Down Expand Up @@ -90,6 +94,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm
```python
import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper

key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)
Expand All @@ -103,6 +108,9 @@ ruleset = benchmark.sample_ruleset(ruleset_key)
env, env_params = xminigrid.make("XLand-MiniGrid-R9-25x25")
env_params = env_params.replace(ruleset=ruleset)

# auto-reset wrapper
env = GymAutoResetWrapper(env)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)
Expand Down Expand Up @@ -157,7 +165,7 @@ While composing rules and goals by hand is flexible, it can quickly become cumbe
Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations

To avoid significant overhead during training and facilitate reliable comparisons between agents,
we pre-sampled several benchmarks with up to **five million unique tasks**, following the procedure used to train DeepMind
we pre-sampled several benchmarks with up to **three million unique tasks**, following the procedure used to train DeepMind
AdA agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with
varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example
the `trivial-1m` benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution
Expand Down Expand Up @@ -189,7 +197,7 @@ We also provide the [script](scripts/ruleset_generator.py) used to generate thes
python scripts/ruleset_generator.py --help
```

In depth description of all available benchmarks is provided [here (soon)]().
In depth description of all available benchmarks is provided [in the technical paper](https://arxiv.org/abs/2312.12044) (Section 3).

**P.S.** Be aware, that benchmarks can change, as we are currently testing and balancing them!

Expand Down Expand Up @@ -249,20 +257,6 @@ Furthermore, we provide standalone implementations that can be trained in Colab:
available. How much fun would that be 🤔? However, we hope that they will
help to get started quickly!

## Roadmap 🗓️

With the initial release of XLand-MiniGrid, things are just getting started. There is a long way to go in
terms of polishing the code, adding new features, and improving the overall user experience. What we
currently plan to improve in forthcoming releases:
1. Tweaks to the benchmark generation, time-limits
2. Documentation (in code and as standalone site)
3. Full type hints coverage, type checking
4. Tests
5. More examples and tutorials

After that we will start thinking on new major features, environments and bechmarks.
However, we should perfect the core before that.

## Contributing 🔨

We welcome anyone interested in helping out! Please take a look at our [contribution guide](CONTRIBUTING.md)
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ classifiers = [
]

dependencies = [
"jax>=0.4.13",
"jaxlib>=0.4.13",
"flax>=0.7.0",
"jax>=0.4.16",
"jaxlib>=0.4.16",
"flax>=0.8.0",
"rich>=13.4.2",
]

Expand Down Expand Up @@ -96,7 +96,9 @@ exclude = [
"**/__pycache__",
]

reportMissingImports = true
reportMissingImports = "none"
reportMissingTypeStubs = false
reportMissingModuleSource = false

pythonVersion = "3.10"
pythonPlatform = "All"
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.5.1"
__version__ = "0.6.0"

# ---------- XLand-MiniGrid environments ----------

Expand Down
5 changes: 2 additions & 3 deletions src/xminigrid/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import jax.numpy as jnp
import jax.tree_util as jtu
from flax import struct
from jax.random import KeyArray
from tqdm.auto import tqdm

from .types import RuleSet
Expand Down Expand Up @@ -43,11 +42,11 @@ def num_rulesets(self) -> int:
def get_ruleset(self, ruleset_id: int | jax.Array) -> RuleSet:
return get_ruleset(self.goals, self.rules, self.init_tiles, ruleset_id)

def sample_ruleset(self, key: KeyArray) -> RuleSet:
def sample_ruleset(self, key: jax.Array) -> RuleSet:
ruleset_id = jax.random.randint(key, shape=(), minval=0, maxval=self.num_rulesets())
return self.get_ruleset(ruleset_id)

def shuffle(self, key: KeyArray) -> Benchmark:
def shuffle(self, key: jax.Array) -> Benchmark:
idxs = jax.random.permutation(key, jnp.arange(len(self.num_rules)))
return jtu.tree_map(lambda a: a[idxs], self)

Expand Down
4 changes: 2 additions & 2 deletions src/xminigrid/core/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from typing_extensions import TypeAlias

from ..types import AgentState, GridState
from ..types import AgentState, GridState, IntOrArray
from .constants import DIRECTIONS, TILES_REGISTRY, Colors, Tiles
from .grid import check_can_put, check_pickable, check_walkable, equal

Expand Down Expand Up @@ -109,7 +109,7 @@ def toggle(grid: GridState, agent: AgentState) -> ActionOutput:
return new_grid, agent, next_position


def take_action(grid: GridState, agent: AgentState, action: int) -> ActionOutput:
def take_action(grid: GridState, agent: AgentState, action: IntOrArray) -> ActionOutput:
# This will evaluate all actions.
# Can we fix this and choose only one function? It'll speed everything up dramatically.
actions = (
Expand Down
31 changes: 15 additions & 16 deletions src/xminigrid/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax.random import KeyArray

from ..types import GridState, Tile
from ..types import GridState, IntOrArray, Tile
from .constants import FREE_TO_PUT_DOWN, LOS_BLOCKING, PICKABLE, TILES_REGISTRY, WALKABLE, Colors, Tiles


def empty_world(height: int, width: int) -> GridState:
def empty_world(height: IntOrArray, width: IntOrArray) -> GridState:
grid = jnp.zeros((height, width, 2), dtype=jnp.uint8)
grid = grid.at[:, :, 0:2].set(TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK])
return grid
Expand All @@ -21,7 +20,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile:
return jnp.all(jnp.equal(tile1, tile2))


def get_neighbouring_tiles(grid: GridState, y: int | jax.Array, x: int | jax.Array) -> tuple[Tile, Tile, Tile, Tile]:
def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]:
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
end_of_map = Tiles.END_OF_MAP

Expand All @@ -36,31 +35,31 @@ def get_neighbouring_tiles(grid: GridState, y: int | jax.Array, x: int | jax.Arr
return up_tile, right_tile, down_tile, left_tile


def horizontal_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
def horizontal_line(grid: GridState, x: IntOrArray, y: IntOrArray, length: IntOrArray, tile: Tile) -> GridState:
grid = grid.at[y, x : x + length].set(tile)
return grid


def vertical_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
def vertical_line(grid: GridState, x: IntOrArray, y: IntOrArray, length: IntOrArray, tile: Tile) -> GridState:
grid = grid.at[y : y + length, x].set(tile)
return grid


def rectangle(grid: GridState, x: int, y: int, h: int, w: int, tile: Tile) -> GridState:
def rectangle(grid: GridState, x: IntOrArray, y: IntOrArray, h: IntOrArray, w: IntOrArray, tile: Tile) -> GridState:
grid = vertical_line(grid, x, y, h, tile)
grid = vertical_line(grid, x + w - 1, y, h, tile)
grid = horizontal_line(grid, x, y, w, tile)
grid = horizontal_line(grid, x, y + h - 1, w, tile)
return grid


def room(height: int, width: int) -> GridState:
def room(height: IntOrArray, width: IntOrArray) -> GridState:
grid = empty_world(height, width)
grid = rectangle(grid, 0, 0, height, width, tile=TILES_REGISTRY[Tiles.WALL, Colors.GREY])
return grid


def two_rooms(height: int, width: int) -> GridState:
def two_rooms(height: IntOrArray, width: IntOrArray) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
Expand All @@ -69,7 +68,7 @@ def two_rooms(height: int, width: int) -> GridState:
return grid


def four_rooms(height: int, width: int) -> GridState:
def four_rooms(height: IntOrArray, width: IntOrArray) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
Expand All @@ -79,7 +78,7 @@ def four_rooms(height: int, width: int) -> GridState:
return grid


def nine_rooms(height: int, width: int) -> GridState:
def nine_rooms(height: IntOrArray, width: IntOrArray) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
Expand Down Expand Up @@ -118,7 +117,7 @@ def check_see_behind(grid: GridState, position: jax.Array) -> jax.Array:
return is_not_blocking


def align_with_up(grid: GridState, direction: int | jax.Array) -> GridState:
def align_with_up(grid: GridState, direction: IntOrArray) -> GridState:
aligned_grid = jax.lax.switch(
direction,
(
Expand Down Expand Up @@ -151,15 +150,15 @@ def free_tiles_mask(grid: GridState) -> jax.Array:
return mask


def coordinates_mask(grid: GridState, address: tuple[int, int], comparison_fn: Callable) -> jax.Array:
def coordinates_mask(grid: GridState, address: tuple[IntOrArray, IntOrArray], comparison_fn: Callable) -> jax.Array:
positions = jnp.mgrid[: grid.shape[0], : grid.shape[1]]
cond_1 = comparison_fn(positions[0], address[0])
cond_2 = comparison_fn(positions[1], address[1])
mask = jnp.logical_and(cond_1, cond_2)
return mask


def sample_coordinates(key: KeyArray, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
if mask is None:
mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_)

Expand All @@ -175,7 +174,7 @@ def sample_coordinates(key: KeyArray, grid: GridState, num: int, mask: jax.Array
return coords


def sample_direction(key: KeyArray) -> jax.Array:
def sample_direction(key: jax.Array) -> jax.Array:
return jax.random.randint(key, shape=(), minval=0, maxval=4)


Expand Down
Loading

0 comments on commit 98c4969

Please sign in to comment.