Skip to content

Commit

Permalink
Move attention testing/benchmarking code out of package (Stability-AI#47
Browse files Browse the repository at this point in the history
)
  • Loading branch information
akx authored Jul 25, 2023
1 parent b5b5680 commit 2897fdc
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 314 deletions.
319 changes: 319 additions & 0 deletions scripts/tests/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import torch
import einops
from torch.backends.cuda import SDPBackend
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock


def benchmark_attn():
# Lets define a helpful benchmarking function:
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
device = "cuda" if torch.cuda.is_available() else "cpu"

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
key = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
value = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)

print(f"q/k/v shape:", query.shape, key.shape, value.shape)

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel

# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
}

from torch.profiler import ProfilerActivity, profile, record_function

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]

print(
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

print(
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with sdp_kernel(**backend_map[SDPBackend.MATH]):
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


def run_model(model, x, context):
return model(x, context)


def benchmark_transformer_blocks():
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.utils.benchmark as benchmark

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6

checkpoint = True
compile = False

batch_size = 32
h, w = 64, 64
context_len = 77
embed_dimension = 1024
context_dim = 1024
d_head = 64

transformer_depth = 4

n_heads = embed_dimension // d_head

dtype = torch.float16

model_native = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
use_checkpoint=checkpoint,
attn_type="softmax",
depth=transformer_depth,
sdp_backend=SDPBackend.FLASH_ATTENTION,
).to(device)
model_efficient_attn = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
depth=transformer_depth,
use_checkpoint=checkpoint,
attn_type="softmax-xformers",
).to(device)
if not checkpoint and compile:
print("compiling models")
model_native = torch.compile(model_native)
model_efficient_attn = torch.compile(model_efficient_attn)

x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)

from torch.profiler import ProfilerActivity, profile, record_function

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]

with torch.autocast("cuda"):
print(
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
)
print(
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
)

print(75 * "+")
print("NATIVE")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("NativeAttention stats"):
for _ in range(25):
model_native(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")

print(75 * "+")
print("Xformers")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("xformers stats"):
for _ in range(25):
model_efficient_attn(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")


def test01():
# conv1x1 vs linear
from sgm.util import count_params

conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
print(count_params(conv))
linear = torch.nn.Linear(3, 32).cuda()
print(count_params(linear))

print(conv.weight.shape)

# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
linear.bias = torch.nn.Parameter(conv.bias)

print(linear.weight.shape)

x = torch.randn(11, 3, 64, 64).cuda()

xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
print(xr.shape)
out_linear = linear(xr)
print(out_linear.mean(), out_linear.shape)

out_conv = conv(x)
print(out_conv.mean(), out_conv.shape)
print("done with test01.\n")


def test02():
# try cosine flash attention
import time

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("testing cosine flash attention...")
DIM = 1024
SEQLEN = 4096
BS = 16

print(" softmax (vanilla) first...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="softmax",
).cuda()
try:
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
except RuntimeError as e:
# likely oom
print(str(e))

print("\n now flash-cosine...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="flash-cosine",
).cuda()
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
print("done with test02.\n")


if __name__ == "__main__":
# test01()
# test02()
# test03()

# benchmark_attn()
benchmark_transformer_blocks()

print("done.")
Loading

0 comments on commit 2897fdc

Please sign in to comment.