Skip to content

Commit

Permalink
[EMBED] improve speed
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 1, 2016
1 parent e75660f commit 66bd728
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mshadow
9 changes: 8 additions & 1 deletion src/operator/embedding-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,14 @@ class EmbeddingOp : public Operator {
Tensor<xpu, 1> data = in_data[embedding::kData].get<xpu, 1, real_t>(s);
Tensor<xpu, 2> grad_out = out_grad[embedding::kOut].get<xpu, 2, real_t>(s);
Tensor<xpu, 2> grad_in = in_grad[embedding::kWeight].get<xpu, 2, real_t>(s);
Assign(grad_in, req[embedding::kWeight], take_grad(data, grad_out, param_.input_dim));
if (req[embedding::kWeight] == kWriteTo) {
grad_in = 0.0f;
AddTakeGrad(grad_in, data, grad_out);
} else if (req[embedding::kWeight] == kAddTo) {
AddTakeGrad(grad_in, data, grad_out);
} else {
LOG(FATAL) << "wrong req";
}
}

private:
Expand Down

0 comments on commit 66bd728

Please sign in to comment.