Skip to content

Commit

Permalink
Integrate TPP code into IPEX (#1357)
Browse files Browse the repository at this point in the history
* Add libxsmm as third party module

* Add TPP csrc files but meet build link issue

* Add libxsmm as third party module

* Add TPP csrc files but meet build link issue

* Enable build

* Libxsmm initialization python binding

* Add python binding for tpp/optimizier and tpp/pad/bert module

* Squad inference get expected result

* SQuad training pass

* Add all tpp related code into torch_ipex::tpp namespace

* Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding

* Enbale unpad fp32 path

* Enbale bf16 unpad

* Add libxsmm as third party module

* Add TPP csrc files but meet build link issue

* Add libxsmm as third party module

* Add TPP csrc files but meet build link issue

* Enable build

* Libxsmm initialization python binding

* Add python binding for tpp/optimizier and tpp/pad/bert module

* Squad inference get expected result

* SQuad training pass

* Add all tpp related code into torch_ipex::tpp namespace

* Registe the bert fusion op into torch.ops.ipex namespace instead of using python binding

* Enbale unpad fp32 path

* Enbale bf16 unpad

* fix unpad import issue

* Enable ipex.tpp_optimize API

* Fix transformers check issue

* Update frontend.py

* Unify unpad/paded code based the original unpad code

* Remove unused paded code

* Remove extend_profiler.py

* Add UT for tpp

* Add UT for backward

* Update Module.cpp

* 1) tpp_optimize -> fast_bert
2) add tpp prefix for optmizer related python api
3) fix the optimizer replace issue for non- SGD/AdamW

* Create README.md

* Create README.md

* Fix clang format issue

* Fix UT fail when transformers is not installed

* Update test_tpp_ops.py

* Fix UT fail when transformers is not installed

* Fix UT fail

* Fix UT fail
  • Loading branch information
liangan1 authored Feb 16, 2023
1 parent 82fb517 commit 82b6aa0
Show file tree
Hide file tree
Showing 41 changed files with 13,359 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "third_party/ideep"]
path = third_party/ideep
url = https://github.com/intel/ideep.git
[submodule "third_party/libxsmm"]
path = third_party/libxsmm
url = https://github.com/libxsmm/libxsmm.git
20 changes: 17 additions & 3 deletions csrc/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(IPEX_CPU_CPP_ISA_SRCS)
set(IPEX_CPU_CPP_TOOLKIT_SRCS)
set(IPEX_CPU_CPP_IDEEP_SRCS)
set(IPEX_CPU_CPP_RUNTIME_SRCS)
set (IPEX_CPU_CPP_TPP_SRCS)

set(IPEX_JIT_CPP_SRCS)
set(IPEX_UTLIS_CPP_SRCS)
Expand All @@ -62,13 +63,13 @@ add_subdirectory(${IPEX_CPU_ROOT_DIR}/isa)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/toolkit)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/runtime)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/utils)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/tpp)

add_subdirectory(${IPEX_JIT_CPP_ROOT} jit_cpu)
add_subdirectory(${IPEX_UTLIS_CPP_ROOT} csrc_utlis)

set(IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_DYNDISP_SRCS} ${IPEX_CPU_CPP_ISA_SRCS_GEN} ${IPEX_CPU_CPP_UTILS_SRCS} ${IPEX_CPU_CPP_QUANTIZATION_SRCS} ${IPEX_JIT_CPP_SRCS}
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_UTLIS_CPP_SRCS}
${IPEX_CPU_CPP_TOOLKIT_SRCS})
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} ${IPEX_CPU_CPP_TPP_SRCS})

list(REMOVE_ITEM IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_ISA_SRCS_ORIGIN})

Expand All @@ -81,10 +82,12 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_ROOT_DIR})
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR})
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/aten)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/utils)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_ROOT_DIR}/tpp)

target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_JIT_CPP_ROOT})
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_UTLIS_CPP_ROOT})

