Skip to content

Commit

Permalink
intel benchmark matmul gets 60 TFLOPS?
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Jun 4, 2023
1 parent 657e642 commit fbf17f0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
57 changes: 57 additions & 0 deletions extra/intel/benchmark_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time

onnx_path = "/tmp/my.onnx"
N = 2048
CNT = 400

"""
import torch
import torch.nn as nn
#dtype = torch.bfloat16
dtype = torch.float32
class MatMul(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(N, N, bias=False)
def forward(self, x):
x = x.to(dtype)
for i in range(CNT): x = self.a(x).relu()
return x.to(torch.float32)
torch_model = MatMul().to(dtype)
torch.onnx.export(torch_model, torch.randn(N, N), onnx_path)
"""

"""
import onnx
from tinygrad.tensor import Tensor
from extra.onnx import get_run_onnx
out = get_run_onnx(onnx.load(onnx_path))({"onnx::MatMul_0": Tensor.zeros(N, N)})
for x in out.values(): x.realize()
"""

from openvino.runtime import Core
core = Core()
devices = core.available_devices
for device in devices:
device_name = core.get_property(device, "FULL_DEVICE_NAME")
print(f"{device}: {device_name}")
model = core.read_model(onnx_path)
compiled_model = core.compile_model(model, device_name='GPU.0')
print(compiled_model)
ireq = compiled_model.create_infer_request()
for model_input in compiled_model.inputs:
tensor = ireq.get_tensor(model_input)
tensor.data[:] = 2
print(tensor)
print("request")
ireq.infer()
ireq.infer()
print("did one")

REPS = 20
st = time.perf_counter()
for i in range(REPS): ireq.infer()
et = time.perf_counter() - st
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")

10 changes: 7 additions & 3 deletions extra/intel/joint_matrix_bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A, big_matri

queue q;
auto start = std::chrono::steady_clock::now();
q.submit(program).wait();
auto e = q.submit(program);
auto submit = std::chrono::steady_clock::now();
e.wait();
auto end = std::chrono::steady_clock::now();
std::cout << "compute: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms" << std::endl;
std::cout << "submit: " << std::chrono::duration_cast<std::chrono::milliseconds>(submit - start).count() << " ms" << std::endl;
std::cout << "compute: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - submit).count() << " ms" << std::endl;

// ahh, freeing is slow
}

//#define SCALE 1024
#define SCALE 64
//#define SCALE 64
#define SCALE 256
static constexpr size_t MATRIX_M = TM * SCALE;
static constexpr size_t MATRIX_N = TN * SCALE;
static constexpr size_t MATRIX_K = TK * SCALE;
Expand Down

0 comments on commit fbf17f0

Please sign in to comment.