Skip to content

Commit

Permalink
Refactor RST_NSGAIII code, add serialization, prepare experiment
Browse files Browse the repository at this point in the history
Add parameters exploration and exploitation num generations
Add parameters numElicitations1,2 and elicitationFrequency

Next step is to investigate why MinXZ does not work as good as Central
Chebyshev Ranker. Should also change rho value from 0 to small positive,
should also replace ideal point located in origin with adjustable 
one. Also should refactor Gradient Lambda Search - remove unused 
improvement directions
  • Loading branch information
tomekster committed Apr 11, 2017
1 parent 6beda51 commit e1cc8fc
Show file tree
Hide file tree
Showing 23 changed files with 746 additions and 490 deletions.
50 changes: 50 additions & 0 deletions src/core/Evaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package core;

import core.points.Solution;
import history.ExecutionHistory;
import solutionRankers.ChebyshevRanker;
import utils.Geometry;
import utils.MyMath;

public class Evaluator {
public static void evaluateRun(Problem prob, ChebyshevRanker dmr, Population res) {
String pname = prob.getName();
double targetPoint[] = {};

switch(pname){
case "DTLZ1":
targetPoint = Geometry.lineCrossDTLZ1HyperplanePoint(Geometry.invert(dmr.getLambda()));
break;
case "DTLZ2":
case "DTLZ3":
case "DTLZ4":
targetPoint = Geometry.lineCrossDTLZ234HyperspherePoint(Geometry.invert(dmr.getLambda()));
break;
}
System.out.println("TARGET POINT: ");
for(double d : targetPoint){
System.out.print(d + " ");
}
System.out.println();

System.out.println("PREF: ");
for(int i=0; i< prob.getNumObjectives(); i++){
double min = Double.MAX_VALUE, sum = 0, max = -Double.MAX_VALUE;
for(Solution s : res.getSolutions()){
double o = s.getObjective(i);
min = Double.min(min, o);
max = Double.max(max, o);
sum += o;
}

System.out.println(i + ": " + min + ", " + sum/res.getSolutions().size() + ", ");
}

if(targetPoint.length > 0){
ExecutionHistory.getInstance().setFinalMinDist(MyMath.getMinDist(targetPoint, res));
ExecutionHistory.getInstance().setFinalAvgDist(MyMath.getAvgDist(targetPoint, res));
}

ExecutionHistory.getInstance().setLambdasConverged(Lambda.getInstance().converged());
}
}
100 changes: 37 additions & 63 deletions src/core/Lambda.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,25 @@ public class Lambda {

private final static Logger LOGGER = Logger.getLogger(Lambda.class.getName());

private static Lambda instance = null;

private int numObjectives;
private int numLambdas;
private ArrayList <ReferencePoint> lambdas;
private GradientLambdaSearch GLS;
protected Lambda(int numObjectives, int numLambdas) {

protected Lambda(){
// Exists only to defeat instantiation.
}

public static Lambda getInstance(){
if (instance == null){
instance = new Lambda();
}
return instance;
}

public void init(int numObjectives, int numLambdas) {
this.numObjectives = numObjectives;
this.numLambdas = numLambdas;
lambdas = new ArrayList<>();
Expand All @@ -35,7 +49,7 @@ protected Lambda(int numObjectives, int numLambdas) {
}
GLS = new GradientLambdaSearch(numObjectives);
}

private ReferencePoint getRandomLambda() {
ArrayList <Double> breakPoints = new ArrayList<>();
ArrayList <Double> dimensions = new ArrayList<>();
Expand Down Expand Up @@ -107,67 +121,6 @@ protected ArrayList <ReferencePoint> selectNewLambdas(ArrayList <ReferencePoint>
return new ArrayList<ReferencePoint>(lambdasPop.subList(0, numLambdas));
}

public Population selectKSolutionsByChebyshevBordaRanking(Population pop, int k) {
HashMap<Solution, Integer> bordaPointsMap = getBordaPointsForSolutions(pop);

ArrayList<Pair<Solution, Integer>> pairs = new ArrayList<Pair<Solution, Integer>>();

for (Solution s : bordaPointsMap.keySet()) {
pairs.add(new Pair<Solution, Integer>(s, bordaPointsMap.get(s)));
}

Collections.sort(pairs, new Comparator<Pair<Solution, Integer>>() {
@Override
public int compare(final Pair<Solution, Integer> o1, final Pair<Solution, Integer> o2) {
return Integer.compare(o2.second, o1.second); // Sort DESC by Borda points
}
});

Population res = new Population();
for (int i = 0; i < k; i++) {
res.addSolution(pairs.get(i).first.copy());
}
return res;
}

private HashMap<Solution, Integer> getBordaPointsForSolutions(Population pop) {
HashMap<Solution, Integer> bordaPointsMap = new HashMap<>();
for (ReferencePoint lambda : lambdas) {
ArrayList<Solution> ranking = buildSolutionsRanking(lambda, pop);
assert ranking.size() == pop.size();
for (int i = 0; i < ranking.size(); i++) {
Solution s = ranking.get(i);
if (!bordaPointsMap.containsKey(s)) {
bordaPointsMap.put(s, 0);
}
bordaPointsMap.put(s, bordaPointsMap.get(s) + (ranking.size() - i)/(lambda.getNumViolations() + 1));
}
}
return bordaPointsMap;
}

public static ArrayList<Solution> buildSolutionsRanking(ReferencePoint lambda, Population pop) {
ArrayList<Pair<Solution, Double>> solutionValuePairs = new ArrayList<Pair<Solution, Double>>();
for (Solution s : pop.getSolutions()) {
double chebyshevValue = ChebyshevRanker.eval(s, null, Geometry.invert(lambda.getDim()), 0);
solutionValuePairs.add(new Pair<Solution, Double>(s, chebyshevValue));
}
Collections.sort(solutionValuePairs, new Comparator<Pair<Solution, Double>>() {
@Override
public int compare(final Pair<Solution, Double> o1, final Pair<Solution, Double> o2) {
// Sort pairs by Chebyshev Function value ascending (Decreasing quality)
return Double.compare(o1.second, o2.second);
}
});

ArrayList<Solution> ranking = new ArrayList<Solution>();
for (Pair<Solution, Double> p : solutionValuePairs) {
ranking.add(p.first);
}
assert ranking.size() == pop.size();
return ranking;
}

public ArrayList<ReferencePoint> getLambdas() {
return this.lambdas;
}
Expand Down Expand Up @@ -200,4 +153,25 @@ public void nextGeneration() {
public void setLambdas(ArrayList<ReferencePoint> lambdas){
this.lambdas = lambdas;
}

public boolean converged(){
double min[] = new double[numObjectives];
double max[] = new double[numObjectives];
for(int i=0; i<numObjectives; i++){
min[i] = Double.MAX_VALUE;
max[i] = -Double.MAX_VALUE;
}

for(ReferencePoint rp : lambdas){
for(int i=0; i<numObjectives; i++){
if(rp.getDim(i) < min[i]){ min[i] = rp.getDim(i); }
if(rp.getDim(i) > max[i]){ max[i] = rp.getDim(i); }
}
}

for(int i=0; i<numObjectives; i++){
if( max[i] - min[i] > 0.001) return false;
}
return true;
}
}
75 changes: 60 additions & 15 deletions src/core/NSGAIIIParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
public class NSGAIIIParameters {
private String problemName;
private int numberObjectives;
private int numberGenerations;
private int numberExplorationGenerations;
private int numberExploitationGenerations;
private int numElicitations1;
private int numElicitations2;
private int numberRuns;
private int elicitationInterval;

Expand All @@ -19,13 +22,31 @@ public class NSGAIIIParameters {

private static NSGAIIIParameters instance = null;

// if(problem.getNumObjectives() == 3){
// numElicitations1 = 50;
// numElicitations2 = 30;
// }
// if(problem.getNumObjectives() == 5){
// numElicitations1 = 70;
// numElicitations2 = 30;
// }
// if(problem.getNumObjectives() == 8){
// numElicitations1 = 100;
// numElicitations2 = 30;
// }

protected NSGAIIIParameters(){
// Exists only to defeat instantiation.
problemName = "DTLZ1";
numberObjectives = 3;
numberGenerations = 50;
numberExplorationGenerations = 100;
numberExploitationGenerations = 100;
numElicitations1 = 50;
numElicitations2 = 30;
setNumElicitations1(50);
setNumElicitations2(30);
numberRuns = 1;
elicitationInterval = 20;
elicitationInterval = 1;
showTargetPoints = true;
showLambdas = true;
showComparisons = true;
Expand All @@ -46,8 +67,12 @@ public int getNumberObjectives() {
return numberObjectives;
}

public int getNumberGenerations() {
return numberGenerations;
public int getNumberExplorationGenerations() {
return numberExplorationGenerations;
}

public int getNumberExploitationGenerations() {
return numberExploitationGenerations;
}

public int getNumberRuns() {
Expand All @@ -66,10 +91,6 @@ public boolean isShowComparisons() {
return showComparisons;
}

public int getElicitationInterval() {
return elicitationInterval;
}

public void setProblemName(String problemName) {
this.problemName = problemName;
}
Expand All @@ -78,18 +99,18 @@ public void setNumberObjectives(int numberObjectives) {
this.numberObjectives = numberObjectives;
}

public void setNumberGenerations(int numberGenerations) {
this.numberGenerations = numberGenerations;
public void setNumberExplorationGenerations(int numberGenerations) {
this.numberExplorationGenerations = numberGenerations;
}

public void setNumberExploitationGenerations(int numberGenerations) {
this.numberExploitationGenerations = numberGenerations;
}

public void setNumberRuns(int numberRuns) {
this.numberRuns = numberRuns;
}

public void setElicitationFrequency(int elicitationInterval) {
this.elicitationInterval = elicitationInterval;
}

public void setShowTargetPoints(boolean showTargetPoints) {
this.showTargetPoints = showTargetPoints;
}
Expand All @@ -102,4 +123,28 @@ public void setShowComparisons(boolean showComparisons) {
this.showComparisons = showComparisons;
}

public int getElicitationInterval() {
return elicitationInterval;
}

public void setElicitationInterval(int elicitationInterval) {
this.elicitationInterval = elicitationInterval;
}

public int getNumElicitations1() {
return numElicitations1;
}

public void setNumElicitations1(int numElicitations1) {
this.numElicitations1 = numElicitations1;
}

public int getNumElicitations2() {
return numElicitations2;
}

public void setNumElicitations2(int numElicitations2) {
this.numElicitations2 = numElicitations2;
}

}
17 changes: 15 additions & 2 deletions src/core/NSGAIIIRunnner.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;

import core.points.Solution;
import history.ExecutionHistory;
import solutionRankers.ChebyshevRanker;
import solutionRankers.ChebyshevRankerBuilder;
import utils.Geometry;
import utils.MyMath;

public class NSGAIIIRunnner {

Expand All @@ -26,15 +31,23 @@ public static void runNSGAIII() {
}

RST_NSGAIII alg = null;
Problem problem = null;
ChebyshevRanker cr = null;
try {
alg = new RST_NSGAIII((Problem) problemConstructor.newInstance(params.getNumberObjectives()), params.getNumberGenerations(), params.getElicitationInterval(), ChebyshevRankerBuilder.getCentralChebyshevRanker(params.getNumberObjectives()));
problem = (Problem) problemConstructor.newInstance(params.getNumberObjectives());
cr = ChebyshevRankerBuilder.getCentralChebyshevRanker(params.getNumberObjectives());
alg = new RST_NSGAIII(problem, params.getNumberExplorationGenerations(), params.getNumberExploitationGenerations(), params.getNumElicitations1(), params.getNumElicitations2(), params.getElicitationInterval(), cr);
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
alg.run();
// alg.run();

Evaluator.evaluateRun(problem, cr, alg.getPopulation());
ExecutionHistory history = ExecutionHistory.getInstance();
System.out.println("Generation min: " + history.getFinalMinDist());
System.out.println("Generation avg: " + history.getFinalAvgDist());
}

}
Expand Down
12 changes: 11 additions & 1 deletion src/core/Population.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package core;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import core.points.Solution;

public class Population{
public class Population implements Serializable{
/**
*
*/
private static final long serialVersionUID = -5522918817753869897L;
private ArrayList <Solution> solutions = null;

public Population(){
Expand All @@ -20,6 +26,10 @@ public Population(Population pop) {
}
}

public Population(List<Solution> solList) {
this.solutions = new ArrayList<>(solList);
}

public void addSolution(Solution sol){
this.solutions.add(sol);
}
Expand Down
8 changes: 7 additions & 1 deletion src/core/Problem.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package core;

import java.io.Serializable;

import core.points.Solution;
import utils.NSGAIIIRandom;

public abstract class Problem {
public abstract class Problem implements Serializable {
/**
*
*/
private static final long serialVersionUID = -5151466907576488480L;
private int numVariables = 0;
private int numObjectives = 0;
private int numConstraints = 0;
Expand Down
Loading

0 comments on commit e1cc8fc

Please sign in to comment.