Skip to content

Commit

Permalink
Merge pull request guofei9987#164 from samueljsluo/master
Browse files Browse the repository at this point in the history
Add early stop feature for GA
  • Loading branch information
guofei9987 authored Jan 10, 2022
2 parents ceeca2a + b0676d5 commit 2e942b1
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions sko/GA.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
class GeneticAlgorithmBase(SkoBase, metaclass=ABCMeta):
def __init__(self, func, n_dim,
size_pop=50, max_iter=200, prob_mut=0.001,
constraint_eq=tuple(), constraint_ueq=tuple()):
constraint_eq=tuple(), constraint_ueq=tuple(), early_stop=None):
self.func = func_transformer(func)
assert size_pop % 2 == 0, 'size_pop must be even integer'
self.size_pop = size_pop # size of population
self.max_iter = max_iter
self.prob_mut = prob_mut # probability of mutation
self.n_dim = n_dim
self.early_stop = early_stop

# constraint:
self.has_constraint = len(constraint_eq) > 0 or len(constraint_ueq) > 0
Expand Down Expand Up @@ -75,6 +76,7 @@ def mutation(self):

def run(self, max_iter=None):
self.max_iter = max_iter or self.max_iter
best = []
for i in range(self.max_iter):
self.X = self.chrom2x(self.Chrom)
self.Y = self.x2y()
Expand All @@ -90,6 +92,14 @@ def run(self, max_iter=None):
self.all_history_Y.append(self.Y)
self.all_history_FitV.append(self.FitV)

if self.early_stop:
best.append(min(self.generation_best_Y))
if len(best) >= self.early_stop:
if best.count(min(best)) == len(best):
break
else:
best.pop(0)

global_best_index = np.array(self.generation_best_Y).argmin()
self.best_x = self.generation_best_X[global_best_index]
self.best_y = self.func(np.array([self.best_x]))
Expand Down Expand Up @@ -141,8 +151,8 @@ def __init__(self, func, n_dim,
prob_mut=0.001,
lb=-1, ub=1,
constraint_eq=tuple(), constraint_ueq=tuple(),
precision=1e-7):
super().__init__(func, n_dim, size_pop, max_iter, prob_mut, constraint_eq, constraint_ueq)
precision=1e-7, early_stop=None):
super().__init__(func, n_dim, size_pop, max_iter, prob_mut, constraint_eq, constraint_ueq, early_stop)

self.lb, self.ub = np.array(lb) * np.ones(self.n_dim), np.array(ub) * np.ones(self.n_dim)
self.precision = np.array(precision) * np.ones(self.n_dim) # works when precision is int, float, list or array
Expand Down Expand Up @@ -244,7 +254,7 @@ def chrom2x(self, Chrom):
register('chrom2x', chrom2x)

return self

class RCGA(GeneticAlgorithmBase):
"""real-coding genetic algorithm
Expand Down

0 comments on commit 2e942b1

Please sign in to comment.