Skip to content

Commit

Permalink
Make sure the allocated MTLBuffer is largest can be.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 21, 2022
1 parent e8453ea commit 99e0ba4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/nnc/ccv_nnc_symbolic_graph_compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ static void _ccv_nnc_tensor_arena_obj_dispose(void* ptr, void* userdata)
#endif
KHASH_MAP_INIT_INT64(obj_ptr, void*)

static inline void* _ccv_nnc_tensor_arena_obj_create(khash_t(obj_ptr)* obj_ptr_map, void* ptr, const ccv_nnc_tensor_param_t params, ccv_nnc_tensor_arena_t* tensor_arena)
static inline void* _ccv_nnc_tensor_arena_obj_create(khash_t(obj_ptr)* obj_ptr_map, void* ptr, const size_t size, const ccv_nnc_tensor_param_t params, ccv_nnc_tensor_arena_t* tensor_arena)
{
#ifdef HAVE_MPS
if (CCV_TENSOR_GET_MEMORY(params.type) == CCV_TENSOR_GPU_MEMORY)
Expand All @@ -1081,7 +1081,7 @@ static inline void* _ccv_nnc_tensor_arena_obj_create(khash_t(obj_ptr)* obj_ptr_m
khiter_t k = kh_put(obj_ptr, obj_ptr_map, (uint64_t)(uintptr_t)ptr, &ret);
if (ret != 0)
{
void* obj = mpobjcreate(ptr, CCV_GET_DATA_TYPE_SIZE(params.datatype) * ccv_nnc_tensor_count(params));
void* obj = mpobjcreate(ptr, size);
if (!tensor_arena->disposers)
tensor_arena->disposers = ccv_array_new(sizeof(ccv_nnc_arena_disposer_t), 1, 0);
ccv_nnc_arena_disposer_t disposer = {
Expand Down Expand Up @@ -1307,7 +1307,7 @@ static ccv_nnc_tensor_arena_t* _ccv_nnc_tensor_arena_new(ccv_nnc_symbolic_graph_
ccv_nnc_tensor_t* const tensor = _ccv_nnc_tensor_metadata_get(tensor_arena->tensor_metadata, pos);
// Also, set its allocations.
// Since tensor view is bit compatible with tensor, we can just cast.
void* obj = _ccv_nnc_tensor_arena_obj_create(obj_ptr_map, tensor_arena->buffers[buffer_ref].ptr + offset, tensor_symbol_info[i].info, tensor_arena);
void* obj = _ccv_nnc_tensor_arena_obj_create(obj_ptr_map, tensor_arena->buffers[buffer_ref].ptr + offset, tensor_arena->buffers[buffer_ref].size - offset, tensor_symbol_info[i].info, tensor_arena);
*tensor = ccv_nnc_tensor(obj, tensor_symbol_info[i].info, 0);
assert(offset + tensor_blocks[i].size <= tensor_arena->buffers[buffer_ref].size);
// If we need to force broadcast, we need to wrap it in a multiview.
Expand Down

0 comments on commit 99e0ba4

Please sign in to comment.