Skip to content

Commit

Permalink
[Stdlib] Move the memcpy implementation to operate on Pointer (#34391)
Browse files Browse the repository at this point in the history
This switches the code so that the memcpy implementatino operates on
Pointer with the DTypePointer overload just performing the forwarding.
This simplifies the code a bit and sets the stage for us to enable
memcpy on static shapes.

modular-orig-commit: ab091589aca0e2d1c062762ebedae08f1021e40c
  • Loading branch information
abduld authored Mar 12, 2024
1 parent 138c654 commit 9be39b5
Showing 1 changed file with 38 additions and 48 deletions.
86 changes: 38 additions & 48 deletions stdlib/src/memory/memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -123,40 +123,10 @@ fn memcpy[
src: The source pointer.
count: The number of elements to copy.
"""
var byte_count = count * sizeof[type]()
memcpy(
DTypePointer[DType.uint8, address_space=address_space](
dest.bitcast[UInt8]()
),
DTypePointer[DType.uint8, address_space=address_space](
src.bitcast[UInt8]()
),
byte_count,
)


fn memcpy[
type: DType, address_space: AddressSpace
](
dest: DTypePointer[type, address_space],
src: DTypePointer[type, address_space],
count: Int,
):
"""Copies a memory area.
Parameters:
type: The element dtype.
address_space: The address space of the pointer.
Args:
dest: The destination pointer.
src: The source pointer.
count: The number of elements to copy (not bytes!).
"""
var n = count * sizeof[type]()

var dest_data = dest.bitcast[DType.uint8]()
var src_data = src.bitcast[DType.uint8]()
var dest_data = dest.bitcast[Int8]()
var src_data = src.bitcast[Int8]()

if n < 5:
if n == 0:
Expand All @@ -171,27 +141,23 @@ fn memcpy[

if n <= 16:
if n >= 8:
var ui64_size = sizeof[DType.uint64]()
dest_data.bitcast[DType.uint64]().store(
src_data.bitcast[DType.uint64]().load()
)
dest_data.offset(n - ui64_size).bitcast[DType.uint64]().store(
src_data.offset(n - ui64_size).bitcast[DType.uint64]().load()
var ui64_size = sizeof[Int64]()
dest_data.bitcast[Int64]().store(src_data.bitcast[Int64]()[0])
dest_data.offset(n - ui64_size).bitcast[Int64]().store(
src_data.offset(n - ui64_size).bitcast[Int64]()[0]
)
return
var ui32_size = sizeof[DType.uint32]()
dest_data.bitcast[DType.uint32]().store(
src_data.bitcast[DType.uint32]().load()
)
dest_data.offset(n - ui32_size).bitcast[DType.uint32]().store(
src_data.offset(n - ui32_size).bitcast[DType.uint32]().load()
var ui32_size = sizeof[Int32]()
dest_data.bitcast[Int32]().store(src_data.bitcast[Int32]()[0])
dest_data.offset(n - ui32_size).bitcast[Int32]().store(
src_data.offset(n - ui32_size).bitcast[Int32]()[0]
)
return

# TODO (#10566): This branch appears to cause a 12% regression in BERT by
# slowing down broadcast ops
# if n <= 32:
# alias simd_16xui8_size = 16 * sizeof[DType.uint8]()
# alias simd_16xui8_size = 16 * sizeof[Int8]()
# dest_data.simd_store[16](src_data.simd_load[16]())
# # note that some of these bytes may have already been written by the
# # previous simd_store
Expand All @@ -200,18 +166,42 @@ fn memcpy[
# )
# return

var dest_dtype_ptr = DTypePointer[DType.int8, address_space](dest_data)
var src_dtype_ptr = DTypePointer[DType.int8, address_space](src_data)

@always_inline
@__copy_capture(dest_data, src_data)
@parameter
fn _copy[simd_width: Int](idx: Int):
dest_data.simd_store[simd_width](
idx, src_data.load[width=simd_width](idx)
dest_dtype_ptr.simd_store(
idx, src_dtype_ptr.load[width=simd_width](idx)
)

# Copy in 32-bit chunks
# Copy in 32-byte chunks.
vectorize[_copy, 32](n)


fn memcpy[
type: DType, address_space: AddressSpace
](
dest: DTypePointer[type, address_space],
src: DTypePointer[type, address_space],
count: Int,
):
"""Copies a memory area.
Parameters:
type: The element dtype.
address_space: The address space of the pointer.
Args:
dest: The destination pointer.
src: The source pointer.
count: The number of elements to copy (not bytes!).
"""
memcpy(dest.address, src.address, count)


# ===----------------------------------------------------------------------===#
# memset
# ===----------------------------------------------------------------------===#
Expand Down

0 comments on commit 9be39b5

Please sign in to comment.