diff --git a/sko/GA.py b/sko/GA.py index cded290..4662ff6 100644 --- a/sko/GA.py +++ b/sko/GA.py @@ -14,13 +14,14 @@ class GeneticAlgorithmBase(SkoBase, metaclass=ABCMeta): def __init__(self, func, n_dim, size_pop=50, max_iter=200, prob_mut=0.001, - constraint_eq=tuple(), constraint_ueq=tuple()): + constraint_eq=tuple(), constraint_ueq=tuple(), early_stop=None): self.func = func_transformer(func) assert size_pop % 2 == 0, 'size_pop must be even integer' self.size_pop = size_pop # size of population self.max_iter = max_iter self.prob_mut = prob_mut # probability of mutation self.n_dim = n_dim + self.early_stop = early_stop # constraint: self.has_constraint = len(constraint_eq) > 0 or len(constraint_ueq) > 0 @@ -75,6 +76,7 @@ def mutation(self): def run(self, max_iter=None): self.max_iter = max_iter or self.max_iter + best = [] for i in range(self.max_iter): self.X = self.chrom2x(self.Chrom) self.Y = self.x2y() @@ -90,6 +92,14 @@ def run(self, max_iter=None): self.all_history_Y.append(self.Y) self.all_history_FitV.append(self.FitV) + if self.early_stop: + best.append(min(self.generation_best_Y)) + if len(best) >= self.early_stop: + if best.count(min(best)) == len(best): + break + else: + best.pop(0) + global_best_index = np.array(self.generation_best_Y).argmin() self.best_x = self.generation_best_X[global_best_index] self.best_y = self.func(np.array([self.best_x])) @@ -141,8 +151,8 @@ def __init__(self, func, n_dim, prob_mut=0.001, lb=-1, ub=1, constraint_eq=tuple(), constraint_ueq=tuple(), - precision=1e-7): - super().__init__(func, n_dim, size_pop, max_iter, prob_mut, constraint_eq, constraint_ueq) + precision=1e-7, early_stop=None): + super().__init__(func, n_dim, size_pop, max_iter, prob_mut, constraint_eq, constraint_ueq, early_stop) self.lb, self.ub = np.array(lb) * np.ones(self.n_dim), np.array(ub) * np.ones(self.n_dim) self.precision = np.array(precision) * np.ones(self.n_dim) # works when precision is int, float, list or array @@ -244,7 +254,7 @@ def chrom2x(self, Chrom): register('chrom2x', chrom2x) return self - + class RCGA(GeneticAlgorithmBase): """real-coding genetic algorithm