-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate TPP code into IPEX (#1357)
* Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Enable build * Libxsmm initialization python binding * Add python binding for tpp/optimizier and tpp/pad/bert module * Squad inference get expected result * SQuad training pass * Add all tpp related code into torch_ipex::tpp namespace * Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding * Enbale unpad fp32 path * Enbale bf16 unpad * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Add libxsmm as third party module * Add TPP csrc files but meet build link issue * Enable build * Libxsmm initialization python binding * Add python binding for tpp/optimizier and tpp/pad/bert module * Squad inference get expected result * SQuad training pass * Add all tpp related code into torch_ipex::tpp namespace * Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding * Enbale unpad fp32 path * Enbale bf16 unpad * fix unpad import issue * Enable ipex.tpp_optimize API * Fix transformers check issue * Update frontend.py * Unify unpad/paded code based the original unpad code * Remove unused paded code * Remove extend_profiler.py * Add UT for tpp * Add UT for backward * Update Module.cpp * 1) tpp_optimize -> fast_bert 2) add tpp prefix for optmizer related python api 3) fix the optimizer replace issue for non- SGD/AdamW * Create README.md * Create README.md * Fix clang format issue * Fix UT fail when transformers is not installed * Update test_tpp_ops.py * Fix UT fail when transformers is not installed * Fix UT fail * Fix UT fail
- Loading branch information
Showing
41 changed files
with
13,359 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FILE(GLOB _TPP_SRCS *.cpp bert/*.cpp) | ||
LIST(APPEND IPEX_CPU_CPP_TPP_SRCS ${_TPP_SRCS}) | ||
# LIST(APPEND IPEX_CPU_CPP_ATEN_SRCS ${_CPU_KERNELS_SRCS}) | ||
message(STATUS "IPEX_CPU_CPP_TPP_SRCS: ${IPEX_CPU_CPP_TPP_SRCS}") | ||
# Pass to parent | ||
set(IPEX_CPU_CPP_TPP_SRCS ${IPEX_CPU_CPP_TPP_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Overview | ||
This directory mainly includes the tpp related tools and fused kernel implementation based on tpp optimization. | ||
|
||
``` | ||
├── bert #fused kernel based on tpp | ||
│ ├── fused_bert.cpp | ||
│ ├── fused_dense_dropout_layernorm_bwd_tmpl.h #backard for fused linear+dropout+layernorm | ||
│ ├── fused_dense_dropout_layernorm_fwd_tmpl.h #forward for fused linear+dropout+layernorm | ||
│ ├── fused_dense_gelu_bwd_tmpl.h #backward for fused linear+gelu | ||
│ ├── fused_dense_gelu_fwd_tmpl.h #forward for fused linear+gelu | ||
│ ├── fused_embedding_layernorm_dropout_bwd_tmpl.h #forward for fused embeeding+add+layernorm+dropout | ||
│ ├── fused_embedding_layernorm_dropout_fwd_tmpl.h #backard for fused embeeding+add+layernorm+dropout | ||
│ ├── fused_self_attention_bwd_tmpl.h #fused backward self-attention | ||
│ └── fused_self_attention_fwd_tmpl.h #fused forward self-attention | ||
├── CMakeLists.txt | ||
├── common_loops.cpp #loops generation and tuning | ||
├── ext_tpp.h | ||
├── init.cpp | ||
├── jit_compile.cpp | ||
├── jit_compile.h | ||
├── optim.cpp | ||
├── optim.h | ||
├── par_loop_generator.cpp #loops generation and tuning | ||
├── par_loop_generator.h #loops generation and tuning | ||
├── rtm.h | ||
├── tensor_helper.h | ||
├── threaded_loops.h | ||
├── timing.h | ||
├── utils.h | ||
└── xsmm_functors.h #the tpp definition based on libxsmm | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
|
||
#include <ATen/record_function.h> | ||
//#include <torch/csrc/autograd/VariableTypeUtils.h> | ||
//#include <torch/extension.h> | ||
|
||
#include <dyndisp/DispatchStub.h> | ||
#include <torch/all.h> | ||
#include <iostream> | ||
#include <vector> | ||
#include "ext_tpp.h" | ||
//#include "init.h" | ||
#include "tensor_helper.h" | ||
#include "threaded_loops.h" | ||
#include "timing.h" | ||
#include "xsmm_functors.h" | ||
|
||
namespace torch_ipex { | ||
namespace tpp { | ||
|
||
static int my_rank = guess_mpi_rank(); | ||
|
||
REGISTER_LOCAL_SCOPE(b_emb, "b_emb"); | ||
REGISTER_LOCAL_SCOPE(q_gemm, "q_gemm"); | ||
REGISTER_LOCAL_SCOPE(k_gemm, "k_gemm"); | ||
REGISTER_LOCAL_SCOPE(v_gemm, "v_gemm"); | ||
REGISTER_LOCAL_SCOPE(ac_gemm, "ac_gemm"); | ||
REGISTER_LOCAL_SCOPE(o_gemm, "o_gemm"); | ||
REGISTER_LOCAL_SCOPE(i_gemm, "i_gemm"); | ||
|
||
REGISTER_LOCAL_SCOPE(db_emb, "db_emb"); | ||
REGISTER_LOCAL_SCOPE(diq_gemm, "diq_gemm"); | ||
REGISTER_LOCAL_SCOPE(dik_gemm, "dik_gemm"); | ||
REGISTER_LOCAL_SCOPE(div_gemm, "div_gemm"); | ||
REGISTER_LOCAL_SCOPE(dica_gemm, "dica_gemm"); | ||
REGISTER_LOCAL_SCOPE(dii_gemm, "dii_gemm"); | ||
REGISTER_LOCAL_SCOPE(dio_gemm, "dio_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwqkv_gemm, "dwqkv_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwq_gemm, "dwq_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwk_gemm, "dwk_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwv_gemm, "dwv_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwa_gemm, "dwa_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwc_gemm, "dwc_gemm"); | ||
REGISTER_LOCAL_SCOPE(dac_gemm, "dac_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwi_gemm, "dwi_gemm"); | ||
REGISTER_LOCAL_SCOPE(dwo_gemm, "dwo_gemm"); | ||
REGISTER_LOCAL_SCOPE(dqkv_bias, "dqkv_bias"); | ||
REGISTER_LOCAL_SCOPE(di_bias, "di_bias"); | ||
REGISTER_LOCAL_SCOPE(do_bias, "do_bias"); | ||
|
||
template <typename T> | ||
inline void omp_reduce_buf( | ||
int num_threads, | ||
int N, | ||
float** ptrs, | ||
T* buf, | ||
bool accumulate = false) { | ||
ScopedTimer _t(EW_RED); | ||
#pragma omp for | ||
for (int i = 0; i < N; i++) { | ||
float sum = 0.0; | ||
for (int j = 0; j < num_threads; j++) { | ||
sum += ptrs[j][i]; | ||
} | ||
if (accumulate) { | ||
buf[i] += sum; | ||
} else { | ||
buf[i] = sum; | ||
} | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_self_attention_fwd_unpad( | ||
double p, | ||
std::vector<at::Tensor> inputs, | ||
bool training) { | ||
GlobalPass _gp(FWD); | ||
if (inputs[6].dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_self_attention_fwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_self_attention_fwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_self_attention_bwd_unpad( | ||
double p, | ||
std::vector<at::Tensor> inputs) { | ||
GlobalPass _gp(BWD); | ||
if (inputs[0].dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_self_attention_bwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_self_attention_bwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_dense_dropout_layernorm_fwd_unpad( | ||
double p, | ||
double eps, | ||
std::vector<at::Tensor> inputs, | ||
bool training) { | ||
GlobalPass _gp(FWD); | ||
if (inputs[0].dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_dense_dropout_layernorm_fwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_dense_dropout_layernorm_fwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_dense_dropout_layernorm_bwd_unpad( | ||
double p, | ||
std::vector<at::Tensor> inputs) { | ||
GlobalPass _gp(BWD); | ||
if (inputs[0].dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_dense_dropout_layernorm_bwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_dense_dropout_layernorm_bwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_dense_gelu_fwd_unpad( | ||
at::Tensor t_in, | ||
at::Tensor t_wt, | ||
at::Tensor t_bias, | ||
bool training) { | ||
GlobalPass _gp(FWD); | ||
if (t_in.dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_dense_gelu_fwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_dense_gelu_fwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_dense_gelu_bwd_unpad( | ||
at::Tensor t_grad_out, | ||
at::Tensor t_gelu_in, | ||
at::Tensor t_in, | ||
at::Tensor t_wt) { | ||
GlobalPass _gp(BWD); | ||
if (t_grad_out.dtype() == at::kFloat) { | ||
typedef float T; | ||
#include "fused_dense_gelu_bwd_tmpl.h" | ||
} else { | ||
typedef bfloat16 T; | ||
#include "fused_dense_gelu_bwd_tmpl.h" | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_embedding_layernorm_dropout_fwd_unpad( | ||
double p, | ||
double eps, | ||
long H, | ||
long pad_id, | ||
std::vector<at::Tensor> inputs, | ||
bool training) { | ||
GlobalPass _gp(FWD); | ||
if (inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) { | ||
typedef float T; | ||
typedef float ET; | ||
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h" | ||
} else if ( | ||
inputs[4].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) { | ||
typedef bfloat16 T; | ||
typedef float ET; | ||
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h" | ||
} else if ( | ||
inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) { | ||
typedef float T; | ||
typedef bfloat16 ET; | ||
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h" | ||
} else if ( | ||
inputs[4].dtype() == at::kBFloat16 && | ||
inputs[6].dtype() == at::kBFloat16) { | ||
typedef bfloat16 T; | ||
typedef bfloat16 ET; | ||
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h" | ||
} else { | ||
PCL_ASSERT(0, "Should not come here\n"); | ||
} | ||
} | ||
|
||
static std::vector<at::Tensor> fused_embedding_layernorm_dropout_bwd_unpad( | ||
double p, | ||
long pad_id, | ||
std::vector<at::Tensor> inputs) { | ||
GlobalPass _gp(BWD); | ||
if (inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) { | ||
typedef float T; | ||
typedef float ET; | ||
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h" | ||
} else if ( | ||
inputs[0].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) { | ||
typedef bfloat16 T; | ||
typedef float ET; | ||
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h" | ||
} else if ( | ||
inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) { | ||
typedef float T; | ||
typedef bfloat16 ET; | ||
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h" | ||
} else if ( | ||
inputs[0].dtype() == at::kBFloat16 && | ||
inputs[6].dtype() == at::kBFloat16) { | ||
typedef bfloat16 T; | ||
typedef bfloat16 ET; | ||
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h" | ||
} else { | ||
PCL_ASSERT(0, "Should not come here\n"); | ||
} | ||
} | ||
} // namespace tpp | ||
} // namespace torch_ipex | ||
namespace { | ||
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { | ||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_self_attention_fwd_unpad(float p, Tensor[] inputs, bool training) -> Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_self_attention_fwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_self_attention_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_self_attention_bwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_dense_dropout_layernorm_fwd_unpad(float p, float eps, Tensor[] inputs, bool training) -> Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_dense_dropout_layernorm_fwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_dense_dropout_layernorm_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_dense_dropout_layernorm_bwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_dense_gelu_fwd_unpad(Tensor t_in, Tensor t_wt, Tensor " | ||
"t_bias, bool training)->Tensor[] ", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_dense_gelu_fwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_dense_gelu_bwd_unpad(Tensor t_grad_out, Tensor t_gelu_in," | ||
"Tensor t_in, Tensor t_wt) -> Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_dense_gelu_bwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_embedding_layernorm_dropout_fwd_unpad(float p, float " | ||
"eps, int H, int pad_id, Tensor(a!)[] inputs, bool training) ->" | ||
"Tensor[]", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_embedding_layernorm_dropout_fwd_unpad); | ||
|
||
m.def( | ||
torch::schema( | ||
"torch_ipex::fused_embedding_layernorm_dropout_bwd_unpad(float p, int " | ||
"pad_id, Tensor(a!)[] inputs)->Tensor[] ", | ||
c10::AliasAnalysisKind::PURE_FUNCTION), | ||
torch_ipex::tpp::fused_embedding_layernorm_dropout_bwd_unpad); | ||
} | ||
} // namespace |
Oops, something went wrong.