Skip to content

Commit

Permalink
Merge pull request Stability-AI#89 from Stability-AI/dango.patch.atte…
Browse files Browse the repository at this point in the history
…n_overflow

* Force cast to fp32 to avoid atten layer overflow
  • Loading branch information
rromb authored Dec 7, 2022
2 parents f547c4a + e1797ae commit 8bde0cf
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
except:
XFORMERS_IS_AVAILBLE = False

# CrossAttn precision handling
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")

def exists(val):
return val is not None
Expand Down Expand Up @@ -167,9 +170,16 @@ def forward(self, x, context=None, mask=None):

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
Expand Down

0 comments on commit 8bde0cf

Please sign in to comment.