Skip to content

Commit

Permalink
Enable Windows Arm64 (pytorch#133088)
Browse files Browse the repository at this point in the history
This PR enables Pytorch for Windows on Arm64 - CPU only.
Currently, there aren't any checks in place to build and test for Windows on Arm64, but we're working to implement those as soon as possible.
We recommend using [Arm Performance Libraries (APL)](https://developer.arm.com/Tools%20and%20Software/Arm%20Performance%20Libraries) as a BLAS option, which is introduced in this PR.

Pull Request resolved: pytorch#133088
Approved by: https://github.com/malfet

Co-authored-by: cristian panaite <[email protected]>
Co-authored-by: Stefan-Alin Pahontu <[email protected]>
Co-authored-by: Ozan Aydin <[email protected]>
  • Loading branch information
4 people authored and pytorchmergebot committed Oct 24, 2024
1 parent f7bb11d commit b021486
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
endif()

if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE AND NOT (MSVC AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64"))
if(NOT MSVC)
# Bump up optimization level for sleef to -O1, since at -O0 the compiler
# excessively spills intermediate vector registers to the stack
Expand Down
65 changes: 64 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,46 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *inf
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);

// potrs
#if defined(_WIN32) && defined(_M_ARM64)

// The functions zpotrs, cpotrs, dpotrs, and spotrs are not directly available in LAPACKE on Windows on ARM,
// so we need to have wrapper functions to call them.
// The issue on ARM platform can be found below:
// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions

#define LAPACK_COL_MAJOR 102
#define LAPACK_ROW_MAJOR 101

extern "C" int LAPACKE_zpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex<double> *a, int lda, std::complex<double> *b, int ldb);
extern "C" int LAPACKE_cpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex<float> *a, int lda, std::complex<float> *b, int ldb);
extern "C" int LAPACKE_dpotrs(int matrix_layout, char uplo, int n, int nrhs, const double *a, int lda, double *b, int ldb);
extern "C" int LAPACKE_spotrs(int matrix_layout, char uplo, int n, int nrhs, const float *a, int lda, float *b, int ldb);

static inline void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info) {
*info = LAPACKE_zpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
}

static inline void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info) {
*info = LAPACKE_cpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
}

static inline void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info){
*info = LAPACKE_dpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
}

static inline void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info) {
*info = LAPACKE_spotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
}

#else

extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);

#endif

// potrf
extern "C" void zpotrf_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
Expand Down Expand Up @@ -284,11 +319,39 @@ extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau
extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);

// ormqr
#if defined(_WIN32) && defined(_M_ARM64)

// The functions zunmqr, cunmqr, dormqr, and sormqr are not directly available in LAPACKE on Windows on ARM,
// so we need to have wrapper functions to call them.
// The issue on ARM platform can be found below:
// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions

extern "C" int LAPACKE_zunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex<double> *a, int lda, const std::complex<double> *tau, std::complex<double> *c, int ldc, std::complex<double> *work, int lwork);
extern "C" int LAPACKE_cunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex<float> *a, int lda, const std::complex<float> *tau, std::complex<float> *c, int ldc, std::complex<float> *work, int lwork);
extern "C" int LAPACKE_dormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const double *a, int lda, const double *tau, double *c, int ldc, double *work, int lwork);
extern "C" int LAPACKE_sormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const float *a, int lda, const float *tau, float *c, int ldc, float *work, int lwork);

static inline void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info) {
*info = LAPACKE_zunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
}

static inline void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info) {
*info = LAPACKE_cunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
}

static inline void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info) {
*info = LAPACKE_dormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
}

static inline void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info) {
*info = LAPACKE_sormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
}
#else
extern "C" void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info);
extern "C" void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info);
extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);

#endif
// syevd
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
Expand Down
5 changes: 4 additions & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,10 @@ if(BUILD_TEST)
endif()
else()
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main)
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main)
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64")
target_link_libraries(${test_name}_${CPU_CAPABILITY} sleef)
endif()
endif()
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
Expand Down
10 changes: 8 additions & 2 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ else()
set(AT_MKLDNN_ENABLED 0)
set(AT_MKL_ENABLED 0)
endif()
set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib")
set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib;APL")
message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS})

if(BLAS STREQUAL "Eigen")
Expand Down Expand Up @@ -226,6 +226,12 @@ elseif(BLAS STREQUAL "FlexiBLAS")
find_package(FlexiBLAS REQUIRED)
include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR})
list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB})
elseif(BLAS STREQUAL "APL")
find_package(APL REQUIRED)
include_directories(SYSTEM ${APL_INCLUDE_DIR})
set(BLAS_INFO "apl")
set(BLAS_FOUND 1)
set(BLAS_LIBRARIES ${APL_LIBRARIES})
elseif(BLAS STREQUAL "Generic")
# On Debian family, the CBLAS ABIs have been merged into libblas.so
if(ENV{GENERIC_BLAS_LIBRARIES} STREQUAL "")
Expand All @@ -246,7 +252,7 @@ endif()
if(NOT INTERN_BUILD_MOBILE)
set(AT_MKL_SEQUENTIAL 0)
set(USE_BLAS 1)
if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND))
if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND OR APL_FOUND))
message(WARNING "Preferred BLAS (" ${BLAS} ") cannot be found, now searching for a general BLAS library")
find_package(BLAS)
if(NOT BLAS_FOUND)
Expand Down
58 changes: 58 additions & 0 deletions cmake/Modules/FindAPL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# - Find APL (Arm Performance Libraries)
#
# This module sets the following variables:
# APL_INCLUDE_SEARCH_PATHS - list of paths to search for APL include files
# APL_LIB_SEARCH_PATHS - list of paths to search for APL libraries
# APL_FOUND - set to true if APL is found
# APL_INCLUDE_DIR - path to include dir.
# APL_LIB_DIR - path to include dir.
# APL_LIBRARIES - list of libraries for base APL

