Skip to content

Commit

Permalink
Dead code removal (Stability-AI#48)
Browse files Browse the repository at this point in the history
* Remove old commented-out attention code

* Mark two functions as likely unused

* Use exists() and default() from sgm.util
  • Loading branch information
akx authored Jul 25, 2023
1 parent 6f6d3f8 commit b5b5680
Showing 1 changed file with 4 additions and 34 deletions.
38 changes: 4 additions & 34 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
from ..util import exists, default



logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,25 +59,11 @@
from .diffusionmodules.util import checkpoint


def exists(val):
return val is not None


def uniq(arr):
def uniq(arr): # TODO: this seems unused
return {el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


def max_neg_value(t):
return -torch.finfo(t.dtype).max


def init_(tensor):
def init_(tensor): # TODO: this seems unused
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
Expand Down Expand Up @@ -256,23 +243,6 @@ def forward(

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
Expand Down

0 comments on commit b5b5680

Please sign in to comment.