From b5b56801508eabdbbd92d4b38baa381dae4e881a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 25 Jul 2023 16:24:24 +0300 Subject: [PATCH] Dead code removal (#48) * Remove old commented-out attention code * Mark two functions as likely unused * Use exists() and default() from sgm.util --- sgm/modules/attention.py | 38 ++++---------------------------------- 1 file changed, 4 insertions(+), 34 deletions(-) diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 38af6e65..4cd2908d 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -1,6 +1,5 @@ import logging import math -from inspect import isfunction from typing import Any, Optional import torch @@ -8,6 +7,8 @@ from einops import rearrange, repeat from packaging import version from torch import nn +from ..util import exists, default + logger = logging.getLogger(__name__) @@ -58,25 +59,11 @@ from .diffusionmodules.util import checkpoint -def exists(val): - return val is not None - - -def uniq(arr): +def uniq(arr): # TODO: this seems unused return {el: True for el in arr}.keys() -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): +def init_(tensor): # TODO: this seems unused dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) @@ -256,23 +243,6 @@ def forward( q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - ## old - """ - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - """ - ## new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) out = F.scaled_dot_product_attention(