Skip to content

Commit

Permalink
Optimize preparation of selfattn operators (apache#20682)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych authored Nov 4, 2021
1 parent 30734fb commit 9266a91
Showing 1 changed file with 52 additions and 42 deletions.
94 changes: 52 additions & 42 deletions src/operator/subgraph/dnnl/dnnl_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class SgDNNLSelfAttQKOp {
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
const std::vector<NDArray>& outputs,
bool already_prepared);

void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -163,10 +164,12 @@ static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SgDNNLSelfAttQKOp& op = state_pointer.get_state<SgDNNLSelfAttQKOp>();
bool already_prepared = false;
if (!op.IsInitialized()) {
op.Initialize(ctx, inputs, req, outputs);
already_prepared = true;
}
op.Forward(ctx, inputs, req, outputs);
op.Forward(ctx, inputs, req, outputs, already_prepared);
}

static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -264,21 +267,23 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx,
void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const size_t output_lin_dim = inputs[0].shape()[2];
const size_t embed_dim = output_lin_dim / QKV_NUM;

MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
DType* key_mem_ptr = query_mem_ptr + embed_dim;
cached_query_mem_->set_data_handle(query_mem_ptr);
cached_key_mem_->set_data_handle(key_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});

const std::vector<NDArray>& outputs,
bool already_prepared) {
if (!already_prepared) {
const size_t output_lin_dim = inputs[0].shape()[2];
const size_t embed_dim = output_lin_dim / QKV_NUM;

MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, {
DType* query_mem_ptr = inputs[0].data().dptr<DType>();
DType* key_mem_ptr = query_mem_ptr + embed_dim;
cached_query_mem_->set_data_handle(query_mem_ptr);
cached_key_mem_->set_data_handle(key_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_out_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});
}
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
DNNLStream::Get()->Submit();

Expand Down Expand Up @@ -484,7 +489,8 @@ class DNNLSelfAttValAttOp {
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
const std::vector<NDArray>& outputs,
bool already_prepared);

void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -538,10 +544,12 @@ static void DNNLSelfAttValAttForward(const OpStatePtr& state_pointer,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLSelfAttValAttOp& op = state_pointer.get_state<DNNLSelfAttValAttOp>();
bool already_prepared = false;
if (!op.IsInitialized()) {
op.Initialize(ctx, inputs, req, outputs);
already_prepared = true;
}
op.Forward(ctx, inputs, req, outputs);
op.Forward(ctx, inputs, req, outputs, already_prepared);
}

void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
Expand Down Expand Up @@ -664,29 +672,31 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
void DNNLSelfAttValAttOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// multiply by 2 as we need to skip queries and keys
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;

auto att_buffer = inputs[0];
if (att_buffer.IsDNNLData())
att_buffer = att_buffer.Reorder2Default();

MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
DType* attention_ptr = att_buffer.data().dptr<DType>();
cached_att_mem_->set_data_handle(attention_ptr);
});

MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
DType* qkv_ptr = inputs[1].data().dptr<DType>();
DType* value_mem_ptr = qkv_ptr + value_offset;
cached_value_mem_->set_data_handle(value_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});

const std::vector<NDArray>& outputs,
bool already_prepared) {
if (!already_prepared) {
// multiply by 2 as we need to skip queries and keys
const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2;

auto att_buffer = inputs[0];
if (att_buffer.IsDNNLData())
att_buffer = att_buffer.Reorder2Default();

MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, {
DType* attention_ptr = att_buffer.data().dptr<DType>();
cached_att_mem_->set_data_handle(attention_ptr);
});

MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, {
DType* qkv_ptr = inputs[1].data().dptr<DType>();
DType* value_mem_ptr = qkv_ptr + value_offset;
cached_value_mem_->set_data_handle(value_mem_ptr);
});

MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
cached_transposed_mem_->set_data_handle(outputs[0].data().dptr<DType>());
});
}
DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_);
DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args);
DNNLStream::Get()->Submit();
Expand Down

0 comments on commit 9266a91

Please sign in to comment.