Skip to content

Commit

Permalink
Conditionally run bmm functions in fp16 based on cuda version
Browse files Browse the repository at this point in the history
  • Loading branch information
carlc-nv committed Mar 27, 2019
1 parent f5cd5ae commit f1123e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
27 changes: 17 additions & 10 deletions apex/amp/lists/torch_overrides.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from .. import utils

MODULE = torch

FP16_FUNCS = [
Expand All @@ -20,10 +22,8 @@
'matmul',
'mm',
'mv',

]

# TODO: ban in-place versions of these in fp16
FP32_FUNCS = [
# Pointwise
'acos',
Expand Down Expand Up @@ -54,15 +54,21 @@
'sum',
'var',

# Special reduction-like BLAS
'addbmm',
'baddbmm',
'bmm',

# Misc
'renorm'
]

# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)

# Multi-tensor fns that may need type promotion
CASTS = [
# Multi-tensor math
Expand All @@ -87,8 +93,9 @@
'ne'
]

# Will possibly need to promote *all* elements of `seq`
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [
'cat', # torch.cat(seq, dim=0, out=None)
'stack' # torch.stack(seq, dim=0, out=None)
'cat',
'stack'
]
3 changes: 3 additions & 0 deletions apex/amp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import torch

def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))

def is_fp_tensor(x):
if is_nested(x):
# Fast-fail version of all(is_fp_tensor)
Expand Down

0 comments on commit f1123e3

Please sign in to comment.