Skip to content

Commit

Permalink
[Stdlib] Restore bf16<->f32 conversion functions (#33911)
Browse files Browse the repository at this point in the history
Those were reverted in [Internal Link]
because we can rely on compilerrt to do the conversion, but it's
actually useful to have these in mojo as well (for parameter eval for
example), so restore them.

modular-orig-commit: 4e2a3e69f508c408077a59ae2e0f2166f51b3cb6
  • Loading branch information
abduld authored Mar 6, 2024
1 parent 3941b19 commit 87773c1
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 16 deletions.
150 changes: 134 additions & 16 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These are Mojo built-ins, so you don't need to import them.

from math._numerics import FPUtils
from math.limit import inf, neginf
from math.math import _simd_apply, nan
from math.math import _simd_apply, nan, isnan
from sys import llvm_intrinsic
from sys.info import has_neon, is_x86, simdwidthof

Expand Down Expand Up @@ -81,6 +81,23 @@ fn _simd_construction_checks[type: DType, size: Int]():
]()


@always_inline("nodebug")
fn _unchecked_zero[type: DType, size: Int]() -> SIMD[type, size]:
var zero = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<index>`],
value = __mlir_attr[`#pop.simd<0> : !pop.scalar<index>`],
]()
)
return SIMD[type, size] {
value: __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](zero)
}


@lldb_formatter_wrapping_type
@register_passable("trivial")
struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Expand Down Expand Up @@ -117,21 +134,7 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
SIMD vector whose elements are 0.
"""
_simd_construction_checks[type, size]()
var zero = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<index>`],
value = __mlir_attr[`#pop.simd<0> : !pop.scalar<index>`],
]()
)
return Self {
value: __mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `, `, type.value, `>`
]
](zero)
}
return _unchecked_zero[type, size]()

@always_inline("nodebug")
fn __init__(value: SIMD[DType.float64, 1]) -> Self:
Expand Down Expand Up @@ -405,11 +408,25 @@ struct SIMD[type: DType, size: Int = simdwidthof[type]()](
element type.
"""

@parameter
if has_neon() and (type == DType.bfloat16 or target == DType.bfloat16):
# BF16 support on neon systems is not supported.
return _unchecked_zero[target, size]()

@parameter
if type == DType.bool:
return self.select(SIMD[target, size](1), SIMD[target, size](0))
elif target == DType.bool:
return rebind[SIMD[target, size]](self != 0)
elif type == DType.bfloat16:
var cast_result = _bfloat16_to_f32(
rebind[SIMD[DType.bfloat16, size]](self)
).cast[target]()
return rebind[SIMD[target, size]](cast_result)
elif target == DType.bfloat16:
return rebind[SIMD[target, size]](
_f32_to_bfloat16(self.cast[DType.float32]())
)
elif target == DType.address:
var index_val = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<`, size.value, `, index>`]
Expand Down Expand Up @@ -2175,3 +2192,104 @@ fn _floor[
return _floor(x.cast[DType.float32]()).cast[type]()

return llvm_intrinsic["llvm.floor", SIMD[type, simd_width]](x)


# ===----------------------------------------------------------------------===#
# bfloat16
# ===----------------------------------------------------------------------===#

alias _fp32_bf16_mantissa_diff = FPUtils[
DType.float32
].mantissa_width() - FPUtils[DType.bfloat16].mantissa_width()


@always_inline
fn _bfloat16_to_f32_scalar(
val: Scalar[DType.bfloat16],
) -> Scalar[DType.float32]:
@parameter
if has_neon():
# BF16 support on neon systems is not supported.
return _unchecked_zero[DType.float32, 1]()

var bfloat_bits = FPUtils.bitcast_to_integer(val)
return FPUtils[DType.float32].bitcast_from_integer(
bfloat_bits << _fp32_bf16_mantissa_diff
)


@always_inline
fn _bfloat16_to_f32[
size: Int
](val: SIMD[DType.bfloat16, size]) -> SIMD[DType.float32, size]:
@parameter
if has_neon():
# BF16 support on neon systems is not supported.
return _unchecked_zero[DType.float32, size]()

@always_inline
@parameter
fn wrapper_fn[
input_type: DType, result_type: DType
](val: Scalar[input_type]) capturing -> Scalar[result_type]:
return rebind[Scalar[result_type]](
_bfloat16_to_f32_scalar(rebind[Scalar[DType.bfloat16]](val))
)

return _simd_apply[
size,
DType.bfloat16,
DType.float32,
wrapper_fn,
](val)


@always_inline
fn _f32_to_bfloat16_scalar(
val: Scalar[DType.float32],
) -> Scalar[DType.bfloat16]:
@parameter
if has_neon():
# BF16 support on neon systems is not supported.
return _unchecked_zero[DType.bfloat16, 1]()

if isnan(val):
return -nan[DType.bfloat16]() if FPUtils.get_sign(val) else nan[
DType.bfloat16
]()

var float_bits = FPUtils.bitcast_to_integer(val)

var lsb = (float_bits >> _fp32_bf16_mantissa_diff) & 1
var rounding_bias = 0x7FFF + lsb
float_bits += rounding_bias

var bfloat_bits = float_bits >> _fp32_bf16_mantissa_diff

return FPUtils[DType.bfloat16].bitcast_from_integer(bfloat_bits)


@always_inline
fn _f32_to_bfloat16[
size: Int
](val: SIMD[DType.float32, size]) -> SIMD[DType.bfloat16, size]:
@parameter
if has_neon():
# BF16 support on neon systems is not supported.
return _unchecked_zero[DType.bfloat16, size]()

@always_inline
@parameter
fn wrapper_fn[
input_type: DType, result_type: DType
](val: Scalar[input_type]) capturing -> Scalar[result_type]:
return rebind[Scalar[result_type]](
_f32_to_bfloat16_scalar(rebind[Scalar[DType.float32]](val))
)

return _simd_apply[
size,
DType.float32,
DType.bfloat16,
wrapper_fn,
](val)
4 changes: 4 additions & 0 deletions stdlib/test/builtin/test_bfloat16.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def test_methods():
assert_equal(int(BFloat16(3.5)), 3)

assert_almost_equal(BFloat16(4.4).cast[DType.float32](), 4.40625)
assert_equal(BFloat16(3.0).cast[DType.float32](), 3)
assert_equal(BFloat16(-3.0).cast[DType.float32](), -3)

assert_almost_equal(Float32(4.4).cast[DType.bfloat16](), 4.4)

assert_almost_equal(BFloat16(2.0), 2.0)


Expand Down

0 comments on commit 87773c1

Please sign in to comment.