From 038b179549ff06c4ebd9326d689e6eeef7ed0761 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Sun, 4 Jun 2023 14:58:43 +0900 Subject: [PATCH 1/7] bug fixed in LM-MA-ES;default parameters,parameter update expressions. --- evosax/strategies/lm_ma_es.py | 157 +++++++++++++++------------------- 1 file changed, 67 insertions(+), 90 deletions(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index deba2a0..6cb1135 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -1,10 +1,10 @@ -from typing import Tuple, Optional, Union +from typing import Optional, Tuple, Union + +import chex import jax import jax.numpy as jnp -import chex +from evosax.strategy import Strategy from flax import struct -from .cma_es import get_cma_elite_weights -from ..strategy import Strategy @struct.dataclass @@ -17,18 +17,15 @@ class EvoState: c_d: chex.Array weights_truncated: chex.Array best_member: chex.Array + z: chex.Array best_fitness: float = jnp.finfo(jnp.float32).max gen_counter: int = 0 @struct.dataclass class EvoParams: - mu_eff: float - c_1: float - c_mu: float c_sigma: float d_sigma: float - chi_n: float mu_w: float c_m: float = 1.0 sigma_init: float = 0.065 @@ -55,12 +52,7 @@ def __init__( Reference: https://arxiv.org/pdf/1705.06693.pdf """ super().__init__( - popsize, - num_dims, - pholder_params, - mean_decay, - n_devices, - **fitness_kwargs + popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs ) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio @@ -77,53 +69,40 @@ def __init__( @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - _, weights_truncated, mu_eff, c_1, c_mu = get_cma_elite_weights( - self.popsize, self.elite_popsize, self.num_dims, self.max_dims_sq + w_hat = jnp.array( + [ + jnp.log(self.elite_popsize + 0.5) - jnp.log(i) + for i in range(1, self.elite_popsize + 1) + ] ) - + weights_truncated = w_hat / jnp.sum(w_hat) # lrate for cumulation of step-size control and rank-one update - c_sigma = (mu_eff + 2) / (self.num_dims + mu_eff + 5) - d_sigma = ( - 1 - + 2 - * jnp.maximum(0, jnp.sqrt((mu_eff - 1) / (self.num_dims + 1)) - 1) - + c_sigma - ) - chi_n = jnp.sqrt(self.num_dims) * ( - 1.0 - - (1.0 / (4.0 * self.num_dims)) - + 1.0 / (21.0 * (self.max_dims_sq ** 2)) - ) - mu_w = 1 / jnp.sum(weights_truncated ** 2) + c_sigma = (2 * self.popsize) / self.num_dims + d_sigma = 2 + mu_w = 1 / jnp.sum(jnp.square(weights_truncated)) params = EvoParams( - mu_eff=mu_eff, - c_1=c_1, - c_mu=c_mu, c_sigma=c_sigma, d_sigma=d_sigma, - chi_n=chi_n, mu_w=mu_w, sigma_init=self.sigma_init, ) return params - def initialize_strategy( - self, rng: chex.PRNGKey, params: EvoParams - ) -> EvoState: + def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState: """`initialize` the evolution strategy.""" - _, weights_truncated, _, _, _ = get_cma_elite_weights( - self.popsize, self.elite_popsize, self.num_dims, self.max_dims_sq + w_hat = jnp.array( + [ + jnp.log(self.elite_popsize + 0.5) - jnp.log(i) + for i in range(1, self.elite_popsize + 1) + ] ) + weights_truncated = w_hat / jnp.sum(w_hat) c_d = jnp.array( - [1 / (1.5 ** i * self.num_dims) for i in range(self.memory_size)] + [1 / (1.5**i * self.num_dims) for i in range(self.memory_size)] ) c_c = jnp.array( - [ - self.popsize / (4 ** i * self.num_dims) - for i in range(self.memory_size) - ] + [self.popsize / (4**i * self.num_dims) for i in range(self.memory_size)] ) - c_c = jnp.minimum(c_c, 1.99) # Initialize evolution paths & covariance matrix initialization = jax.random.uniform( @@ -136,11 +115,12 @@ def initialize_strategy( p_sigma=jnp.zeros(self.num_dims), sigma=params.sigma_init, mean=initialization, - M=jnp.zeros((self.num_dims, self.memory_size)), + M=jnp.zeros((self.memory_size, self.num_dims)), weights_truncated=weights_truncated, c_d=c_d, c_c=c_c, best_member=initialization, + z=jnp.zeros((self.popsize, self.num_dims)), ) return state @@ -148,7 +128,7 @@ def ask_strategy( self, rng: chex.PRNGKey, state: EvoState, params: EvoParams ) -> Tuple[chex.Array, EvoState]: """`ask` for new parameter candidates to evaluate next.""" - x = sample( + x, z = sample( rng, state.mean, state.sigma, @@ -158,7 +138,7 @@ def ask_strategy( state.c_d, state.gen_counter, ) - return x, state + return x, state.replace(z=z) def tell_strategy( self, @@ -171,87 +151,86 @@ def tell_strategy( # Sort new results, extract elite, store best performer concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x]) sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()] + concat_z_f = jnp.hstack([jnp.expand_dims(fitness, 1), state.z]) + sorted_zvectors = concat_z_f[concat_z_f[:, 0].argsort()] + sorted_z = sorted_zvectors[:, 1:] + if state.best_fitness < sorted_solutions[0, 0]: + state.best_fitness = sorted_solutions[0, 0] + state.best_member = sorted_solutions[0, 1:] # Update mean, isotropic/anisotropic paths, covariance, stepsize - mean, z_k = update_mean( + mean = update_mean( state.mean, - state.sigma, sorted_solutions, + self.elite_popsize, params.c_m, state.weights_truncated, ) - p_sigma, norm_p_sigma = update_p_sigma( - z_k, + p_sigma, norm_p_sigma, wz = update_p_sigma( + sorted_z, + self.elite_popsize, state.p_sigma, params.c_sigma, - params.mu_eff, + params.mu_w, state.weights_truncated, ) M = update_M_matrix( state.M, - z_k, + wz, state.c_c, params.mu_w, - state.weights_truncated, ) sigma = update_sigma( state.sigma, norm_p_sigma, params.c_sigma, params.d_sigma, - params.chi_n, + self.num_dims, ) return state.replace(mean=mean, p_sigma=p_sigma, M=M, sigma=sigma) def update_mean( mean: chex.Array, - sigma: float, sorted_solutions: chex.Array, + elite_popsize: int, c_m: float, weights_truncated: chex.Array, ) -> Tuple[chex.Array, chex.Array]: """Update mean of strategy.""" - z_k = sorted_solutions[:, 1:] - mean # ~ N(0, σ^2 C) - y_k = z_k / sigma # ~ N(0, C) - y_w = jnp.sum(y_k.T * weights_truncated, axis=1) - mean += c_m * sigma * y_w - return mean, z_k + y_k = sorted_solutions[:elite_popsize, 1:] - mean # ~ N(0, σ^2 C) + G_m = weights_truncated.T @ y_k + mean += c_m * G_m + return mean def update_p_sigma( z_k: chex.Array, + elite_popsize: int, p_sigma: chex.Array, c_sigma: float, - mu_eff: float, + mu_w: float, weights_truncated: chex.Array, ) -> Tuple[chex.Array, float]: """Update evolution path for covariance matrix.""" - z_w = jnp.sum(z_k.T * weights_truncated, axis=1) + wz = weights_truncated.T @ (z_k[:elite_popsize, :]) p_sigma_new = (1 - c_sigma) * p_sigma + jnp.sqrt( - c_sigma * (2 - c_sigma) * mu_eff - ) * z_w + mu_w * c_sigma * (2 - c_sigma) + ) * wz norm_p_sigma = jnp.linalg.norm(p_sigma_new) - return p_sigma_new, norm_p_sigma + return p_sigma_new, norm_p_sigma, wz def update_M_matrix( M: chex.Array, - z_k: chex.Array, + wz: chex.Array, c_c: chex.Array, mu_w: float, - weights_truncated: chex.Array, ) -> chex.Array: """Update the M matrix.""" - weighted_elite = jnp.sum( - jnp.array([w * z for w, z in zip(weights_truncated, z_k)]), - axis=0, - ) # Loop over individual memory components - this could be vectorized! - for i in range(M.shape[1]): - new_m = (1 - c_c[i]) * M[:, i] + jnp.sqrt( - mu_w * c_c[i] * (2 - c_c[i]) - ) * weighted_elite - M = M.at[:, i].set(new_m) + for i in range(M.shape[0]): + new_m = (1 - c_c[i]) * M[i, :] + jnp.sqrt(mu_w * c_c[i] * (2 - c_c[i])) * wz + M = M.at[i, :].set(new_m) return M @@ -260,11 +239,11 @@ def update_sigma( norm_p_sigma: float, c_sigma: float, d_sigma: float, - chi_n: float, + n: int, ) -> float: """Update stepsize sigma.""" sigma_new = sigma * jnp.exp( - (c_sigma / d_sigma) * (norm_p_sigma / chi_n - 1) + (c_sigma / d_sigma) * (jnp.square(norm_p_sigma) / n - 1) ) return sigma_new @@ -278,15 +257,13 @@ def sample( pop_size: int, c_d: chex.Array, gen_counter: int, -) -> chex.Array: +) -> tuple[chex.Array, chex.Array]: """Jittable Gaussian Sample Helper.""" - z = jax.random.normal(rng, (n_dim, pop_size)) # ~ N(0, I) - for j in range(M.shape[1]): + z = jax.random.normal(rng, (pop_size, n_dim)) # ~ N(0, I) + d = jnp.copy(z) + for j in range(M.shape[0]): update_bool = gen_counter > j - new_z = (1 - c_d[j]) * z + (c_d[j] * M[:, j])[:, jnp.newaxis] * ( - M[:, j][:, jnp.newaxis] * z - ) - z = jax.lax.select(update_bool, new_z, z) - z = jnp.swapaxes(z, 1, 0) - x = mean + sigma * z # ~ N(m, σ^2 C) - return x + new_d = (1 - c_d[j]) * d + c_d[j] * jnp.outer(jnp.dot(d, M[j, :]), M[j, :]) + d = jax.lax.select(update_bool, new_d, d) + x = mean + sigma * d # ~ N(m, σ^2 C) + return x, z From 2381be70f27b88b2f12a8d98458e58f86b671171 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Sun, 4 Jun 2023 15:22:32 +0900 Subject: [PATCH 2/7] Fix import --- evosax/strategies/lm_ma_es.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index 6cb1135..5c42b0a 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -1,10 +1,9 @@ from typing import Optional, Tuple, Union - -import chex import jax import jax.numpy as jnp -from evosax.strategy import Strategy +import chex from flax import struct +from ..strategy import Strategy @struct.dataclass From 8e069a0802044e78631b84d8a8e211b1486b87b3 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Sun, 4 Jun 2023 15:54:32 +0900 Subject: [PATCH 3/7] Fix LMMAES --- evosax/strategies/lm_ma_es.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index 5c42b0a..5a63e03 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -153,9 +153,6 @@ def tell_strategy( concat_z_f = jnp.hstack([jnp.expand_dims(fitness, 1), state.z]) sorted_zvectors = concat_z_f[concat_z_f[:, 0].argsort()] sorted_z = sorted_zvectors[:, 1:] - if state.best_fitness < sorted_solutions[0, 0]: - state.best_fitness = sorted_solutions[0, 0] - state.best_member = sorted_solutions[0, 1:] # Update mean, isotropic/anisotropic paths, covariance, stepsize mean = update_mean( state.mean, From 6d8b3aeb5c31d152ddac7507163a2432c0e50c40 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Sun, 4 Jun 2023 23:39:41 +0900 Subject: [PATCH 4/7] Fix LMMAES; add update best fitness and best member. --- evosax/strategies/lm_ma_es.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index 5a63e03..ac38797 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -153,6 +153,16 @@ def tell_strategy( concat_z_f = jnp.hstack([jnp.expand_dims(fitness, 1), state.z]) sorted_zvectors = concat_z_f[concat_z_f[:, 0].argsort()] sorted_z = sorted_zvectors[:, 1:] + new_best_fitness = jax.lax.select( + state.best_fitness > sorted_solutions[0, 0], + sorted_solutions[0, 0], + state.best_fitness, + ) + new_best_member = jax.lax.select( + state.best_fitness > sorted_solutions[0, 0], + sorted_solutions[0, 1:], + state.best_member, + ) # Update mean, isotropic/anisotropic paths, covariance, stepsize mean = update_mean( state.mean, @@ -182,7 +192,14 @@ def tell_strategy( params.d_sigma, self.num_dims, ) - return state.replace(mean=mean, p_sigma=p_sigma, M=M, sigma=sigma) + return state.replace( + mean=mean, + p_sigma=p_sigma, + M=M, + sigma=sigma, + best_fitness=new_best_fitness, + best_member=new_best_member, + ) def update_mean( From 7f9476c04fa61dfdd60030f9330262b12e9340c6 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Mon, 5 Jun 2023 00:47:00 +0900 Subject: [PATCH 5/7] Fix LMMAES updating best_fitness and best_menber are done in Strategy's tell function so we needn't it. --- evosax/strategies/lm_ma_es.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index ac38797..66a8100 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -147,22 +147,12 @@ def tell_strategy( params: chex.ArrayTree, ) -> EvoState: """`tell` performance data for strategy state update.""" - # Sort new results, extract elite, store best performer + # Sort new results, extract elite concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x]) sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()] concat_z_f = jnp.hstack([jnp.expand_dims(fitness, 1), state.z]) sorted_zvectors = concat_z_f[concat_z_f[:, 0].argsort()] sorted_z = sorted_zvectors[:, 1:] - new_best_fitness = jax.lax.select( - state.best_fitness > sorted_solutions[0, 0], - sorted_solutions[0, 0], - state.best_fitness, - ) - new_best_member = jax.lax.select( - state.best_fitness > sorted_solutions[0, 0], - sorted_solutions[0, 1:], - state.best_member, - ) # Update mean, isotropic/anisotropic paths, covariance, stepsize mean = update_mean( state.mean, From 81a1cc264ea73369690e0ab3c41c535e03611bab Mon Sep 17 00:00:00 2001 From: arahi10 Date: Mon, 5 Jun 2023 00:50:47 +0900 Subject: [PATCH 6/7] Fix LMMAES updating best_fitness and best_menber are done in Strategy's tell function so we needn't it. --- evosax/strategies/lm_ma_es.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index 66a8100..b53549c 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -187,8 +187,6 @@ def tell_strategy( p_sigma=p_sigma, M=M, sigma=sigma, - best_fitness=new_best_fitness, - best_member=new_best_member, ) From f17f77e6465427b806d178a01bc034acbbb44af6 Mon Sep 17 00:00:00 2001 From: arahi10 Date: Mon, 5 Jun 2023 08:13:18 +0900 Subject: [PATCH 7/7] fit LMMAES fix type check typo --- evosax/strategies/lm_ma_es.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index b53549c..3b09be7 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -258,7 +258,7 @@ def sample( pop_size: int, c_d: chex.Array, gen_counter: int, -) -> tuple[chex.Array, chex.Array]: +) -> Tuple[chex.Array, chex.Array]: """Jittable Gaussian Sample Helper.""" z = jax.random.normal(rng, (pop_size, n_dim)) # ~ N(0, I) d = jnp.copy(z)