Skip to content

Commit

Permalink
[ISSUE-271] support categorical feature
Browse files Browse the repository at this point in the history
  • Loading branch information
bluesjjw committed Dec 19, 2017
1 parent 6f44025 commit 06f8eef
Show file tree
Hide file tree
Showing 13 changed files with 387 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public AngelConf(Configuration conf) {
/** Training data path. */
public static final String ANGEL_TRAIN_DATA_PATH = "angel.train.data.path";

/** Training data path. */
/** Predict data path. */
public static final String ANGEL_PREDICT_DATA_PATH = "angel.predict.data.path";

/** Input data path use by Angel */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ public void setConf() {
conf.setBoolean("mapred.mapper.new-api", true);
conf.setBoolean(AngelConf.ANGEL_JOB_OUTPUT_PATH_DELETEONEXIST, true);

// Use local deploy mode and dummy data spliter
// Use local deploy mode and dummy data format
conf.set(AngelConf.ANGEL_DEPLOY_MODE, "LOCAL");
conf.set(MLConf.ML_DATA_FORMAT(), String.valueOf(dataFmt));

// set input, output path
conf.set(AngelConf.ANGEL_INPUTFORMAT_CLASS, CombineTextInputFormat.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class GBDTLearner(override val ctx: TaskContext) extends MLLearner(ctx) {
param.gradHistNamePrefix = GBDTModel.GRAD_HIST_MAT_PREFIX
param.activeTreeNodesName = GBDTModel.ACTIVE_NODE_MAT
param.sampledFeaturesName = GBDTModel.FEAT_SAMPLE_MAT
param.cateFeatureName = GBDTModel.FEAT_CATEGORY_MAT
param.splitFeaturesName = GBDTModel.SPLIT_FEAT_MAT
param.splitValuesName = GBDTModel.SPLIT_VALUE_MAT
param.splitGainsName = GBDTModel.SPLIT_GAIN_MAT
Expand Down Expand Up @@ -175,6 +176,7 @@ class GBDTLearner(override val ctx: TaskContext) extends MLLearner(ctx) {
if (!nextClock && controller.phase == GBDTPhase.CREATE_SKETCH) {
LOG.info(s"******Current phase: CREATE_SKETCH, clock[${controller.clock}]******")
controller.createSketch
controller.mergeCateFeatSketch
controller.setPhase(GBDTPhase.GET_SKETCH)
nextClock = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration

object GBDTModel {

val SKETCH_MAT: String = "gbdt.sketch"
val GRAD_HIST_MAT_PREFIX: String = "gbdt.grad.histogram.node"
val ACTIVE_NODE_MAT: String = "gbdt.active.nodes"
val FEAT_SAMPLE_MAT: String = "gbdt.feature.sample."
val FEAT_SAMPLE_MAT: String = "gbdt.feature.sample"
val FEAT_CATEGORY_MAT = "gbdt.feature.category"
val SPLIT_FEAT_MAT: String = "gbdt.split.feature"
val SPLIT_VALUE_MAT: String = "gbdt.split.value"
val SPLIT_GAIN_MAT: String = "gbdt.split.gain"
Expand All @@ -61,11 +63,14 @@ class GBDTModel(conf: Configuration, _ctx: TaskContext = null) extends MLModel(c
val maxTreeDepth = conf.getInt(MLConf.ML_GBDT_TREE_DEPTH, MLConf.DEFAULT_ML_GBDT_TREE_DEPTH)
val splitNum = conf.getInt(MLConf.ML_GBDT_SPLIT_NUM, MLConf.DEFAULT_ML_GBDT_SPLIT_NUM)
val featSampleRatio = conf.getFloat(MLConf.ML_GBDT_SAMPLE_RATIO, MLConf.DEFAULT_ML_GBDT_SAMPLE_RATIO)
val cateFeatStr = conf.get(MLConf.ML_GBDT_CATE_FEAT, MLConf.DEFAULT_ML_GBDT_CATE_FEAT)
val cateFeatNum = if (cateFeatStr.contains(",")) cateFeatStr.split(",").size else 1

val maxTNodeNum: Int = Maths.pow(2, maxTreeDepth) - 1

// # parameter server
val psNumber = conf.getInt(AngelConf.ANGEL_PS_NUMBER, 1)
val workerNumber = conf.getInt(AngelConf.ANGEL_WORKERGROUP_ACTUAL_NUM, 1)

// adjust feature number to ensure the parameter partition
if (featNum % psNumber != 0) {
Expand Down Expand Up @@ -140,6 +145,15 @@ class GBDTModel(conf: Configuration, _ctx: TaskContext = null) extends MLModel(c
.setOplogType("DENSE_DOUBLE")
addPSModel(NODE_PRED_MAT, nodePred)

// Matrix 10: categorical feature
val featCategory = PSModel(FEAT_CATEGORY_MAT, workerNumber, cateFeatNum * splitNum, 1, cateFeatNum * splitNum)
.setRowType(RowType.T_DOUBLE_DENSE)
.setOplogType("DENSE_DOUBLE")
.setNeedSave(false)
addPSModel(FEAT_CATEGORY_MAT, featCategory)



super.setSavePath(conf)
super.setLoadPath(conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package com.tencent.angel.ml.GBDT.algo;


import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.ml.GBDT.GBDTModel;
import com.tencent.angel.ml.GBDT.algo.RegTree.*;
import com.tencent.angel.ml.GBDT.algo.tree.SplitEntry;
import com.tencent.angel.ml.GBDT.algo.tree.TYahooSketchSplit;
import com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc;
import com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowResult;
import com.tencent.angel.ml.GBDT.algo.RegTree.*;
import com.tencent.angel.ml.GBDT.psf.HistAggrParam;
import com.tencent.angel.ml.conf.MLConf;
import com.tencent.angel.ml.math.vector.*;
Expand All @@ -32,8 +35,6 @@
import com.tencent.angel.ml.objective.ObjFunc;
import com.tencent.angel.ml.objective.RegLossObj;
import com.tencent.angel.ml.param.GBDTParam;
import com.tencent.angel.ml.GBDT.algo.tree.SplitEntry;
import com.tencent.angel.ml.GBDT.algo.tree.TYahooSketchSplit;
import com.tencent.angel.ml.utils.Maths;
import com.tencent.angel.worker.task.TaskContext;
import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -66,8 +67,10 @@ public class GBDTController {
// gradient and hessian
public List<GradPair> gradPairs = new ArrayList<>();

public float sketches[]; // size: featureNum * splitNum
public int[] fset; // sampled features in the current tree
public float[] sketches; // size: featureNum * splitNum
public List<Integer> cateFeatList; // categorical feature set, null: none, empty: all, else: partial
public Map<Integer, Integer> cateFeatNum; // number of splits of categorical features
public int[] fSet; // sampled features in the current tree
public int[] activeNode; // active tree node, 1:active, 0:inactive
public int[] activeNodeStat; // >=1:running, 0:finished, -1:failed
public int[] instancePos; // map tree node to instance, each item is instance id
Expand Down Expand Up @@ -101,6 +104,31 @@ public void init() throws Exception {
LossHelper loss = new Loss.BinaryLogisticLoss();
objfunc = new RegLossObj(loss);
this.sketches = new float[this.param.numFeature * this.param.numSplit];

String cateFeatStr = this.taskContext.getConf().get(MLConf.ML_GBDT_CATE_FEAT());
cateFeatList = new ArrayList<>();
cateFeatNum = new HashMap<>();
switch (cateFeatStr) {
case "all":
for (int fid = 0; fid < this.param.numFeature; fid++) {
cateFeatList.add(fid);
}
break;
case "none":
break;
default:
String[] splits = cateFeatStr.split(",");
for (int i = 0; i < splits.length; i++) {
String[] fidAndNum = splits[i].split(":");
int fid = Integer.parseInt(fidAndNum[0]);
int num = Integer.parseInt(fidAndNum[1]);
assert num < this.param.numSplit;
if (!cateFeatList.contains(fid)) {
cateFeatList.add(fid);
}
}
}

this.maxNodeNum = Maths.pow(2, this.param.maxDepth) - 1;
this.activeNode = new int[maxNodeNum];
this.activeNodeStat = new int[maxNodeNum];
Expand Down Expand Up @@ -152,33 +180,115 @@ private void calGradPairs() {
// create data sketch, push candidate split value to PS
public void createSketch() throws Exception {
PSModel sketch = model.getPSModel(this.param.sketchName);
PSModel cateFeat = model.getPSModel(this.param.cateFeatureName);
if (taskContext.getTaskIndex() == 0) {
LOG.info("------Create sketch------");
long startTime = System.currentTimeMillis();
DenseDoubleVector sketchVec =
new DenseDoubleVector(this.param.numFeature * this.param.numSplit);
DenseDoubleVector cateFeatVec = null;
if (!this.cateFeatList.isEmpty()) {
cateFeatVec = new DenseDoubleVector(this.cateFeatList.size() * this.param.numSplit);
}

// 1. calculate candidate split value
float[][] splits = TYahooSketchSplit.getSplitValue(this.trainDataStore, this.param.numSplit);
float[][] splits = TYahooSketchSplit.getSplitValue(this.trainDataStore, this.param.numSplit,
this.cateFeatList);

LOG.info("Splits of feature 242883: " + Arrays.toString(splits[242883]));

if (splits.length == this.param.numFeature && splits[0].length == this.param.numSplit) {
for (int fid = 0; fid < splits.length; fid++) {
LOG.debug(String.format("Candidate splits of fid[%d]: %s",
fid, Arrays.toString(splits[fid])));
if (cateFeatList.contains(fid)) {
continue;
}
for (int j = 0; j < splits[fid].length; j++) {
sketchVec.set(fid * this.param.numSplit + j, splits[fid][j]);
}
}
} else {
LOG.error("Incompatible sketches size.");
}

// categorical features
if (!this.cateFeatList.isEmpty()) {
Collections.sort(this.cateFeatList);
for (int i = 0; i < this.cateFeatList.size(); i++) {
int fid = this.cateFeatList.get(i);
int start = i * this.param.numSplit;
for (int j = 0; j < splits[fid].length; j++) {
if (splits[fid][j] == 0 && j > 0)
break;
cateFeatVec.set(start + j, splits[fid][j]);
}
}
}

// 2. push local sketch to PS
sketch.increment(0, sketchVec);
if (null != cateFeatVec) {
cateFeat.increment(this.taskContext.getTaskIndex(), cateFeatVec);
}
// 3. set phase to GET_SKETCH
this.phase = GBDTPhase.GET_SKETCH;
LOG.info(String.format("Create sketch cost: %d ms", System.currentTimeMillis() - startTime));
}

Set<String> needFlushMatrixSet = new HashSet<String>(1);
needFlushMatrixSet.add(this.param.sketchName);
needFlushMatrixSet.add(this.param.cateFeatureName);
clockAllMatrix(needFlushMatrixSet, true);
}

public void mergeCateFeatSketch() throws Exception {

LOG.info("------Merge categorical features------");

Set<String> needFlushMatrixSet = new HashSet<String>(1);

// the leader worker
if (!this.cateFeatList.isEmpty() && this.taskContext.getTaskIndex() == 0) {

PSModel cateFeat = model.getPSModel(this.param.cateFeatureName);
PSModel sketch = model.getPSModel(this.param.sketchName);

Set<Double>[] featSet = new HashSet[cateFeatList.size()];
for (int i = 0; i < cateFeatList.size(); i++) {
featSet[i] = new HashSet<>();
}

int workerNum = this.taskContext.getConf().getInt(AngelConf.ANGEL_WORKERGROUP_ACTUAL_NUM, 1);

// merge categorical features
for (int worker = 0; worker < workerNum; worker++) {
DenseDoubleVector vec = (DenseDoubleVector) cateFeat.getRow(worker);
for (int i = 0; i < cateFeatList.size(); i++) {
int fid = cateFeatList.get(i);
int start = i * this.param.numSplit;
for (int j = 0; j < this.param.numSplit; j++) {
double fvalue = vec.get(start + j);
featSet[i].add(fvalue);
}
}
}

// create updates
SparseDoubleVector cateFeatVec = new SparseDoubleVector(this.param.numFeature * this.param.numSplit);
for (int i = 0; i < cateFeatList.size(); i++) {
int fid = cateFeatList.get(i);
int start = fid * this.param.numSplit;
List<Double> sortedValue = new ArrayList<>(featSet[i]);
Collections.sort(sortedValue);
assert sortedValue.size() < this.param.numSplit;
for (int j = 0; j < sortedValue.size(); j++) {
cateFeatVec.set(start + j, sortedValue.get(j));
}
}

sketch.increment(0, cateFeatVec);
needFlushMatrixSet.add(this.param.sketchName);
}

clockAllMatrix(needFlushMatrixSet, true);
}

Expand All @@ -189,9 +299,26 @@ public void getSketch() throws Exception {
long startTime = System.currentTimeMillis();
DenseDoubleVector sketchVector = (DenseDoubleVector) sketch.getRow(0);
LOG.info(String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));

for (int i = 0; i < sketchVector.getDimension(); i++) {
this.sketches[i] = (float) sketchVector.get(i);
}

// number of categorical feature
for (int i = 0; i < cateFeatList.size(); i++) {
int fid = cateFeatList.get(i);
int start = fid * this.param.numSplit;
int splitNum = 1;
for (int j = 0; j < this.param.numSplit; j++) {
if (this.sketches[start + j + 1] > this.sketches[start + j] ) {
splitNum++;
} else
break;
}
this.cateFeatNum.put(fid, splitNum);
}

LOG.info("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString());
this.phase = GBDTPhase.NEW_TREE;
}

Expand Down Expand Up @@ -233,14 +360,14 @@ public void createNewTree() throws Exception {
PSModel featSample = model.getPSModel(this.param.sampledFeaturesName);
DenseIntVector sampleFeatureVector =
(DenseIntVector) featSample.getRow(this.currentTree);
this.fset = sampleFeatureVector.getValues();
this.fSet = sampleFeatureVector.getValues();
this.forest[this.currentTree].fset = sampleFeatureVector.getValues();
} else {
// 2.2. if use all the features, only called one
if (null == this.fset) {
this.fset = new int[this.trainDataStore.featureMeta.numFeature];
for (int fid = 0; fid < this.fset.length; fid++) {
this.fset[fid] = fid;
if (null == this.fSet) {
this.fSet = new int[this.trainDataStore.featureMeta.numFeature];
for (int fid = 0; fid < this.fSet.length; fid++) {
this.fSet[fid] = fid;
}
}
}
Expand Down Expand Up @@ -394,9 +521,8 @@ public void findSplit() throws Exception {
// 2.3. find best split result of this tree node
if (this.param.isServerSplit) {
// 2.3.1 using server split

if (splitEntry.getFid() != -1) {
int trueSplitFid = this.fset[splitEntry.getFid()];
int trueSplitFid = this.fSet[splitEntry.getFid()];
int splitIdx = (int) splitEntry.getFvalue();
float trueSplitValue = this.sketches[trueSplitFid * this.param.numSplit + splitIdx];
LOG.info(String.format("Best split of node[%d]: feature[%d], value[%f], "
Expand Down
Loading

0 comments on commit 06f8eef

Please sign in to comment.