Skip to content

Commit 58b928a

Browse files
mrkwjcpavanky
authored andcommitted
Add conjugate gradient benchmark.
Tests dense and sparse arrayfire behavior against numpy and scipy solutions. Works in float32 presision.
1 parent a684355 commit 58b928a

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

examples/benchmarks/bench_cg.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/python
2+
3+
#######################################################
4+
# Copyright (c) 2015, ArrayFire
5+
# All rights reserved.
6+
#
7+
# This file is distributed under 3-clause BSD license.
8+
# The complete license agreement can be obtained at:
9+
# http://arrayfire.com/licenses/BSD-3-Clause
10+
########################################################
11+
12+
13+
import sys
14+
from time import time
15+
import arrayfire as af
16+
17+
try:
18+
import numpy as np
19+
except:
20+
np = None
21+
22+
try:
23+
from scipy import sparse as sp
24+
from scipy.sparse import linalg
25+
except:
26+
sp = None
27+
28+
29+
def to_numpy(A):
30+
return np.asarray(A.to_list(), dtype=np.float32)
31+
32+
33+
def to_sparse(A):
34+
return af.sparse.create_sparse_from_dense(A)
35+
36+
37+
def to_scipy_sparse(spA, fmt='csr'):
38+
vals = np.asarray(af.sparse.sparse_get_values(spA).to_list(),
39+
dtype = np.float32)
40+
rows = np.asarray(af.sparse.sparse_get_row_idx(spA).to_list(),
41+
dtype = np.int)
42+
cols = np.asarray(af.sparse.sparse_get_col_idx(spA).to_list(),
43+
dtype = np.int)
44+
return sp.csr_matrix((vals, cols, rows), dtype=np.float32)
45+
46+
47+
def setup_input(n, sparsity=7):
48+
T = af.randu(n, n, dtype=af.Dtype.f32)
49+
A = af.floor(T*1000)
50+
A = A * ((A % sparsity) == 0) / 1000
51+
A = A.T + A + n*af.identity(n, n, dtype=af.Dtype.f32)
52+
x0 = af.randu(n, dtype=af.Dtype.f32)
53+
b = af.matmul(A, x0)
54+
# printing
55+
# nnz = af.sum((A != 0))
56+
# print "Sparsity of A: %2.2f %%" %(100*nnz/n**2,)
57+
return A, b, x0
58+
59+
60+
def input_info(A, Asp):
61+
m, n = A.dims()
62+
nnz = af.sum((A != 0))
63+
print(" matrix size: %i x %i" %(m, n))
64+
print(" matrix sparsity: %2.2f %%" %(100*nnz/n**2,))
65+
print(" dense matrix memory usage: ")
66+
print(" sparse matrix memory usage: ")
67+
68+
69+
def calc_arrayfire(A, b, x0, maxiter=10):
70+
x = af.constant(0, b.dims()[0], dtype=af.Dtype.f32)
71+
r = b - af.matmul(A, x)
72+
p = r
73+
for i in range(maxiter):
74+
Ap = af.matmul(A, p)
75+
alpha_num = af.dot(r, r)
76+
alpha_den = af.dot(p, Ap)
77+
alpha = alpha_num/alpha_den
78+
r -= af.tile(alpha, Ap.dims()[0]) * Ap
79+
x += af.tile(alpha, Ap.dims()[0]) * p
80+
beta_num = af.dot(r, r)
81+
beta = beta_num/alpha_num
82+
p = r + af.tile(beta, p.dims()[0]) * p
83+
res = x0 - x
84+
return x, af.dot(res, res)
85+
86+
87+
def calc_numpy(A, b, x0, maxiter=10):
88+
x = np.zeros(len(b), dtype=np.float32)
89+
r = b - np.dot(A, x)
90+
p = r.copy()
91+
for i in range(maxiter):
92+
Ap = np.dot(A, p)
93+
alpha_num = np.dot(r, r)
94+
alpha_den = np.dot(p, Ap)
95+
alpha = alpha_num/alpha_den
96+
r -= alpha * Ap
97+
x += alpha * p
98+
beta_num = np.dot(r, r)
99+
beta = beta_num/alpha_num
100+
p = r + beta * p
101+
res = x0 - x
102+
return x, np.dot(res, res)
103+
104+
105+
def calc_scipy_sparse(A, b, x0, maxiter=10):
106+
x = np.zeros(len(b), dtype=np.float32)
107+
r = b - A*x
108+
p = r.copy()
109+
for i in range(maxiter):
110+
Ap = A*p
111+
alpha_num = np.dot(r, r)
112+
alpha_den = np.dot(p, Ap)
113+
alpha = alpha_num/alpha_den
114+
r -= alpha * Ap
115+
x += alpha * p
116+
beta_num = np.dot(r, r)
117+
beta = beta_num/alpha_num
118+
p = r + beta * p
119+
res = x0 - x
120+
return x, np.dot(res, res)
121+
122+
123+
def calc_scipy_sparse_linalg_cg(A, b, x0, maxiter=10):
124+
x = np.zeros(len(b), dtype=np.float32)
125+
x, _ = linalg.cg(A, b, x, tol=0., maxiter=maxiter)
126+
res = x0 - x
127+
return x, np.dot(res, res)
128+
129+
130+
def timeit(calc, iters, args):
131+
t0 = time()
132+
for i in range(iters):
133+
calc(*args)
134+
dt = time() - t0
135+
return 1000*dt/iters # ms
136+
137+
138+
def test():
139+
print("\nTesting benchmark functions...")
140+
A, b, x0 = setup_input(50) # dense A
141+
Asp = to_sparse(A)
142+
x1, _ = calc_arrayfire(A, b, x0)
143+
x2, _ = calc_arrayfire(Asp, b, x0)
144+
if af.sum(af.abs(x1 - x2)/x2 > 1e-6):
145+
raise ValueError("arrayfire test failed")
146+
if np:
147+
An = to_numpy(A)
148+
bn = to_numpy(b)
149+
x0n = to_numpy(x0)
150+
x3, _ = calc_numpy(An, bn, x0n)
151+
if not np.allclose(x3, x1.to_list()):
152+
raise ValueError("numpy test failed")
153+
if sp:
154+
Asc = to_scipy_sparse(Asp)
155+
x4, _ = calc_scipy_sparse(Asc, bn, x0n)
156+
if not np.allclose(x4, x1.to_list()):
157+
raise ValueError("scipy.sparse test failed")
158+
x5, _ = calc_scipy_sparse_linalg_cg(Asc, bn, x0n)
159+
if not np.allclose(x5, x1.to_list()):
160+
raise ValueError("scipy.sparse.linalg.cg test failed")
161+
print(" all tests passed...")
162+
163+
164+
def bench(n=4*1024, sparsity=7, maxiter=10, iters=10):
165+
# generate data
166+
print("\nGenerating benchmark data for n = %i ..." %n)
167+
A, b, x0 = setup_input(n, sparsity) # dense A
168+
Asp = to_sparse(A) # sparse A
169+
input_info(A, Asp)
170+
# make benchmarks
171+
print("Benchmarking CG solver for n = %i ..." %n)
172+
t1 = timeit(calc_arrayfire, iters, args=(A, b, x0, maxiter))
173+
print(" arrayfire - dense: %f ms" %t1)
174+
t2 = timeit(calc_arrayfire, iters, args=(Asp, b, x0, maxiter))
175+
print(" arrayfire - sparse: %f ms" %t2)
176+
if np:
177+
An = to_numpy(A)
178+
bn = to_numpy(b)
179+
x0n = to_numpy(x0)
180+
t3 = timeit(calc_numpy, iters, args=(An, bn, x0n, maxiter))
181+
print(" numpy - dense: %f ms" %t3)
182+
if sp:
183+
Asc = to_scipy_sparse(Asp)
184+
t4 = timeit(calc_scipy_sparse, iters, args=(Asc, bn, x0n, maxiter))
185+
print(" scipy - sparse: %f ms" %t4)
186+
t5 = timeit(calc_scipy_sparse_linalg_cg, iters, args=(Asc, bn, x0n, maxiter))
187+
print(" scipy - sparse.linalg.cg: %f ms" %t5)
188+
189+
if __name__ == "__main__":
190+
#af.set_backend('cpu', unsafe=True)
191+
192+
if (len(sys.argv) > 1):
193+
af.set_device(int(sys.argv[1]))
194+
195+
af.info()
196+
197+
test()
198+
199+
for n in (128, 256, 512, 1024, 2048, 4096):
200+
bench(n)

0 commit comments

Comments
 (0)