You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix attention autocast (#94)
Torch autocasts attention weights into FP32 because of softmax, but
doesn't autocast back into the user-specified data type.
Up until recently, we explicitly passed the autocast dtype in all
autograd function wrappers (reference:
https://github.com/SHI-Labs/NATTEN/blob/3b54c76185904f3cb59a49fff7bc044e4513d106/src/natten/functional.py#L149),
but this is wrong, because the user might be doing BF16.
According to the latest torch documentation, this has not been changed
since the first NATTEN release.
Because it's error prone, this commit explicitly calls cast on all
attention tensors to match the dtype of value. If it's already matching,
torch will ignore it, and it shouldn't really get in the way
of AMP mechanics.
Reference: #93