Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LMMAES #56

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 67 additions & 90 deletions evosax/strategies/lm_ma_es.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Tuple, Optional, Union
from typing import Optional, Tuple, Union

import chex
import jax
import jax.numpy as jnp
import chex
from evosax.strategy import Strategy
arahi10 marked this conversation as resolved.
Show resolved Hide resolved
from flax import struct
from .cma_es import get_cma_elite_weights
from ..strategy import Strategy


@struct.dataclass
Expand All @@ -17,18 +17,15 @@ class EvoState:
c_d: chex.Array
weights_truncated: chex.Array
best_member: chex.Array
z: chex.Array
best_fitness: float = jnp.finfo(jnp.float32).max
gen_counter: int = 0


@struct.dataclass
class EvoParams:
mu_eff: float
c_1: float
c_mu: float
c_sigma: float
d_sigma: float
chi_n: float
mu_w: float
c_m: float = 1.0
sigma_init: float = 0.065
Expand All @@ -55,12 +52,7 @@ def __init__(
Reference: https://arxiv.org/pdf/1705.06693.pdf
"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs
)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
Expand All @@ -77,53 +69,40 @@ def __init__(
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
_, weights_truncated, mu_eff, c_1, c_mu = get_cma_elite_weights(
self.popsize, self.elite_popsize, self.num_dims, self.max_dims_sq
w_hat = jnp.array(
[
jnp.log(self.elite_popsize + 0.5) - jnp.log(i)
for i in range(1, self.elite_popsize + 1)
]
)

weights_truncated = w_hat / jnp.sum(w_hat)
# lrate for cumulation of step-size control and rank-one update
c_sigma = (mu_eff + 2) / (self.num_dims + mu_eff + 5)
d_sigma = (
1
+ 2
* jnp.maximum(0, jnp.sqrt((mu_eff - 1) / (self.num_dims + 1)) - 1)
+ c_sigma
)
chi_n = jnp.sqrt(self.num_dims) * (
1.0
- (1.0 / (4.0 * self.num_dims))
+ 1.0 / (21.0 * (self.max_dims_sq ** 2))
)
mu_w = 1 / jnp.sum(weights_truncated ** 2)
c_sigma = (2 * self.popsize) / self.num_dims
d_sigma = 2
mu_w = 1 / jnp.sum(jnp.square(weights_truncated))
params = EvoParams(
mu_eff=mu_eff,
c_1=c_1,
c_mu=c_mu,
c_sigma=c_sigma,
d_sigma=d_sigma,
chi_n=chi_n,
mu_w=mu_w,
sigma_init=self.sigma_init,
)
return params

def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
_, weights_truncated, _, _, _ = get_cma_elite_weights(
self.popsize, self.elite_popsize, self.num_dims, self.max_dims_sq
w_hat = jnp.array(
[
jnp.log(self.elite_popsize + 0.5) - jnp.log(i)
for i in range(1, self.elite_popsize + 1)
]
)
weights_truncated = w_hat / jnp.sum(w_hat)
c_d = jnp.array(
[1 / (1.5 ** i * self.num_dims) for i in range(self.memory_size)]
[1 / (1.5**i * self.num_dims) for i in range(self.memory_size)]
)
c_c = jnp.array(
[
self.popsize / (4 ** i * self.num_dims)
for i in range(self.memory_size)
]
[self.popsize / (4**i * self.num_dims) for i in range(self.memory_size)]
)
c_c = jnp.minimum(c_c, 1.99)

# Initialize evolution paths & covariance matrix
initialization = jax.random.uniform(
Expand All @@ -136,19 +115,20 @@ def initialize_strategy(
p_sigma=jnp.zeros(self.num_dims),
sigma=params.sigma_init,
mean=initialization,
M=jnp.zeros((self.num_dims, self.memory_size)),
M=jnp.zeros((self.memory_size, self.num_dims)),
weights_truncated=weights_truncated,
c_d=c_d,
c_c=c_c,
best_member=initialization,
z=jnp.zeros((self.popsize, self.num_dims)),
)
return state

def ask_strategy(
self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
x = sample(
x, z = sample(
rng,
state.mean,
state.sigma,
Expand All @@ -158,7 +138,7 @@ def ask_strategy(
state.c_d,
state.gen_counter,
)
return x, state
return x, state.replace(z=z)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing it, but where is the z needed? (popsize, num_dims) will require quite a lot of memory for many applications. Can we reconstruct this from the sampled x candidates instead?

Copy link
Author

@arahi10 arahi10 Jun 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this modification is the largest difference between original evosax's code and numpy code for memory size.
We have to use z to update p_sigma,M(pls see the update functions), but in order to get z from solution vectors we have to calculate inverse matrices and that may causes both computing errors and costs. I think the procedure for getting z from solutions will be like this;
iteratively calculate inverse of (1-c_d[j])*I+c_d[j]*jnp.outer(M[j,:],M[j,:]) like sample function and select j by checking generation.

Are there any linear algebra tricks to keep away from calculation of inverse matrices? I can't came up with that.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhmm, I am not sure -- but I guess this is a bit of a problem. I guess it will be hard to scale to neuroevolution even though it is the limited memory version :( I will give this more of a thought.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello. If possible, could you please approve the workflow and run the test before devising the implementation? I am wondering if there are any implementation mistakes at this time.


def tell_strategy(
self,
Expand All @@ -171,87 +151,86 @@ def tell_strategy(
# Sort new results, extract elite, store best performer
concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x])
sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()]
concat_z_f = jnp.hstack([jnp.expand_dims(fitness, 1), state.z])
sorted_zvectors = concat_z_f[concat_z_f[:, 0].argsort()]
sorted_z = sorted_zvectors[:, 1:]
if state.best_fitness < sorted_solutions[0, 0]:
state.best_fitness = sorted_solutions[0, 0]
state.best_member = sorted_solutions[0, 1:]
arahi10 marked this conversation as resolved.
Show resolved Hide resolved
# Update mean, isotropic/anisotropic paths, covariance, stepsize
mean, z_k = update_mean(
mean = update_mean(
state.mean,
state.sigma,
sorted_solutions,
self.elite_popsize,
params.c_m,
state.weights_truncated,
)
p_sigma, norm_p_sigma = update_p_sigma(
z_k,
p_sigma, norm_p_sigma, wz = update_p_sigma(
sorted_z,
self.elite_popsize,
state.p_sigma,
params.c_sigma,
params.mu_eff,
params.mu_w,
state.weights_truncated,
)
M = update_M_matrix(
state.M,
z_k,
wz,
state.c_c,
params.mu_w,
state.weights_truncated,
)
sigma = update_sigma(
state.sigma,
norm_p_sigma,
params.c_sigma,
params.d_sigma,
params.chi_n,
self.num_dims,
)
return state.replace(mean=mean, p_sigma=p_sigma, M=M, sigma=sigma)


def update_mean(
mean: chex.Array,
sigma: float,
sorted_solutions: chex.Array,
elite_popsize: int,
c_m: float,
weights_truncated: chex.Array,
) -> Tuple[chex.Array, chex.Array]:
"""Update mean of strategy."""
z_k = sorted_solutions[:, 1:] - mean # ~ N(0, σ^2 C)
y_k = z_k / sigma # ~ N(0, C)
y_w = jnp.sum(y_k.T * weights_truncated, axis=1)
mean += c_m * sigma * y_w
return mean, z_k
y_k = sorted_solutions[:elite_popsize, 1:] - mean # ~ N(0, σ^2 C)
G_m = weights_truncated.T @ y_k
mean += c_m * G_m
return mean


def update_p_sigma(
z_k: chex.Array,
elite_popsize: int,
p_sigma: chex.Array,
c_sigma: float,
mu_eff: float,
mu_w: float,
weights_truncated: chex.Array,
) -> Tuple[chex.Array, float]:
"""Update evolution path for covariance matrix."""
z_w = jnp.sum(z_k.T * weights_truncated, axis=1)
wz = weights_truncated.T @ (z_k[:elite_popsize, :])
p_sigma_new = (1 - c_sigma) * p_sigma + jnp.sqrt(
c_sigma * (2 - c_sigma) * mu_eff
) * z_w
mu_w * c_sigma * (2 - c_sigma)
) * wz
norm_p_sigma = jnp.linalg.norm(p_sigma_new)
return p_sigma_new, norm_p_sigma
return p_sigma_new, norm_p_sigma, wz


def update_M_matrix(
M: chex.Array,
z_k: chex.Array,
wz: chex.Array,
c_c: chex.Array,
mu_w: float,
weights_truncated: chex.Array,
) -> chex.Array:
"""Update the M matrix."""
weighted_elite = jnp.sum(
jnp.array([w * z for w, z in zip(weights_truncated, z_k)]),
axis=0,
)
# Loop over individual memory components - this could be vectorized!
for i in range(M.shape[1]):
new_m = (1 - c_c[i]) * M[:, i] + jnp.sqrt(
mu_w * c_c[i] * (2 - c_c[i])
) * weighted_elite
M = M.at[:, i].set(new_m)
for i in range(M.shape[0]):
new_m = (1 - c_c[i]) * M[i, :] + jnp.sqrt(mu_w * c_c[i] * (2 - c_c[i])) * wz
M = M.at[i, :].set(new_m)
return M


Expand All @@ -260,11 +239,11 @@ def update_sigma(
norm_p_sigma: float,
c_sigma: float,
d_sigma: float,
chi_n: float,
n: int,
) -> float:
"""Update stepsize sigma."""
sigma_new = sigma * jnp.exp(
(c_sigma / d_sigma) * (norm_p_sigma / chi_n - 1)
(c_sigma / d_sigma) * (jnp.square(norm_p_sigma) / n - 1)
)
return sigma_new

Expand All @@ -278,15 +257,13 @@ def sample(
pop_size: int,
c_d: chex.Array,
gen_counter: int,
) -> chex.Array:
) -> tuple[chex.Array, chex.Array]:
"""Jittable Gaussian Sample Helper."""
z = jax.random.normal(rng, (n_dim, pop_size)) # ~ N(0, I)
for j in range(M.shape[1]):
z = jax.random.normal(rng, (pop_size, n_dim)) # ~ N(0, I)
d = jnp.copy(z)
for j in range(M.shape[0]):
update_bool = gen_counter > j
new_z = (1 - c_d[j]) * z + (c_d[j] * M[:, j])[:, jnp.newaxis] * (
M[:, j][:, jnp.newaxis] * z
)
z = jax.lax.select(update_bool, new_z, z)
z = jnp.swapaxes(z, 1, 0)
x = mean + sigma * z # ~ N(m, σ^2 C)
return x
new_d = (1 - c_d[j]) * d + c_d[j] * jnp.outer(jnp.dot(d, M[j, :]), M[j, :])
d = jax.lax.select(update_bool, new_d, d)
x = mean + sigma * d # ~ N(m, σ^2 C)
return x, z