forked from Stability-AI/generative-models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move attention testing/benchmarking code out of package (Stability-AI#47
- Loading branch information
Showing
2 changed files
with
319 additions
and
314 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
import torch | ||
import einops | ||
from torch.backends.cuda import SDPBackend | ||
import torch.nn.functional as F | ||
import torch.utils.benchmark as benchmark | ||
|
||
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock | ||
|
||
|
||
def benchmark_attn(): | ||
# Lets define a helpful benchmarking function: | ||
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | ||
t0 = benchmark.Timer( | ||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | ||
) | ||
return t0.blocked_autorange().mean * 1e6 | ||
|
||
# Lets define the hyper-parameters of our input | ||
batch_size = 32 | ||
max_sequence_len = 1024 | ||
num_heads = 32 | ||
embed_dimension = 32 | ||
|
||
dtype = torch.float16 | ||
|
||
query = torch.rand( | ||
batch_size, | ||
num_heads, | ||
max_sequence_len, | ||
embed_dimension, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
key = torch.rand( | ||
batch_size, | ||
num_heads, | ||
max_sequence_len, | ||
embed_dimension, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
value = torch.rand( | ||
batch_size, | ||
num_heads, | ||
max_sequence_len, | ||
embed_dimension, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
|
||
print(f"q/k/v shape:", query.shape, key.shape, value.shape) | ||
|
||
# Lets explore the speed of each of the 3 implementations | ||
from torch.backends.cuda import SDPBackend, sdp_kernel | ||
|
||
# Helpful arguments mapper | ||
backend_map = { | ||
SDPBackend.MATH: { | ||
"enable_math": True, | ||
"enable_flash": False, | ||
"enable_mem_efficient": False, | ||
}, | ||
SDPBackend.FLASH_ATTENTION: { | ||
"enable_math": False, | ||
"enable_flash": True, | ||
"enable_mem_efficient": False, | ||
}, | ||
SDPBackend.EFFICIENT_ATTENTION: { | ||
"enable_math": False, | ||
"enable_flash": False, | ||
"enable_mem_efficient": True, | ||
}, | ||
} | ||
|
||
from torch.profiler import ProfilerActivity, profile, record_function | ||
|
||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] | ||
|
||
print( | ||
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | ||
) | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("Default detailed stats"): | ||
for _ in range(25): | ||
o = F.scaled_dot_product_attention(query, key, value) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
|
||
print( | ||
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | ||
) | ||
with sdp_kernel(**backend_map[SDPBackend.MATH]): | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("Math implmentation stats"): | ||
for _ in range(25): | ||
o = F.scaled_dot_product_attention(query, key, value) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
|
||
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): | ||
try: | ||
print( | ||
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | ||
) | ||
except RuntimeError: | ||
print("FlashAttention is not supported. See warnings for reasons.") | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("FlashAttention stats"): | ||
for _ in range(25): | ||
o = F.scaled_dot_product_attention(query, key, value) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
|
||
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): | ||
try: | ||
print( | ||
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | ||
) | ||
except RuntimeError: | ||
print("EfficientAttention is not supported. See warnings for reasons.") | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("EfficientAttention stats"): | ||
for _ in range(25): | ||
o = F.scaled_dot_product_attention(query, key, value) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
|
||
|
||
def run_model(model, x, context): | ||
return model(x, context) | ||
|
||
|
||
def benchmark_transformer_blocks(): | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
import torch.utils.benchmark as benchmark | ||
|
||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | ||
t0 = benchmark.Timer( | ||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | ||
) | ||
return t0.blocked_autorange().mean * 1e6 | ||
|
||
checkpoint = True | ||
compile = False | ||
|
||
batch_size = 32 | ||
h, w = 64, 64 | ||
context_len = 77 | ||
embed_dimension = 1024 | ||
context_dim = 1024 | ||
d_head = 64 | ||
|
||
transformer_depth = 4 | ||
|
||
n_heads = embed_dimension // d_head | ||
|
||
dtype = torch.float16 | ||
|
||
model_native = SpatialTransformer( | ||
embed_dimension, | ||
n_heads, | ||
d_head, | ||
context_dim=context_dim, | ||
use_linear=True, | ||
use_checkpoint=checkpoint, | ||
attn_type="softmax", | ||
depth=transformer_depth, | ||
sdp_backend=SDPBackend.FLASH_ATTENTION, | ||
).to(device) | ||
model_efficient_attn = SpatialTransformer( | ||
embed_dimension, | ||
n_heads, | ||
d_head, | ||
context_dim=context_dim, | ||
use_linear=True, | ||
depth=transformer_depth, | ||
use_checkpoint=checkpoint, | ||
attn_type="softmax-xformers", | ||
).to(device) | ||
if not checkpoint and compile: | ||
print("compiling models") | ||
model_native = torch.compile(model_native) | ||
model_efficient_attn = torch.compile(model_efficient_attn) | ||
|
||
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) | ||
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) | ||
|
||
from torch.profiler import ProfilerActivity, profile, record_function | ||
|
||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] | ||
|
||
with torch.autocast("cuda"): | ||
print( | ||
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" | ||
) | ||
print( | ||
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" | ||
) | ||
|
||
print(75 * "+") | ||
print("NATIVE") | ||
print(75 * "+") | ||
torch.cuda.reset_peak_memory_stats() | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("NativeAttention stats"): | ||
for _ in range(25): | ||
model_native(x, c) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") | ||
|
||
print(75 * "+") | ||
print("Xformers") | ||
print(75 * "+") | ||
torch.cuda.reset_peak_memory_stats() | ||
with profile( | ||
activities=activities, record_shapes=False, profile_memory=True | ||
) as prof: | ||
with record_function("xformers stats"): | ||
for _ in range(25): | ||
model_efficient_attn(x, c) | ||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | ||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") | ||
|
||
|
||
def test01(): | ||
# conv1x1 vs linear | ||
from sgm.util import count_params | ||
|
||
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() | ||
print(count_params(conv)) | ||
linear = torch.nn.Linear(3, 32).cuda() | ||
print(count_params(linear)) | ||
|
||
print(conv.weight.shape) | ||
|
||
# use same initialization | ||
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) | ||
linear.bias = torch.nn.Parameter(conv.bias) | ||
|
||
print(linear.weight.shape) | ||
|
||
x = torch.randn(11, 3, 64, 64).cuda() | ||
|
||
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() | ||
print(xr.shape) | ||
out_linear = linear(xr) | ||
print(out_linear.mean(), out_linear.shape) | ||
|
||
out_conv = conv(x) | ||
print(out_conv.mean(), out_conv.shape) | ||
print("done with test01.\n") | ||
|
||
|
||
def test02(): | ||
# try cosine flash attention | ||
import time | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cudnn.benchmark = True | ||
print("testing cosine flash attention...") | ||
DIM = 1024 | ||
SEQLEN = 4096 | ||
BS = 16 | ||
|
||
print(" softmax (vanilla) first...") | ||
model = BasicTransformerBlock( | ||
dim=DIM, | ||
n_heads=16, | ||
d_head=64, | ||
dropout=0.0, | ||
context_dim=None, | ||
attn_mode="softmax", | ||
).cuda() | ||
try: | ||
x = torch.randn(BS, SEQLEN, DIM).cuda() | ||
tic = time.time() | ||
y = model(x) | ||
toc = time.time() | ||
print(y.shape, toc - tic) | ||
except RuntimeError as e: | ||
# likely oom | ||
print(str(e)) | ||
|
||
print("\n now flash-cosine...") | ||
model = BasicTransformerBlock( | ||
dim=DIM, | ||
n_heads=16, | ||
d_head=64, | ||
dropout=0.0, | ||
context_dim=None, | ||
attn_mode="flash-cosine", | ||
).cuda() | ||
x = torch.randn(BS, SEQLEN, DIM).cuda() | ||
tic = time.time() | ||
y = model(x) | ||
toc = time.time() | ||
print(y.shape, toc - tic) | ||
print("done with test02.\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
# test01() | ||
# test02() | ||
# test03() | ||
|
||
# benchmark_attn() | ||
benchmark_transformer_blocks() | ||
|
||
print("done.") |
Oops, something went wrong.