SET(APL_INCLUDE_SEARCH_PATHS $ENV{ARMPL_DIR}/include)
SET(APL_LIB_SEARCH_PATHS $ENV{ARMPL_DIR}/lib)

SET(APL_FOUND ON)

# Check include file
FIND_PATH(APL_INCLUDE_DIR NAMES armpl.h PATHS ${APL_INCLUDE_SEARCH_PATHS})
IF(NOT APL_INCLUDE_DIR)
SET(APL_FOUND OFF)
MESSAGE(STATUS "Could not verify APL include directory. Turning APL_FOUND off")
ENDIF()

# Check lib file
FIND_PATH(APL_LIB_DIR NAMES libarmpl_lp64_mp.dll.lib libomp.dll.lib libarmpl_lp64_mp.a PATHS ${APL_LIB_SEARCH_PATHS})
IF(NOT APL_LIB_DIR)
SET(APL_FOUND OFF)
MESSAGE(STATUS "Could not verify APL lib directory. Turning APL_FOUND off")
ENDIF()

IF (APL_FOUND)
IF(WIN32)
set(APL_LIBRARIES
"${APL_LIB_DIR}/libarmpl_lp64_mp.dll.lib"
"${APL_LIB_DIR}/libomp.dll.lib"
)
ELSEIF(UNIX)
set(APL_LIBRARIES
"${APL_LIB_DIR}/libarmpl_lp64_mp.a"
)
ENDIF()
MESSAGE(STATUS "Found APL header: ${APL_INCLUDE_DIR}")
MESSAGE(STATUS "Found APL library: ${APL_LIB_DIR}")
message(STATUS "APL_LIBRARIES: ${APL_LIBRARIES}")
SET(CMAKE_REQUIRED_LIBRARIES ${APL_LIBRARIES})
include(CheckCSourceRuns)
CHECK_C_SOURCE_RUNS("
#include <stdlib.h>
#include <stdio.h>
float x[4] = { 1, 2, 3, 4 };
float y[4] = { .1, .01, .001, .0001 };
extern float cblas_sdot();
int main() {
int i;
double r = cblas_sdot(4, x, 1, y, 1);
exit((float)r != (float).1234);
}" BLAS_USE_CBLAS_DOT )
MESSAGE(STATUS "BLAS_USE_CBLAS_DOT: ${BLAS_USE_CBLAS_DOT}")
ENDIF (APL_FOUND)
28 changes: 28 additions & 0 deletions cmake/Modules/FindLAPACK.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,34 @@ if(BLAS_FOUND)
endif(LAPACK_LIBRARIES)
endif()

#Arm Performance Libraries
IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "apl"))
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
check_function_exists("cheev_" APL_LAPACK_WORKS)
if(APL_LAPACK_WORKS)
check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS)
if(NOT LAPACK_CGESDD_WORKS)
find_library(GFORTRAN_LIBRARY
NAMES libgfortran.a gfortran
PATHS ${CMAKE_C_IMPLICIT_LINK_DIRECTORIES})
list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}")
unset(LAPACK_CGESDD_WORKS CACHE)
check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS)
if(LAPACK_CGESDD_WORKS)
list(APPEND LAPACK_LIBRARIES "${GFORTRAN_LIBRARY}")
else()
message(WARNING "APL has been compiled with Lapack support, but cgesdd can not be used")
set(APL_LAPACK_WORKS NO)
endif()
endif()
endif()
set(CMAKE_REQUIRED_LIBRARIES)
if(APL_LAPACK_WORKS)
SET(LAPACK_INFO "apl")
else()
message(STATUS "It seems APL has not been compiled with Lapack support")
endif()
endif()
else(BLAS_FOUND)
message(STATUS "LAPACK requires BLAS")
endif(BLAS_FOUND)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/cpp_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def _get_torch_related_args(
if not aot_mode:
libraries.append("torch_python")

if _IS_WINDOWS:
if _IS_WINDOWS and platform.machine().lower() != "arm64":
libraries.append("sleef")

return include_dirs, libraries_dirs, libraries
Expand Down
3 changes: 2 additions & 1 deletion torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import importlib.abc
import os
import platform
import re
import shlex
import shutil
Expand Down Expand Up @@ -994,7 +995,7 @@ def CppExtension(name, sources, *args, **kwargs):
libraries.append('torch')
libraries.append('torch_cpu')
libraries.append('torch_python')
if IS_WINDOWS:
if IS_WINDOWS and platform.machine().lower() != "arm64":
libraries.append("sleef")

kwargs['libraries'] = libraries
Expand Down

0 comments on commit b021486

Please sign in to comment.