diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 9340ddfb..1de2918f 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -1,5 +1,6 @@ import logging import math +from inspect import isfunction from typing import Any, Optional import torch @@ -7,8 +8,6 @@ from einops import rearrange, repeat from packaging import version from torch import nn -from ..util import exists, default - logger = logging.getLogger(__name__) @@ -59,11 +58,25 @@ from .diffusionmodules.util import checkpoint -def uniq(arr): # TODO: this seems unused +def exists(val): + return val is not None + + +def uniq(arr): return {el: True for el in arr}.keys() -def init_(tensor): # TODO: this seems unused +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) @@ -243,6 +256,23 @@ def forward( q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + ## old + """ + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + """ + ## new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) out = F.scaled_dot_product_attention(