target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/include)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/include)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/mkl-dnn/third_party/oneDNN/include)
# TODO: once llga is merged into oneDNN, use oneDNN directly as the third_party instead of using that inside llga
Expand All @@ -101,6 +104,18 @@ if(CLANG_FORMAT)
add_dependencies(${PLUGIN_NAME_CPU} CL_FORMAT_CPU_NATIVE_CSRC)
endif()

include(${CMAKE_ROOT}/Modules/ExternalProject.cmake)
ExternalProject_Add(libxsmm
SOURCE_DIR ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm
BUILD_IN_SOURCE 1
CONFIGURE_COMMAND ""
BUILD_COMMAND
make
"AVX=3"
"-j"
INSTALL_COMMAND ""
)
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/libxsmm/lib/libxsmm.a)
add_dependencies(${PLUGIN_NAME_CPU} dnnl_graph)
# If Graph Compiler is built, then it should link to its LLVM dependencies,
# and not the LLVM symbols exposed by PyTorch.
Expand All @@ -114,7 +129,6 @@ if (DEFINED ENV{DNNL_GRAPH_BUILD_COMPILER_BACKEND})
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs=${DNNL_GRAPHCOMPILER_LLVM_LIB_EXCLUDE}")
endif()
endif()

find_package(oneMKL QUIET)
if (ONEMKL_FOUND)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEMKL_INCLUDE_DIR})
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/tpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FILE(GLOB _TPP_SRCS *.cpp bert/*.cpp)
LIST(APPEND IPEX_CPU_CPP_TPP_SRCS ${_TPP_SRCS})
# LIST(APPEND IPEX_CPU_CPP_ATEN_SRCS ${_CPU_KERNELS_SRCS})
message(STATUS "IPEX_CPU_CPP_TPP_SRCS: ${IPEX_CPU_CPP_TPP_SRCS}")
# Pass to parent
set(IPEX_CPU_CPP_TPP_SRCS ${IPEX_CPU_CPP_TPP_SRCS} PARENT_SCOPE)
31 changes: 31 additions & 0 deletions csrc/cpu/tpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Overview
This directory mainly includes the tpp related tools and fused kernel implementation based on tpp optimization.

```
├── bert #fused kernel based on tpp
│ ├── fused_bert.cpp
│ ├── fused_dense_dropout_layernorm_bwd_tmpl.h #backard for fused linear+dropout+layernorm
│ ├── fused_dense_dropout_layernorm_fwd_tmpl.h #forward for fused linear+dropout+layernorm
│ ├── fused_dense_gelu_bwd_tmpl.h #backward for fused linear+gelu
│ ├── fused_dense_gelu_fwd_tmpl.h #forward for fused linear+gelu
│ ├── fused_embedding_layernorm_dropout_bwd_tmpl.h #forward for fused embeeding+add+layernorm+dropout
│ ├── fused_embedding_layernorm_dropout_fwd_tmpl.h #backard for fused embeeding+add+layernorm+dropout
│ ├── fused_self_attention_bwd_tmpl.h #fused backward self-attention
│ └── fused_self_attention_fwd_tmpl.h #fused forward self-attention
├── CMakeLists.txt
├── common_loops.cpp #loops generation and tuning
├── ext_tpp.h
├── init.cpp
├── jit_compile.cpp
├── jit_compile.h
├── optim.cpp
├── optim.h
├── par_loop_generator.cpp #loops generation and tuning
├── par_loop_generator.h #loops generation and tuning
├── rtm.h
├── tensor_helper.h
├── threaded_loops.h
├── timing.h
├── utils.h
└── xsmm_functors.h #the tpp definition based on libxsmm
```
276 changes: 276 additions & 0 deletions csrc/cpu/tpp/bert/fused_bert.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@

#include <ATen/record_function.h>
//#include <torch/csrc/autograd/VariableTypeUtils.h>
//#include <torch/extension.h>

#include <dyndisp/DispatchStub.h>
#include <torch/all.h>
#include <iostream>
#include <vector>
#include "ext_tpp.h"
//#include "init.h"
#include "tensor_helper.h"
#include "threaded_loops.h"
#include "timing.h"
#include "xsmm_functors.h"

namespace torch_ipex {
namespace tpp {

static int my_rank = guess_mpi_rank();

REGISTER_LOCAL_SCOPE(b_emb, "b_emb");
REGISTER_LOCAL_SCOPE(q_gemm, "q_gemm");
REGISTER_LOCAL_SCOPE(k_gemm, "k_gemm");
REGISTER_LOCAL_SCOPE(v_gemm, "v_gemm");
REGISTER_LOCAL_SCOPE(ac_gemm, "ac_gemm");
REGISTER_LOCAL_SCOPE(o_gemm, "o_gemm");
REGISTER_LOCAL_SCOPE(i_gemm, "i_gemm");

REGISTER_LOCAL_SCOPE(db_emb, "db_emb");
REGISTER_LOCAL_SCOPE(diq_gemm, "diq_gemm");
REGISTER_LOCAL_SCOPE(dik_gemm, "dik_gemm");
REGISTER_LOCAL_SCOPE(div_gemm, "div_gemm");
REGISTER_LOCAL_SCOPE(dica_gemm, "dica_gemm");
REGISTER_LOCAL_SCOPE(dii_gemm, "dii_gemm");
REGISTER_LOCAL_SCOPE(dio_gemm, "dio_gemm");
REGISTER_LOCAL_SCOPE(dwqkv_gemm, "dwqkv_gemm");
REGISTER_LOCAL_SCOPE(dwq_gemm, "dwq_gemm");
REGISTER_LOCAL_SCOPE(dwk_gemm, "dwk_gemm");
REGISTER_LOCAL_SCOPE(dwv_gemm, "dwv_gemm");
REGISTER_LOCAL_SCOPE(dwa_gemm, "dwa_gemm");
REGISTER_LOCAL_SCOPE(dwc_gemm, "dwc_gemm");
REGISTER_LOCAL_SCOPE(dac_gemm, "dac_gemm");
REGISTER_LOCAL_SCOPE(dwi_gemm, "dwi_gemm");
REGISTER_LOCAL_SCOPE(dwo_gemm, "dwo_gemm");
REGISTER_LOCAL_SCOPE(dqkv_bias, "dqkv_bias");
REGISTER_LOCAL_SCOPE(di_bias, "di_bias");
REGISTER_LOCAL_SCOPE(do_bias, "do_bias");

template <typename T>
inline void omp_reduce_buf(
int num_threads,
int N,
float** ptrs,
T* buf,
bool accumulate = false) {
ScopedTimer _t(EW_RED);
#pragma omp for
for (int i = 0; i < N; i++) {
float sum = 0.0;
for (int j = 0; j < num_threads; j++) {
sum += ptrs[j][i];
}
if (accumulate) {
buf[i] += sum;
} else {
buf[i] = sum;
}
}
}

static std::vector<at::Tensor> fused_self_attention_fwd_unpad(
double p,
std::vector<at::Tensor> inputs,
bool training) {
GlobalPass _gp(FWD);
if (inputs[6].dtype() == at::kFloat) {
typedef float T;
#include "fused_self_attention_fwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_self_attention_fwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_self_attention_bwd_unpad(
double p,
std::vector<at::Tensor> inputs) {
GlobalPass _gp(BWD);
if (inputs[0].dtype() == at::kFloat) {
typedef float T;
#include "fused_self_attention_bwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_self_attention_bwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_dense_dropout_layernorm_fwd_unpad(
double p,
double eps,
std::vector<at::Tensor> inputs,
bool training) {
GlobalPass _gp(FWD);
if (inputs[0].dtype() == at::kFloat) {
typedef float T;
#include "fused_dense_dropout_layernorm_fwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_dense_dropout_layernorm_fwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_dense_dropout_layernorm_bwd_unpad(
double p,
std::vector<at::Tensor> inputs) {
GlobalPass _gp(BWD);
if (inputs[0].dtype() == at::kFloat) {
typedef float T;
#include "fused_dense_dropout_layernorm_bwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_dense_dropout_layernorm_bwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_dense_gelu_fwd_unpad(
at::Tensor t_in,
at::Tensor t_wt,
at::Tensor t_bias,
bool training) {
GlobalPass _gp(FWD);
if (t_in.dtype() == at::kFloat) {
typedef float T;
#include "fused_dense_gelu_fwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_dense_gelu_fwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_dense_gelu_bwd_unpad(
at::Tensor t_grad_out,
at::Tensor t_gelu_in,
at::Tensor t_in,
at::Tensor t_wt) {
GlobalPass _gp(BWD);
if (t_grad_out.dtype() == at::kFloat) {
typedef float T;
#include "fused_dense_gelu_bwd_tmpl.h"
} else {
typedef bfloat16 T;
#include "fused_dense_gelu_bwd_tmpl.h"
}
}

static std::vector<at::Tensor> fused_embedding_layernorm_dropout_fwd_unpad(
double p,
double eps,
long H,
long pad_id,
std::vector<at::Tensor> inputs,
bool training) {
GlobalPass _gp(FWD);
if (inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) {
typedef float T;
typedef float ET;
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
} else if (
inputs[4].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) {
typedef bfloat16 T;
typedef float ET;
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
} else if (
inputs[4].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) {
typedef float T;
typedef bfloat16 ET;
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
} else if (
inputs[4].dtype() == at::kBFloat16 &&
inputs[6].dtype() == at::kBFloat16) {
typedef bfloat16 T;
typedef bfloat16 ET;
#include "fused_embedding_layernorm_dropout_fwd_tmpl.h"
} else {
PCL_ASSERT(0, "Should not come here\n");
}
}

static std::vector<at::Tensor> fused_embedding_layernorm_dropout_bwd_unpad(
double p,
long pad_id,
std::vector<at::Tensor> inputs) {
GlobalPass _gp(BWD);
if (inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kFloat) {
typedef float T;
typedef float ET;
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
} else if (
inputs[0].dtype() == at::kBFloat16 && inputs[6].dtype() == at::kFloat) {
typedef bfloat16 T;
typedef float ET;
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
} else if (
inputs[0].dtype() == at::kFloat && inputs[6].dtype() == at::kBFloat16) {
typedef float T;
typedef bfloat16 ET;
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
} else if (
inputs[0].dtype() == at::kBFloat16 &&
inputs[6].dtype() == at::kBFloat16) {
typedef bfloat16 T;
typedef bfloat16 ET;
#include "fused_embedding_layernorm_dropout_bwd_tmpl.h"
} else {
PCL_ASSERT(0, "Should not come here\n");
}
}
} // namespace tpp
} // namespace torch_ipex
namespace {
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
torch::schema(
"torch_ipex::fused_self_attention_fwd_unpad(float p, Tensor[] inputs, bool training) -> Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_self_attention_fwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_self_attention_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_self_attention_bwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_dense_dropout_layernorm_fwd_unpad(float p, float eps, Tensor[] inputs, bool training) -> Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_dense_dropout_layernorm_fwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_dense_dropout_layernorm_bwd_unpad(float p, Tensor[] inputs) -> Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_dense_dropout_layernorm_bwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_dense_gelu_fwd_unpad(Tensor t_in, Tensor t_wt, Tensor "
"t_bias, bool training)->Tensor[] ",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_dense_gelu_fwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_dense_gelu_bwd_unpad(Tensor t_grad_out, Tensor t_gelu_in,"
"Tensor t_in, Tensor t_wt) -> Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_dense_gelu_bwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_embedding_layernorm_dropout_fwd_unpad(float p, float "
"eps, int H, int pad_id, Tensor(a!)[] inputs, bool training) ->"
"Tensor[]",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_embedding_layernorm_dropout_fwd_unpad);

m.def(
torch::schema(
"torch_ipex::fused_embedding_layernorm_dropout_bwd_unpad(float p, int "
"pad_id, Tensor(a!)[] inputs)->Tensor[] ",
c10::AliasAnalysisKind::PURE_FUNCTION),
torch_ipex::tpp::fused_embedding_layernorm_dropout_bwd_unpad);
}
} // namespace
Loading

0 comments on commit 82b6aa0

Please sign in to comment.