Skip to content

Commit

Permalink
Prepare APIs for non-default streams (#93)
Browse files Browse the repository at this point in the history
Details:
- instead of using the default null stream in HIP queue functions, get
  the stream from the rocblas_handle associated with the thread
- add a stream sync function
- rename the HIP write API to write_async to correctly signal intent
- change the device allocation to be asynchronous (and rename API)
- write something unique into the buffer_hip pointer in the managed memory
  case to ensure comparisons are correct
  • Loading branch information
iotamudelta authored Apr 28, 2023
1 parent 7e94bbd commit 450c109
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 49 deletions.
7 changes: 4 additions & 3 deletions src/base/flamec/supermatrix/hip/include/FLASH_Queue_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ void FLASH_Queue_set_hip_num_blocks( dim_t n_blocks );
dim_t FLASH_Queue_get_hip_num_blocks( void );

FLA_Error FLASH_Queue_bind_hip( int thread );
FLA_Error FLASH_Queue_alloc_hip( dim_t size, FLA_Datatype datatype, void** buffer_hip );
FLA_Error FLASH_Queue_free_async_hip( void* buffer_hip );
FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_alloc_async_hip( int thread, dim_t size, FLA_Datatype datatype, void** buffer_hip );
FLA_Error FLASH_Queue_free_async_hip( int thread, void* buffer_hip );
FLA_Error FLASH_Queue_write_async_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_read_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_sync_stream_hip( int thread );
FLA_Error FLASH_Queue_sync_device_hip( int device );
FLA_Error FLASH_Queue_sync_hip( );

Expand Down
98 changes: 55 additions & 43 deletions src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,23 @@ FLA_Error FLASH_Queue_bind_hip( int thread )
{

// Bind a HIP device to this thread.
hipSetDevice( thread );
if ( hipSetDevice( thread ) != hipSuccess ) return FLA_FAILURE;

// initialize its rocBLAS handle
if ( handles[thread] == NULL )
rocblas_create_handle( &(handles[thread]) );
{
//hipStream_t stream;
//if ( hipStreamCreate(&stream) != hipSuccess ) return FLA_FAILURE;
if ( rocblas_create_handle( &(handles[thread]) ) != rocblas_status_success ) return FLA_FAILURE;
//if ( rocblas_set_stream( handles[thread], stream ) != rocblas_status_success ) return FLA_FAILURE;
}

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_alloc_hip( dim_t size,
FLA_Error FLASH_Queue_alloc_async_hip( int thread,
dim_t size,
FLA_Datatype datatype,
void** buffer_hip )
/*----------------------------------------------------------------------------
Expand All @@ -269,11 +275,14 @@ FLA_Error FLASH_Queue_alloc_hip( dim_t size,
----------------------------------------------------------------------------*/
{
hipError_t status;

// Allocate memory for a block on HIP.
status = hipMalloc( buffer_hip,
size * FLA_Obj_datatype_size( datatype ) );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
//fprintf(stdout, "Trying to allocate %ld bytes on device %d\n", size, thread);
hipError_t status = hipMallocAsync( buffer_hip,
size * FLA_Obj_datatype_size( datatype ),
stream );

// Check to see if the allocation was successful.
if ( status != hipSuccess )
Expand All @@ -284,28 +293,34 @@ FLA_Error FLASH_Queue_alloc_hip( dim_t size,
FLA_Check_error_code( FLA_MALLOC_GPU_RETURNED_NULL_POINTER );
}

//fprintf( stdout, "allocating on thread %d for pointer %p\n", thread, *buffer_hip );

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_free_async_hip( void* buffer_hip )
FLA_Error FLASH_Queue_free_async_hip( int thread, void* buffer_hip )
/*----------------------------------------------------------------------------
FLASH_Queue_free_async_hip
----------------------------------------------------------------------------*/
{
// Free memory for a block on HIP.
hipFreeAsync( (hipStream_t) 0, buffer_hip );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
hipFreeAsync( stream, buffer_hip );

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip )
FLA_Error FLASH_Queue_write_async_hip( int thread,
FLA_Obj obj,
void* buffer_hip )
/*----------------------------------------------------------------------------
FLASH_Queue_write_hip
FLASH_Queue_write_async_hip
----------------------------------------------------------------------------*/
{
Expand All @@ -317,11 +332,13 @@ FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip )
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipMemcpyAsync( buffer_hip,
FLA_Obj_buffer_at_view( obj ),
count,
hipMemcpyHostToDevice,
(hipStream_t) 0 );
stream );

if ( err != hipSuccess )
{
Expand All @@ -342,39 +359,11 @@ FLA_Error FLASH_Queue_read_hip( int thread, FLA_Obj obj, void* buffer_hip )
----------------------------------------------------------------------------*/
{
if ( flash_malloc_managed_hip )
{
// inject a stream sync on the rocBLAS stream to ensure completion
hipError_t err = hipStreamSynchronize( (hipStream_t) 0 );
if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to synchronize on HIP stream. err=%d\n",
err );
return FLA_FAILURE;
}
return FLA_SUCCESS;
}

// Read the memory of a block on HIP to main memory.
hipSetDevice( thread );
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
const hipError_t err = hipMemcpy( FLA_Obj_buffer_at_view( obj ),
buffer_hip,
count,
hipMemcpyDeviceToHost );
FLA_Error err1 = FLASH_Queue_read_async_hip( thread, obj, buffer_hip );
if ( err1 != FLA_SUCCESS ) return err1;

if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to read block from HIP device. Size=%ld, err=%d\n",
count, err );
return FLA_FAILURE;
}

return FLA_SUCCESS;
return FLASH_Queue_sync_stream_hip( thread );
}

FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip )
Expand All @@ -394,11 +383,13 @@ FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipMemcpyAsync( FLA_Obj_buffer_at_view( obj ),
buffer_hip,
count,
hipMemcpyDeviceToHost,
(hipStream_t) 0 );
stream );

if ( err != hipSuccess )
{
Expand Down Expand Up @@ -431,6 +422,27 @@ FLA_Error FLASH_Queue_sync_device_hip( int device )
return FLA_SUCCESS;
}

FLA_Error FLASH_Queue_sync_stream_hip( int thread )
/*----------------------------------------------------------------------------
FLASH_Queue_sync_stream_hip
----------------------------------------------------------------------------*/
{
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipStreamSynchronize( stream );
if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to sync HIP stream. Thread=%d, err=%d\n",
thread, err );
return FLA_FAILURE;
}

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_sync_hip( )
/*----------------------------------------------------------------------------
Expand Down
15 changes: 12 additions & 3 deletions src/base/flamec/supermatrix/main/FLASH_Queue_exec.c
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,16 @@ void FLASH_Queue_create_hip( int thread, void *arg )
{
// Allocate the memory on the HIP device for all the blocks a priori.
for ( i = 0; i < hip_n_blocks; i++ )
FLASH_Queue_alloc_hip( block_size, datatype, &(args->hip[thread * hip_n_blocks + i].buffer_hip) );
FLASH_Queue_alloc_async_hip( thread,
block_size,
datatype,
&(args->hip[thread * hip_n_blocks + i].buffer_hip) );
}
else
{
// write something into the buffer_hip pointer to make it unique for tracking
for ( i = 0; i < hip_n_blocks; i++ )
args->hip[thread * hip_n_blocks + i].buffer_hip = (void*) (thread * hip_n_blocks + i);
}

return;
Expand Down Expand Up @@ -2279,7 +2288,7 @@ void FLASH_Queue_destroy_hip( int thread, void *arg )
if ( hip_obj.obj.base != NULL && !hip_obj.clean )
FLASH_Queue_read_async_hip( thread, hip_obj.obj, hip_obj.buffer_hip );
// Free the memory on the HIP for all the blocks.
FLASH_Queue_free_async_hip( hip_obj.buffer_hip );
FLASH_Queue_free_async_hip( thread, hip_obj.buffer_hip );
}

return;
Expand Down Expand Up @@ -2786,7 +2795,7 @@ void FLASH_Queue_update_block_hip( FLA_Obj obj,

// Move the block to the HIP device.
if ( transfer )
FLASH_Queue_write_hip( hip_obj.obj, hip_obj.buffer_hip );
FLASH_Queue_write_async_hip( thread, hip_obj.obj, hip_obj.buffer_hip );

return;
}
Expand Down

0 comments on commit 450c109

Please sign in to comment.