Skip to content

Commit

Permalink
symmetrize a lot of things
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Oct 7, 2024
1 parent ff651a3 commit d3bb931
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 89 deletions.
5 changes: 3 additions & 2 deletions src/luxai_s3/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,11 @@ def update_unit_energy(unit: UnitState, mask):
new_tile_types_map = jnp.where(state.steps * params.nebula_tile_drift_speed % 1 == 0, new_tile_types_map, state.map_features.tile_type)
# new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)])

energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
energy_node_deltas_symmetric = jnp.stack([-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1)
# TODO symmetric movement
# energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
# energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas[::-1]))
energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas_symmetric))
new_energy_nodes = jnp.clip(state.energy_nodes + energy_node_deltas, min=jnp.array([0, 0]), max=jnp.array([params.map_width, params.map_height]))
new_energy_nodes = jnp.where(state.steps * params.energy_node_drift_speed % 1 == 0, new_energy_nodes, state.energy_nodes)
state = state.replace(map_features=state.map_features.replace(tile_type=new_tile_types_map), energy_nodes=new_energy_nodes)
Expand Down
6 changes: 3 additions & 3 deletions src/luxai_s3/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@struct.dataclass
class EnvParams:
max_steps_in_match: int = 100
map_type: int = 0
map_type: int = 1
"""Map generation algorithm. Can change between games"""
map_width: int = 24
map_height: int = 24
Expand Down Expand Up @@ -37,12 +37,12 @@ class EnvParams:


# configs for energy nodes
max_energy_nodes: int = 10
max_energy_nodes: int = 6
max_energy_per_tile: int = 20
min_energy_per_tile: int = -20


