diff --git a/src/luxai_s3/env.py b/src/luxai_s3/env.py index 5fe0e90..df9b529 100644 --- a/src/luxai_s3/env.py +++ b/src/luxai_s3/env.py @@ -241,10 +241,11 @@ def update_unit_energy(unit: UnitState, mask): new_tile_types_map = jnp.where(state.steps * params.nebula_tile_drift_speed % 1 == 0, new_tile_types_map, state.map_features.tile_type) # new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)]) - energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16) + energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16) + energy_node_deltas_symmetric = jnp.stack([-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1) # TODO symmetric movement # energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16) - # energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas[::-1])) + energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas_symmetric)) new_energy_nodes = jnp.clip(state.energy_nodes + energy_node_deltas, min=jnp.array([0, 0]), max=jnp.array([params.map_width, params.map_height])) new_energy_nodes = jnp.where(state.steps * params.energy_node_drift_speed % 1 == 0, new_energy_nodes, state.energy_nodes) state = state.replace(map_features=state.map_features.replace(tile_type=new_tile_types_map), energy_nodes=new_energy_nodes) diff --git a/src/luxai_s3/params.py b/src/luxai_s3/params.py index daf1a3d..8cc11fa 100644 --- a/src/luxai_s3/params.py +++ b/src/luxai_s3/params.py @@ -5,7 +5,7 @@ @struct.dataclass class EnvParams: max_steps_in_match: int = 100 - map_type: int = 0 + map_type: int = 1 """Map generation algorithm. Can change between games""" map_width: int = 24 map_height: int = 24 @@ -37,12 +37,12 @@ class EnvParams: # configs for energy nodes - max_energy_nodes: int = 10 + max_energy_nodes: int = 6 max_energy_per_tile: int = 20 min_energy_per_tile: int = -20 - max_relic_nodes: int = 10 + max_relic_nodes: int = 6 relic_config_size: int = 5 fog_of_war: bool = True """ diff --git a/src/luxai_s3/state.py b/src/luxai_s3/state.py index 9a48d98..58d0baa 100644 --- a/src/luxai_s3/state.py +++ b/src/luxai_s3/state.py @@ -113,10 +113,12 @@ class EnvObs: def serialize_env_states(env_states: list[EnvState]): def serialize_array(root: EnvState, arr, key_path: str = ""): - if key_path in ["vision_power_map", "relic_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]: + if key_path in ["vision_power_map", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]: return None if key_path == "relic_nodes": return root.relic_nodes[root.relic_nodes_mask].tolist() + if key_path == "relic_node_configs": + return root.relic_node_configs[root.relic_nodes_mask].tolist() if key_path == "energy_nodes": return root.energy_nodes[root.energy_nodes_mask].tolist() if isinstance(arr, jnp.ndarray): @@ -252,7 +254,7 @@ def spawn_unit( def set_tile(map_features: MapTile, x: int, y: int, tile_type: int) -> MapTile: return map_features.replace(tile_type=map_features.tile_type.at[x, y].set(tile_type)) - +# @functools.partial(jax.jit, static_argnums=(1,)) def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: map_features = MapTile(energy=jnp.zeros( shape=(params.map_height, params.map_width), dtype=jnp.int16 @@ -260,97 +262,93 @@ def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array: shape=(params.map_height, params.map_width), dtype=jnp.int16 )) energy_nodes = jnp.zeros(shape=(params.max_energy_nodes, 2), dtype=jnp.int16) - energy_nodes_mask = jnp.zeros(shape=(params.max_energy_nodes), dtype=jnp.int16) + energy_nodes_mask = jnp.zeros(shape=(params.max_energy_nodes), dtype=jnp.bool) relic_nodes = jnp.zeros(shape=(params.max_relic_nodes, 2), dtype=jnp.int16) relic_nodes_mask = jnp.zeros(shape=(params.max_relic_nodes), dtype=jnp.bool) - if MAP_TYPES[params.map_type] == "dev0": - # assert params.map_height == 16 and params.map_width == 16 - map_features = set_tile(map_features, 4, 4, NEBULA_TILE) - map_features = set_tile(map_features, slice(3, 6), slice(2, 4), NEBULA_TILE) - map_features = set_tile(map_features, slice(4, 7), slice(6, 9), NEBULA_TILE) - map_features = set_tile(map_features, 4, 5, NEBULA_TILE) - map_features = set_tile(map_features, slice(9, 12), slice(5, 6), NEBULA_TILE) - map_features = set_tile(map_features, slice(14, 16), slice(12, 15), NEBULA_TILE) - - map_features = set_tile(map_features, slice(12, 15), slice(8, 10), ASTEROID_TILE) - map_features = set_tile(map_features, slice(1, 4), slice(6, 8), ASTEROID_TILE) - - map_features = set_tile(map_features, slice(11, 12), slice(3, 6), ASTEROID_TILE) - map_features = set_tile(map_features, slice(4, 5), slice(10, 13), ASTEROID_TILE) - map_features = set_tile(map_features,15, 0, ASTEROID_TILE) - - map_features = set_tile(map_features, 11, 11, NEBULA_TILE) - map_features = set_tile(map_features, 11, 12, NEBULA_TILE) - energy_nodes = energy_nodes.at[0, :].set(jnp.array([4, 4], dtype=jnp.int16)) - energy_nodes_mask = energy_nodes_mask.at[0].set(1) - energy_nodes = energy_nodes.at[1, :].set(jnp.array([19, 19], dtype=jnp.int16)) - energy_nodes_mask = energy_nodes_mask.at[1].set(1) + + if MAP_TYPES[params.map_type] == "random": + + ### Generate nebula tiles ### + key, subkey = jax.random.split(key) + perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + noise = jnp.where(perlin_noise > 0.5, 1, 0) + # mirror along diagonal + noise = noise | noise.T + noise = noise[::-1, ::1] + map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0)) + + ### Generate asteroid tiles ### + noise = jnp.where(perlin_noise < -0.5, 1, 0) + # mirror along diagonal + noise = noise | noise.T + noise = noise[::-1, ::1] + map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False)) + + ### Generate relic nodes ### + key, subkey = jax.random.split(key) + noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + # Find the positions of the highest noise values + flat_indices = jnp.argsort(noise.ravel())[-params.max_relic_nodes // 2:] # Get indices of two highest values + highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) + + # relic nodes have a fixed density of 25% nearby tiles can yield points + relic_node_configs = ( + jax.random.randint( + key, + shape=( + params.max_relic_nodes, + params.relic_config_size, + params.relic_config_size, + ), + minval=0, + maxval=10, + dtype=jnp.int16, + ) + >= 7.5 + ) + highest_positions = highest_positions.astype(jnp.int16) + relic_nodes_mask = relic_nodes_mask.at[0].set(True) + relic_nodes_mask = relic_nodes_mask.at[1].set(True) + mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) + relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0) + + key, subkey = jax.random.split(key) + relic_nodes_mask_half = jax.random.randint(key, (params.max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) + relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True) + relic_nodes_mask = relic_nodes_mask.at[:params.max_relic_nodes // 2].set(relic_nodes_mask_half) + relic_nodes_mask = relic_nodes_mask.at[params.max_relic_nodes // 2:].set(relic_nodes_mask_half) + # import ipdb;ipdb.set_trace() + relic_node_configs = relic_node_configs.at[params.max_relic_nodes // 2:].set(relic_node_configs[:params.max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1]) + + ### Generate energy nodes ### + key, subkey = jax.random.split(key) + noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) + # Find the positions of the highest noise values + flat_indices = jnp.argsort(noise.ravel())[-params.max_energy_nodes // 2:] # Get indices of highest values + highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) + mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1) + energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0) + key, subkey = jax.random.split(key) + energy_nodes_mask_half = jax.random.randint(key, (params.max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool) + energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True) + energy_nodes_mask = energy_nodes_mask.at[:params.max_energy_nodes // 2].set(energy_nodes_mask_half) + energy_nodes_mask = energy_nodes_mask.at[params.max_energy_nodes // 2:].set(energy_nodes_mask_half) energy_node_fns = jnp.array( [ [0, 1.2, 1, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], # [1, 4, 0, 2], [0, 1.2, 1, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], # [1, 4, 0, 0] ] ) - energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0) - - relic_node_configs = ( - jax.random.randint( - key, - shape=( - params.max_relic_nodes, - params.relic_config_size, - params.relic_config_size, - ), - minval=0, - maxval=10, - dtype=jnp.int16, - ) - >= 8 - ) - # elif params.map_type == "random": - # Apply the nebula tiles to the map_features - # map_features = map_features.replace(tile_type=jnp.where(nebula_map, NEBULA_TILE, EMPTY_TILE)) - - key, subkey = jax.random.split(key) - perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) - noise = jnp.where(perlin_noise > 0.5, 1, 0) - noise = noise | noise.T - # Flip the noise matrix's rows and columns in reverse - noise = noise[::-1, ::1] - - map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0)) - - noise = jnp.where(perlin_noise < -0.6, 1, 0) - noise = noise | noise.T - # Flip the noise matrix's rows and columns in reverse - noise = noise[::-1, ::1] - - map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, 2, inplace=False)) - - key, subkey = jax.random.split(key) - noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4)) - # Find the positions of the two highest noise values - flat_indices = jnp.argsort(noise.ravel())[-2:] # Get indices of two highest values - highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)) - - # Convert to int16 to match the dtype of energy_nodes - highest_positions = highest_positions.astype(jnp.int16) - # Set relic nodes to the positions of highest noise values - relic_nodes = relic_nodes.at[0, :].set(highest_positions[0]) - relic_nodes_mask = relic_nodes_mask.at[0].set(True) - relic_nodes = relic_nodes.at[1, :].set(highest_positions[1]) - relic_nodes_mask = relic_nodes_mask.at[1].set(True) - mirrored_pos1 = jnp.array([params.map_width - highest_positions[0][1]-1, params.map_height - highest_positions[0][0]-1], dtype=jnp.int16) - mirrored_pos2 = jnp.array([params.map_width - highest_positions[1][1]-1, params.map_height - highest_positions[1][0]-1], dtype=jnp.int16) - # Set the mirrored positions for the other two relic nodes - relic_nodes = relic_nodes.at[2, :].set(mirrored_pos1) - relic_nodes_mask = relic_nodes_mask.at[2].set(True) - relic_nodes = relic_nodes.at[3, :].set(mirrored_pos2) - relic_nodes_mask = relic_nodes_mask.at[3].set(True) - relic_node_configs = relic_node_configs.at[2].set(relic_node_configs[0]) - relic_node_configs = relic_node_configs.at[3].set(relic_node_configs[1]) + # import ipdb; ipdb.set_trace() + # energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0) + + return dict( map_features=map_features, energy_nodes=energy_nodes, diff --git a/src/luxai_s3/wrappers.py b/src/luxai_s3/wrappers.py index 21c63cc..f234ea4 100644 --- a/src/luxai_s3/wrappers.py +++ b/src/luxai_s3/wrappers.py @@ -26,7 +26,7 @@ def __init__(self, numpy_output: bool = False): # print("Running compilation steps") key = jax.random.key(0) # Reset the environment - dummy_env_params = EnvParams(map_type=0) + dummy_env_params = EnvParams(map_type=1) key, reset_key = jax.random.split(key) obs, state = self.jax_env.reset(reset_key, params=dummy_env_params) # Take a random action