Skip to content

Commit

Permalink
Merge remote-tracking branch 'remotes/origin/branch-1.0.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
paynie committed Jul 24, 2017
2 parents d789576 + a671498 commit 0cc2354
Show file tree
Hide file tree
Showing 21 changed files with 791 additions and 358 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@

## Design

* [模型切分(modelPartitioner)](./docs/design/model_partitioner.md)
* [异步控制(syncController)](./docs/design/sync_controller.md)
* [定制函数(psFunc)](./docs/design/psfFunc.md)
* [核心类的说明](./docs/apis/interface_api.md)
* [psFunc](./docs/design/psfFunc.md)


## Algorithm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ public AngelConfiguration(Configuration conf) {
*/
public static final String ANGEL_MATRIXTRANSFER_CHECK_INTERVAL_MS = ANGEL_PREFIX
+ "matrixtransfer.check.interval.ms";
public static final int DEFAULT_ANGEL_MATRIXTRANSFER_CHECK_INTERVAL_MS = 1000;
public static final int DEFAULT_ANGEL_MATRIXTRANSFER_CHECK_INTERVAL_MS = 100;

// //////////////////////////////
// Matrix transfer Configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ public PartitionGetParam getPartParam() {

@Override
public String toString() {
return "GetUDFRequest [getFuncClass=" + getFuncClass + ", partParam=" + partParam + "]";
return "GetUDFRequest{" + "getFuncClass='" + getFuncClass + '\'' + ", partParam=" + partParam
+ "} " + super.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ public boolean equals(Object obj) {
return true;
}

@Override
public String toString() {
return "PartitionRequest [clock=" + clock + ", partKey=" + partKey + "]";
@Override public String toString() {
return "PartitionRequest{" + "clock=" + clock + ", partKey=" + partKey + "} " + super
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,8 @@ public void setServerId(ParameterServerId serverId) {
}

public abstract TransportMethod getType();

@Override public String toString() {
return "Request{" + "serverId=" + serverId + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

@SuppressWarnings("unchecked")
private void getSplit(int seqId, GetUDFRequest request, ChannelHandlerContext ctx) {
long startTs = System.currentTimeMillis();
GetUDFResponse response = null;
try {
Class<? extends GetFunc> funcClass = (Class<? extends GetFunc>) Class.forName(request.getGetFuncClass());
Expand All @@ -145,10 +146,14 @@ private void getSplit(int seqId, GetUDFRequest request, ChannelHandlerContext ct
response.setResponseType(ResponseType.FATAL);
}

long endTs = System.currentTimeMillis();
LOG.debug("get psf seqId=" + seqId + " process time=" + (endTs - startTs));

ByteBuf buf = ByteBufUtils.newByteBuf(4 + response.bufferLen(), useDirectorBuffer);
buf.writeInt(seqId);
response.serialize(buf);
ctx.writeAndFlush(buf);
LOG.debug("get psf seqId=" + seqId + " serialize and send time=" + (System.currentTimeMillis() - endTs));
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ private void printDispatchInfo() {
}

for (Entry<Integer, Request> entry : seqIdToRequestMap.entrySet()) {
LOG.debug("infight request id=" + entry.getKey() + ", request context=" + entry.getValue()
LOG.debug("infight request seqId=" + entry.getKey() + ", request context=" + entry.getValue()
+ ", request channel=" + entry.getValue().getContext().getChannel());
}
}
Expand Down Expand Up @@ -914,12 +914,12 @@ private void removeTimeOutRequestItem() {
for (Entry<Integer, Request> entry : seqIdToRequestMap.entrySet()) {
Request item = entry.getValue();
item.getContext().addWaitTimeTicks(checkPeriodMS * 10);
LOG.debug("request " + entry.getKey() + " wait time="
LOG.debug("request seqId=" + entry.getKey() + " wait time="
+ item.getContext().getWaitTimeTicks());
if (item.getContext().getWaitTimeTicks() > requestTimeOut) {
item = seqIdToRequestMap.get(entry.getKey());
if (item != null) {
LOG.info("remove timeout request " + item);
LOG.info("remove timeout request seqId=" + entry.getKey());
removeNum++;
requestFailed(entry.getKey(), item);
}
Expand Down Expand Up @@ -1253,7 +1253,7 @@ public RequesterChannelFutureListener(int seqId, Request request) {

@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOG.debug("send request " + request + " with seqId " + seqId + "complete");
LOG.debug("send request " + request + " with seqId=" + seqId + " complete");
if (!future.isSuccess()) {
LOG.error("send " + seqId + " failed ", future.cause());
future.cause().printStackTrace();
Expand Down Expand Up @@ -1281,7 +1281,7 @@ public void run() {

TransportMethod method = request.getType();

LOG.debug("response handler, seqid = " + seqId + ", method = " + method + ", ts = "
LOG.debug("response handler, seqId=" + seqId + ", method=" + method + ", ts="
+ System.currentTimeMillis());

switch (method) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ public GetResult get(GetFunc func) throws InterruptedException, ExecutionExcepti
List<PartitionGetParam> partParams = param.split();
int size = partParams.size();

LOG.debug("get psf request " + func + " start, rpc request number=" + size);
List<Future<PartitionGetResult>> futureResultList =
new ArrayList<Future<PartitionGetResult>>(size);
List<PartitionGetResult> resultList = new ArrayList<PartitionGetResult>(size);
Expand Down
4 changes: 2 additions & 2 deletions docs/algo/kmeans_on_angel.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

![kmeans](../img/kmeans.png)

其中:![xi](../img/xi.png)代表第i个样本,![ci](../img/ci.png)代表与第i个样本距离最近的簇,![miu_i](../img/miu_i.png)代表第j个簇的簇心。
其中:![xi](../img/xi.png)代表第i个样本,![ci](../img/ci.png)代表与第i个样本距离最近的簇,![miu_j](../img/miu_j.png)代表第j个簇的簇心。


## Mini-batch KMeans
"Web-Scale K-Means Clustering"提出一种在朴素KMeans算法基础上改进的KMeans算法,用mini-batch方法训练,每次迭代选择一个mini-batch的样本更新簇心,如下所示:
"Web-Scale K-Means Clustering"提出一种在朴素KMeans算法基础上改进的KMeans算法,用mini-batch方法训练,每次迭代选择一个mini-batch的样本集来更新簇心,如下所示:

![mini_batch_kmeans](../img/mini_batch_kmeans.png)

Expand Down
52 changes: 52 additions & 0 deletions docs/algo/kmeans_on_angel_en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# KMeans

> KMeans is a method that aims to cluster data in K groups of equal variance. The conventional KMeans algorithm has performance bottleneck; when implemented with PS,however, KMeans achieves the same level of accuracy with better performance.
## 1. Introduction

The KMeans algorithm assigns each data point to its *nearest* cluster, where the *distance* is measured between the data point and the cluster's *centroid*. In general, Kmeans algorithm is implemented in an iterative way as shown below:

![kmeans](../img/kmeans.png)

where, ![xi](../img/xi.png) is the ith sample and ![ci](../img/ci.png) is its nearest cluster; ![miu_j](../img/miu_j.png) is the centroid of the ith cluster.


## Mini-batch KMeans
"Web-Scale K-Means Clustering" proposes a improved KMeans algorithm to address the latency, scalability and sparsity requirements in user-facing web applications, using mini-batch optimization for training. As shown below:

![mini_batch_kmeans](../img/mini_batch_kmeans.png)


## 2. Distributed Implementation on Angel

### Model Storage
KMeans on Angel stores the K centroids on ParameterServer,using a K×N matrix representation, where K is the number of clusters and N is the data dimension,i.e. number of features.

### Model Updating
KMeans on Angel is trained in an iterative fashion; during each iteration, the centroids are updated by mini-batch.

### Algorithm
KMeans on Angel algorithm:

![KMeans_on_Angel](../img/KMeans_on_Angel.png)


## 3. Execution & Performance

### Input Format

* Data format is set in "ml.data.type", which supports "libsvm" and "dummy" formats. For details, see [Angel Data Format](data_format_en.md)

### Parameters
* IO Parameters
* angel.train.data.path: input path
* ml.feature.num: number of features
* ml.data.type: [Angel Data Format](data_format_en.md), can be "dummy" or "libsvm"
* angel.save.modelPath: save path for trained model
* angel.log.path: save path for the log
* Algorithm Parameters
* ml.kmeans.center.num: K, number of clusters
* ml.kmeans.sample.ratio.perbath: sample ratio for mini-batch
* ml.kmeans.c:learning rate

### Performance
73 changes: 73 additions & 0 deletions docs/algo/lda_on_angel_en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# LDA(Latent Dirichlet Allocation)

---

> LDA is a widely-used topic-modeling technique, a Bayesian generative model for discovering hidden topical patterns that helps in dimension reduction and text analysis.
## 1. Introduction

### Overview

A text corpus ``$ C $`` contains a set of documents `` $ \{D_1, \cdots, D_{M}\} $``, and each document ``$ D_i $`` contains a set of words, ``$ D_i = (t_1, t_2, \cdots, t_{N_i}) $``. A word is a basic unit of a vocabulary denoted by ``$ V $``. The number of topics in LDA, ``$ K $``, needs to be specified. In LDA, each document is modeled as a random mixture over ``$ K $`` latent topics, ``$ \theta_d $``, whereas each topic is modeled as a ``$ V $`` dimensional distribution over words, ``$ \phi_k $``.

LDA models the generative process for each document in the corpus. It draws a ``$ K $`` dimensional topic distribution, ``$ \theta_d $``, from a Dirichlet distribution, ``$ Dir(\alpha) $``, where ``$ \alpha $`` is the parameter vector of the Dirichlet (hyperparameter of the LDA). To generate each word ``$ t_{dn} $`` in document ``$ d $``, LDA first draws the topic of the word, ``$ z_{dn} $``, from a multinomial distribution ``$ Mult(\theta_d) $``, and then draws the word ``$ w_{dn} \in V $`` from a multinomial distribution ``$ Mult(\phi_{z_{dn}}) $``.

### Gibbs Sampling
A common inference technique for LDA is Gibbs Sampling, which is a MCMC method for sampling from the posterior distribution of ``$ z_{dn} $`` and infer the distribution over topics and the distribution over words for each document. Some commonly used Gibbs Sampling variants include the Collapsed Gibbs Sampling(CGS), SparseLDA,
AliasLDA, F+LDA, LightLDA and WarpLDA, to name a few, and our experiment results suggest F+LDA as most suitable for training LDA on Angel.

### Collapsed Gibbs Sampling (CGS)
We use ``$ Z=\{z_d\}_{d=1}^D $`` to represent the set of topics for all words, ``$ \Phi = [\phi_1 \cdots \phi_{V}] $`` to represent the ``$ V \times K $`` topic-word matrix, and ``$ \Theta = [\theta_1 \cdots \theta_D] $`` to represent the matrix whose columns are the topic distributions for all documents, then, training LDA requires inferring the posterior of the latent variable ``$ (\Theta, \Phi, Z) $``, given the observed variable ``$ Z $`` and the hyperparameters. Useing conjugate prior, CGS gives a closed-form expression for the posterior of ``$ Z $``, resulting in simple iterations for sampling ``$ z_{dn} $`` following the conditional probability below:

```math
p(z_{dn} = k| t_{dn} = w, Z_{\neg dn}, C_{\neg dn}) \propto \\
\frac{C_{wk}^{\neg dn} + \beta}{C_{k}^{\neg dn} \\
+ V\beta}~(C_{dk}^{\neg dn} + \alpha)
```

### F+LDA
F+LDA factorizes the probability into two parts, ``$ C_{dk} \frac{C_{wk} + \beta}{C_k + V\beta} $`` and ``$ \alpha \frac{C_{wk} + \beta}{C_k + V\beta} $``. Because ``$ C_d $`` is sparse, sampling will be only done for its non-zero elements; for the rest, F+LDA uses the F+ tree for searching, thus reducing the complexity to O(logK). Overall, F+LDA's complexity is ``$ O(K_d) $``, where ``$ K_d $`` is the number of non-zero elements in the document-topic matrix.

## 2. Distributed Implementation on Angel

The overall framework for training LDA on Angel is shown in the figure below. There are two comparatively large matrices in LDA, ``$ C_w $`` and ``$ C_d $``, and we slice C_d to different workers, and C_w to different servers. In each iteration, workers pull C_w from the servers for drawing topics, and send the updates on C_w back to the servers.

![Architecture for LDA on Angel](../img/lda_ps.png)

## 3. Execution & Performance

### Input Format

* Each line is a document, and each document consists of a set of word ids; word ids are separated by `,`.

```math
wid_0, wid_1, ..., wid_n
```

### Parameters

* Data Parameters
* angel.train.data.path: input path
* angel.save.model.path: save path for trained model
* Algorithm Parameters
* ml.epoch.num: number of iterations
* ml.lda.word.num:number of words
* ml.lda.topic.num:number of topics
* ml.worker.thread.num:number of threads within each worker
* ml.lda.alpha: alpha
* ml.lda.beta: beta


### Performance

* **Data**
* PubMED

* **Resource**
* worker: 20
* ps: 20

* **Angel vs Spark**: Training time with 100 iterations
* Angel:15min
* Spark:>300min

Loading

0 comments on commit 0cc2354

Please sign in to comment.