max_relic_nodes: int = 10
max_relic_nodes: int = 6
relic_config_size: int = 5
fog_of_war: bool = True
"""
Expand Down
164 changes: 81 additions & 83 deletions src/luxai_s3/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ class EnvObs:

def serialize_env_states(env_states: list[EnvState]):
def serialize_array(root: EnvState, arr, key_path: str = ""):
if key_path in ["vision_power_map", "relic_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
if key_path in ["vision_power_map", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
return None
if key_path == "relic_nodes":
return root.relic_nodes[root.relic_nodes_mask].tolist()
if key_path == "relic_node_configs":
return root.relic_node_configs[root.relic_nodes_mask].tolist()
if key_path == "energy_nodes":
return root.energy_nodes[root.energy_nodes_mask].tolist()
if isinstance(arr, jnp.ndarray):
Expand Down Expand Up @@ -252,105 +254,101 @@ def spawn_unit(
def set_tile(map_features: MapTile, x: int, y: int, tile_type: int) -> MapTile:
return map_features.replace(tile_type=map_features.tile_type.at[x, y].set(tile_type))


# @functools.partial(jax.jit, static_argnums=(1,))
def gen_map(key: chex.PRNGKey, params: EnvParams) -> chex.Array:
map_features = MapTile(energy=jnp.zeros(
shape=(params.map_height, params.map_width), dtype=jnp.int16
), tile_type=jnp.zeros(
shape=(params.map_height, params.map_width), dtype=jnp.int16
))
energy_nodes = jnp.zeros(shape=(params.max_energy_nodes, 2), dtype=jnp.int16)
energy_nodes_mask = jnp.zeros(shape=(params.max_energy_nodes), dtype=jnp.int16)
energy_nodes_mask = jnp.zeros(shape=(params.max_energy_nodes), dtype=jnp.bool)
relic_nodes = jnp.zeros(shape=(params.max_relic_nodes, 2), dtype=jnp.int16)
relic_nodes_mask = jnp.zeros(shape=(params.max_relic_nodes), dtype=jnp.bool)
if MAP_TYPES[params.map_type] == "dev0":
# assert params.map_height == 16 and params.map_width == 16
map_features = set_tile(map_features, 4, 4, NEBULA_TILE)
map_features = set_tile(map_features, slice(3, 6), slice(2, 4), NEBULA_TILE)
map_features = set_tile(map_features, slice(4, 7), slice(6, 9), NEBULA_TILE)
map_features = set_tile(map_features, 4, 5, NEBULA_TILE)
map_features = set_tile(map_features, slice(9, 12), slice(5, 6), NEBULA_TILE)
map_features = set_tile(map_features, slice(14, 16), slice(12, 15), NEBULA_TILE)

map_features = set_tile(map_features, slice(12, 15), slice(8, 10), ASTEROID_TILE)
map_features = set_tile(map_features, slice(1, 4), slice(6, 8), ASTEROID_TILE)

map_features = set_tile(map_features, slice(11, 12), slice(3, 6), ASTEROID_TILE)
map_features = set_tile(map_features, slice(4, 5), slice(10, 13), ASTEROID_TILE)
map_features = set_tile(map_features,15, 0, ASTEROID_TILE)

map_features = set_tile(map_features, 11, 11, NEBULA_TILE)
map_features = set_tile(map_features, 11, 12, NEBULA_TILE)
energy_nodes = energy_nodes.at[0, :].set(jnp.array([4, 4], dtype=jnp.int16))
energy_nodes_mask = energy_nodes_mask.at[0].set(1)
energy_nodes = energy_nodes.at[1, :].set(jnp.array([19, 19], dtype=jnp.int16))
energy_nodes_mask = energy_nodes_mask.at[1].set(1)

if MAP_TYPES[params.map_type] == "random":

### Generate nebula tiles ###
key, subkey = jax.random.split(key)
perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4))
noise = jnp.where(perlin_noise > 0.5, 1, 0)
# mirror along diagonal
noise = noise | noise.T
noise = noise[::-1, ::1]
map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))

### Generate asteroid tiles ###
noise = jnp.where(perlin_noise < -0.5, 1, 0)
# mirror along diagonal
noise = noise | noise.T
noise = noise[::-1, ::1]
map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False))

### Generate relic nodes ###
key, subkey = jax.random.split(key)
noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4))
# Find the positions of the highest noise values
flat_indices = jnp.argsort(noise.ravel())[-params.max_relic_nodes // 2:] # Get indices of two highest values
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))

# relic nodes have a fixed density of 25% nearby tiles can yield points
relic_node_configs = (
jax.random.randint(
key,
shape=(
params.max_relic_nodes,
params.relic_config_size,
params.relic_config_size,
),
minval=0,
maxval=10,
dtype=jnp.int16,
)
>= 7.5
)
highest_positions = highest_positions.astype(jnp.int16)
relic_nodes_mask = relic_nodes_mask.at[0].set(True)
relic_nodes_mask = relic_nodes_mask.at[1].set(True)
mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)

key, subkey = jax.random.split(key)
relic_nodes_mask_half = jax.random.randint(key, (params.max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True)
relic_nodes_mask = relic_nodes_mask.at[:params.max_relic_nodes // 2].set(relic_nodes_mask_half)
relic_nodes_mask = relic_nodes_mask.at[params.max_relic_nodes // 2:].set(relic_nodes_mask_half)
# import ipdb;ipdb.set_trace()
relic_node_configs = relic_node_configs.at[params.max_relic_nodes // 2:].set(relic_node_configs[:params.max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1])

### Generate energy nodes ###
key, subkey = jax.random.split(key)
noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4))
# Find the positions of the highest noise values
flat_indices = jnp.argsort(noise.ravel())[-params.max_energy_nodes // 2:] # Get indices of highest values
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
mirrored_positions = jnp.stack([params.map_width - highest_positions[:, 1] - 1, params.map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
key, subkey = jax.random.split(key)
energy_nodes_mask_half = jax.random.randint(key, (params.max_energy_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True)
energy_nodes_mask = energy_nodes_mask.at[:params.max_energy_nodes // 2].set(energy_nodes_mask_half)
energy_nodes_mask = energy_nodes_mask.at[params.max_energy_nodes // 2:].set(energy_nodes_mask_half)
energy_node_fns = jnp.array(
[
[0, 1.2, 1, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
# [1, 4, 0, 2],
[0, 1.2, 1, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
# [1, 4, 0, 0]
]
)
energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0)

relic_node_configs = (
jax.random.randint(
key,
shape=(
params.max_relic_nodes,
params.relic_config_size,
params.relic_config_size,
),
minval=0,
maxval=10,
dtype=jnp.int16,
)
>= 8
)
# elif params.map_type == "random":
# Apply the nebula tiles to the map_features
# map_features = map_features.replace(tile_type=jnp.where(nebula_map, NEBULA_TILE, EMPTY_TILE))

key, subkey = jax.random.split(key)
perlin_noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4))
noise = jnp.where(perlin_noise > 0.5, 1, 0)
noise = noise | noise.T
# Flip the noise matrix's rows and columns in reverse
noise = noise[::-1, ::1]

map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))

noise = jnp.where(perlin_noise < -0.6, 1, 0)
noise = noise | noise.T
# Flip the noise matrix's rows and columns in reverse
noise = noise[::-1, ::1]

map_features = map_features.replace(tile_type=jnp.place(map_features.tile_type, noise, 2, inplace=False))

key, subkey = jax.random.split(key)
noise = generate_perlin_noise_2d(subkey, (params.map_height, params.map_width), (4, 4))
# Find the positions of the two highest noise values
flat_indices = jnp.argsort(noise.ravel())[-2:] # Get indices of two highest values
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))

# Convert to int16 to match the dtype of energy_nodes
highest_positions = highest_positions.astype(jnp.int16)
# Set relic nodes to the positions of highest noise values
relic_nodes = relic_nodes.at[0, :].set(highest_positions[0])
relic_nodes_mask = relic_nodes_mask.at[0].set(True)
relic_nodes = relic_nodes.at[1, :].set(highest_positions[1])
relic_nodes_mask = relic_nodes_mask.at[1].set(True)
mirrored_pos1 = jnp.array([params.map_width - highest_positions[0][1]-1, params.map_height - highest_positions[0][0]-1], dtype=jnp.int16)
mirrored_pos2 = jnp.array([params.map_width - highest_positions[1][1]-1, params.map_height - highest_positions[1][0]-1], dtype=jnp.int16)
# Set the mirrored positions for the other two relic nodes
relic_nodes = relic_nodes.at[2, :].set(mirrored_pos1)
relic_nodes_mask = relic_nodes_mask.at[2].set(True)
relic_nodes = relic_nodes.at[3, :].set(mirrored_pos2)
relic_nodes_mask = relic_nodes_mask.at[3].set(True)
relic_node_configs = relic_node_configs.at[2].set(relic_node_configs[0])
relic_node_configs = relic_node_configs.at[3].set(relic_node_configs[1])
# import ipdb; ipdb.set_trace()
# energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0)


return dict(
map_features=map_features,
energy_nodes=energy_nodes,
Expand Down
2 changes: 1 addition & 1 deletion src/luxai_s3/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, numpy_output: bool = False):
# print("Running compilation steps")
key = jax.random.key(0)
# Reset the environment
dummy_env_params = EnvParams(map_type=0)
dummy_env_params = EnvParams(map_type=1)
key, reset_key = jax.random.split(key)
obs, state = self.jax_env.reset(reset_key, params=dummy_env_params)
# Take a random action
Expand Down

0 comments on commit d3bb931

Please sign in to comment.