-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathsu3_project.cuh
127 lines (110 loc) · 3.37 KB
/
su3_project.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#pragma once
/**
* @file su3_project.cuh
*
* @section Description
*
* This header file defines an interative SU(3) projection algorithm
*/
#include <quda_matrix.h>
namespace quda {
/**
* @brief Check the unitarity of the input matrix to a given
* tolerance
*
* @param inv The inverse of the input matrix
* @param in The input matrix to which we're reporting its unitarity
* @param tol Tolerance to which this check is applied
*/
template <typename Float2, typename Float>
__host__ __device__ int checkUnitary(Matrix<Float2,3> &inv, Matrix<Float2,3> in, const Float tol)
{
computeMatrixInverse(in, &inv);
for (int i=0;i<3;i++)
for (int j=0;j<3;j++)
{
if (fabs(in(i,j).x - inv(j,i).x) > tol)
return 1;
if (fabs(in(i,j).y + inv(j,i).y) > tol)
return 1;
}
return 0;
}
/**
* @brief Check the unitarity of the input matrix to a given
* tolerance (1e-14) and print out deviation for each component (used for
* debugging only).
*
* @param inv The inverse of the input matrix
* @param in The input matrix to which we're reporting its unitarity
*/ template <typename Float2>
__host__ __device__ int checkUnitaryPrint(Matrix<Float2,3> &inv, Matrix<Float2,3> in)
{
computeMatrixInverse(in, &inv);
for (int i=0;i<3;i++)
for (int j=0;j<3;j++)
{
printf("TESTR: %+.3le %+.3le %+.3le\n", in(i,j).x, (*inv)(j,i).x, fabs(in(i,j).x - (*inv)(j,i).x));
printf("TESTI: %+.3le %+.3le %+.3le\n", in(i,j).y, (*inv)(j,i).y, fabs(in(i,j).y + (*inv)(j,i).y));
cudaDeviceSynchronize();
if (fabs(in(i,j).x - inv(j,i).x) > 1e-14)
return 1;
if (fabs(in(i,j).y + inv(j,i).y) > 1e-14)
return 1;
}
return 0;
}
/**
* @brief Project the input matrix on the SU(3) group. First unitarize the matrix and then project onto the special unitary group.
*
* @param in The input matrix to which we're projecting
* @param tol Tolerance to which this check is applied
*/
template <typename Float>
__host__ __device__ void polarSu3(Matrix<complex<Float>,3> &in, Float tol)
{
Matrix<complex<Float>,3> inv, out;
out = in;
computeMatrixInverse(out, &inv);
// iterate until matrix is unitary
do {
out = out + conj(inv);
out = out*0.5;
} while(checkUnitary(inv, out, tol));
// now project onto special unitary group
complex<Float> det = getDeterminant(out);
double mod = det.x*det.x + det.y*det.y;
mod = pow(mod, (1./6.));
double angle = atan2(det.y, det.x);
angle /= -3.;
complex<Float> cTemp;
cTemp.x = cos(angle)/mod;
cTemp.y = sin(angle)/mod;
in = out*cTemp;
/* if (checkUnitary(inv, out))
{
cTemp = getDeterminant(out);
printf ("DetX: %+.3lf %+.3lfi, %.3lf %.3lf\nDetN: %+.3lf %+.3lfi", det.x, det.y, mod, angle, cTemp.x, cTemp.y);
cudaDeviceSynchronize();
checkUnitaryPrint(out, &inv);
setIdentity(in);
*in = *in * 0.5;
}
else
{
cTemp = getDeterminant(out);
// printf("Det: %+.3lf %+.3lf\n", cTemp.x, cTemp.y);
cudaDeviceSynchronize();
if (fabs(cTemp.x - 1.0) > 1e-8)
setIdentity(in);
else if (fabs(cTemp.y) > 1e-8)
{
setIdentity(in);
printf("DadadaUnitary failed\n");
*in = *in * 0.1;
}
else
*in = out;
}*/
}
} // namespace quda