Skip to content

Commit

Permalink
[GPU] Release unused internal memory from pool (openvinotoolkit#18917)
Browse files Browse the repository at this point in the history
* * Not to reuse internal memory for dynamic shape because of the current inefficiency in the pool
* Added a new debug config for dump runtime memory pool

* Apply DisableMemoryReuse for all usages

* Resolved perf issue of memory reuse from pool : Previously original ibuf record was not released when we allocate new memory for that buf.
After releasing the memory, # of the memory pool record does not increase => no longer inefficient memory pool retireval.

* Added test
  • Loading branch information
yeonbok authored Aug 2, 2023
1 parent 0e13b99 commit db8c29e
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class debug_configuration {
int dump_layers_limit_batch; // Limit the size of batch to dump
int dump_layers_raw; // Dump raw data.
int dump_layers_binary; // Dump binary data.
int dump_runtime_memory_pool; // Dump memory pool status at each iteration
int base_batch_for_memory_estimation; // Base batch size to be used in memory estimation
std::vector<std::string> after_proc; // Start inference after the listed processes
int serialize_compile; // Serialize creating primitives and compiling kernels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ class memory_pool {
allocation_type type);
void clear_pool_for_network(uint32_t network_id);
void release_memory(memory* memory, const primitive_id& id, uint32_t network_id);

size_t get_non_padded_pool_size() {
return _non_padded_pool.size();
}

void dump(uint32_t id);
};

} // namespace cldnn
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class primitive_inst {
static memory::ptr allocate_output(engine& engine, memory_pool& pool, const program_node& _node, const kernel_impl_params& impl_params, uint32_t net_id,
bool is_internal, size_t idx = 0, bool reset_mem = true, bool is_output_buffer = false, memory* curr_memory = nullptr, bool runtime_alloc = false);

std::vector<memory::cptr> get_intermediates_memories() const { return _intermediates_memory; }
std::vector<memory::ptr> get_intermediates_memories() const { return _intermediates_memory; }

virtual void save(cldnn::BinaryOutputBuffer& ob) const;
virtual void load(cldnn::BinaryInputBuffer& ib);
Expand Down Expand Up @@ -307,7 +307,7 @@ class primitive_inst {
// depending on reshape_node.is_in_place())
std::vector<memory::ptr> _outputs;

std::vector<memory::cptr> _intermediates_memory;
std::vector<memory::ptr> _intermediates_memory;

mutable LruCache<layout, memory::ptr, layout::Hasher> _reordered_weights_cache;

Expand Down
26 changes: 17 additions & 9 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,10 +1172,21 @@ std::map<primitive_id, network_output> network::execute(const std::vector<event:

void network::execute_impl(const std::vector<event::ptr>& events) {
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, "NetworkImpl::Execute");
int64_t curr_iter = -1;
GPU_DEBUG_GET_INSTANCE(debug_config);
#ifdef GPU_DEBUG_CONFIG
curr_iter = iteration++;
#endif

// Wait for previous execution completion
reset_execution(false);
GPU_DEBUG_TRACE << "----------------------------------------------" << std::endl;
GPU_DEBUG_TRACE << "Start network execution" << std::endl;
GPU_DEBUG_IF(debug_config->dump_runtime_memory_pool > 0) {
GPU_DEBUG_COUT << "----------------------------------------------" << std::endl;
GPU_DEBUG_COUT << "Start network execution (net_id : " << get_id() << ", iter :" << curr_iter << ")" << std::endl;
} else {
GPU_DEBUG_TRACE << "----------------------------------------------" << std::endl;
GPU_DEBUG_TRACE << "Start network execution (net_id : " << get_id() << ", iter :" << curr_iter << ")" << std::endl;
}

std::vector<memory::ptr> in_out_mem;
auto is_surface_lock_check_needed = [&](const shared_mem_type& shared_mem_type) {
Expand Down Expand Up @@ -1211,7 +1222,6 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
auto surf_lock = surfaces_lock::create(get_engine().type(), in_out_mem, get_stream());

set_arguments();
GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_IF(debug_config->list_layers == 1) {
for (auto& inst : _exec_order) {
GPU_DEBUG_COUT << inst->id() << std::endl;
Expand All @@ -1225,12 +1235,6 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
}
if (!is_internal()) exit(0);
}
int64_t curr_iter = -1;
#ifdef GPU_DEBUG_CONFIG
GPU_DEBUG_IF(!debug_config->dump_iteration.empty()) {
curr_iter = iteration++;
}
#endif
auto get_iteration_prefix = [](int64_t iter) {
if (iter < 0)
return std::string("");
Expand Down Expand Up @@ -1435,6 +1439,10 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
// provide proper event to execution. Flushing pipeline should prevent this kind of issues.
// In scenarios with a big number of very small networks it can provide performance drop.
get_stream().flush();

GPU_DEBUG_IF(debug_config->dump_runtime_memory_pool > 0) {
get_memory_pool().dump(get_id());
}
}

std::vector<primitive_id> network::get_input_ids() const {
Expand Down
62 changes: 39 additions & 23 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ bool is_any_user_cpu(const std::list<const program_node*>& users) {
return false;
}

static memory::ptr get_memory_from_pool(engine& _engine,
uint32_t net_id,
memory_pool& pool,
const program_node& _node,
const layout& layout,
allocation_type type,
bool reusable_across_network,
bool reset = true,
memory* curr_memory = nullptr) {
OPENVINO_ASSERT(!layout.is_dynamic() || layout.has_upper_bound(),
"[GPU] Can't allocate output for dynamic layout without upper bound");
// Use layout with max tensor for dynamic shape with upper bound
if (_node.get_program().get_config().get_property(ov::intel_gpu::enable_memory_pool)) {
if (curr_memory != nullptr)
pool.release_memory(curr_memory, _node.id(), net_id);
return pool.get_memory(layout, _node.id(), net_id, _node.get_memory_dependencies(), type, reusable_across_network, reset);
}
return pool.get_memory(layout, type, reset);
}

std::shared_ptr<kernel_impl_params> primitive_impl::get_weights_reorder_kernel_params() const {
if (!need_weights_reorder())
return nullptr;
Expand Down Expand Up @@ -1002,9 +1022,19 @@ memory::ptr primitive_inst::allocate_internal_buffer(size_t idx, bool reset) {
GPU_DEBUG_LOG << "=> allocate to " << alloc_type << std::endl;

// Reuse intermediate buffer like output buffer.
auto ret_mem = _network.get_memory_pool().get_memory(layout, _node->id(), _network.get_id(), _node->get_memory_dependencies(), alloc_type, true, reset);
bool reuse_internal_buf = true;
auto ret_mem =
get_memory_from_pool(_network.get_engine(),
_network.get_id(),
_network.get_memory_pool(),
*_node,
layout,
alloc_type,
reuse_internal_buf,
reset,
_intermediates_memory.size() > idx ? _intermediates_memory[idx].get() : nullptr);
GPU_DEBUG_LOG << " [" << _network.get_id() << ":" << _node->id() << ": internal buf " << idx << "] " << alloc_type
<< " " << ret_mem->buffer_ptr() << std::endl;
<< " " << ret_mem->buffer_ptr() << std::endl;
return ret_mem;
}

Expand All @@ -1016,7 +1046,7 @@ void primitive_inst::allocate_internal_buffers(bool reset) {
return;

// allocate intermediate memory for the updated layout of buffer
std::vector<memory::cptr> intermediates_memory;
std::vector<memory::ptr> intermediates_memory;
for (size_t i = 0; i < ibuf_layouts.size(); ++i) {
if (ibuf_layouts[i].get_linear_size() == 0)
continue;
Expand Down Expand Up @@ -1150,18 +1180,6 @@ static bool user_requesting_mem_reuse_false(const program_node& node) {

memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool, const program_node& _node, const kernel_impl_params& impl_params,
uint32_t net_id, bool is_internal, size_t idx, bool reset, bool is_output_buffer, memory* curr_memory, bool runtime_alloc) {
auto get_memory_from_pool = [&](engine& _engine, const layout& layout, const primitive_id id, std::set<primitive_id> dependencies,
allocation_type type, bool reusable_across_network, bool reset = true, memory* curr_memory = nullptr) {
OPENVINO_ASSERT(!layout.is_dynamic() || layout.has_upper_bound(), "[GPU] Can't allocate output for dynamic layout without upper bound");
// Use layout with max tensor for dynamic shape with upper bound
if (_node.get_program().get_config().get_property(ov::intel_gpu::enable_memory_pool)) {
if (curr_memory != nullptr)
pool.release_memory(curr_memory, id, net_id);
return pool.get_memory(layout, id, net_id, dependencies, type, reusable_across_network, reset);
}
return pool.get_memory(layout, type, reset);
};

auto layout = impl_params.get_output_layout(idx);
OPENVINO_ASSERT(layout.is_static() || layout.has_upper_bound(), "[GPU] Can't allocate output for dynamic layout");
auto device_mem_acc = [&](size_t a, const cldnn::layout& l) {
Expand All @@ -1187,10 +1205,6 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
if (_node.is_in_shape_of_subgraph())
reusable_across_network = false;

GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_IF(debug_config->disable_memory_reuse) {
reusable_across_network = false;
}
// For outputs, cpu prim we want to have lockable alloc type
// Also if the successor of a node is an cpu, then memory needs to be lockable.
bool is_cpu = _node.get_selected_impl() ? _node.get_selected_impl()->is_cpu() : false;
Expand All @@ -1213,9 +1227,10 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
_engine.supports_allocation(allocation_type::usm_device))
alloc_type = allocation_type::usm_device;
return get_memory_from_pool(_engine,
net_id,
pool,
_node,
layout,
_node.id(),
_node.get_memory_dependencies(),
alloc_type,
false,
reset,
Expand All @@ -1231,9 +1246,10 @@ memory::ptr primitive_inst::allocate_output(engine& _engine, memory_pool& pool,
return _engine.allocate_memory(layout, alloc_type, reset);
} else {
return get_memory_from_pool(_engine,
net_id,
pool,
_node,
layout,
_node.id(),
_node.get_memory_dependencies(),
alloc_type,
reusable_across_network,
reset,
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/runtime/debug_configuration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ debug_configuration::debug_configuration()
, dump_layers_limit_batch(std::numeric_limits<int>::max())
, dump_layers_raw(0)
, dump_layers_binary(0)
, dump_runtime_memory_pool(0)
, base_batch_for_memory_estimation(-1)
, serialize_compile(0)
, max_kernels_per_batch(0)
Expand Down Expand Up @@ -210,6 +211,7 @@ debug_configuration::debug_configuration()
get_gpu_debug_env_var("DisableOnednnOptPostOps", disable_onednn_opt_post_ops);
get_gpu_debug_env_var("DumpProfilingData", dump_profiling_data);
get_gpu_debug_env_var("DryRunPath", dry_run_path);
get_gpu_debug_env_var("DumpRuntimeMemoryPool", dump_runtime_memory_pool);
get_gpu_debug_env_var("BaseBatchForMemEstimation", base_batch_for_memory_estimation);
std::string dump_layers_str;
get_gpu_debug_env_var("DumpLayers", dump_layers_str);
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/runtime/layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ std::string layout::to_short_string() const {
dump_shape(s, size);

if (data_padding.get_dynamic_pad_dims() != tensor(0)) {
s << ":dyn_pad_dims" << data_padding.get_dynamic_pad_dims().to_string();
s << ":dyn_pad_dims";
} else {
if (data_padding)
s << ":pad";
Expand Down
28 changes: 27 additions & 1 deletion src/plugins/intel_gpu/src/runtime/memory_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,12 @@ memory::ptr memory_pool::get_memory(const layout& layout,
allocation_type type,
bool reusable_across_network,
bool reset) {
if (reusable_across_network) {
bool do_reuse = reusable_across_network;
GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_IF(debug_config->disable_memory_reuse) {
do_reuse = false;
}
if (do_reuse) {
// reusable within the same network
if (!layout.format.is_image() && layout.data_padding == padding{{0, 0, 0, 0}, 0}) {
// non-padded buffers
Expand Down Expand Up @@ -298,4 +303,25 @@ void memory_pool::clear_pool_for_network(uint32_t network_id) {

memory_pool::memory_pool(engine& engine) : _engine(&engine) { }

void memory_pool::dump(uint32_t net_id) {
GPU_DEBUG_COUT << "Dump memory pool of network " << net_id << std::endl;
GPU_DEBUG_COUT << "========== non-padded pool ( " << _non_padded_pool.size() << " records) ==========" << std::endl;
for (auto mem : _non_padded_pool) {
GPU_DEBUG_COUT << mem.second._memory->buffer_ptr() << " (size: " << mem.first << ", type: " << mem.second._type
<< ")'s users: " << std::endl;
for (auto user : mem.second._users) {
GPU_DEBUG_COUT << " -- " << user._id << std::endl;
}
}
GPU_DEBUG_COUT << "========== padded pool (" << _padded_pool.size() << " records) ==========" << std::endl;
for (auto mem : _padded_pool) {
GPU_DEBUG_COUT << " layout: " << mem.first.to_short_string() << std::endl;
for (auto record : mem.second) {
GPU_DEBUG_COUT << " " << record._memory->buffer_ptr() << ", type: " << record._type << ", users : " << std::endl;
for (auto user : record._users) {
GPU_DEBUG_COUT << " --- " << user._id << std::endl;
}
}
}
}
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ TEST(softmax_gpu_dynamic_f32_test_upper_bound, input_same_values) {
out_size_1 = output_x_1 * output_b_1,
output_x_2 = 10, output_b_2 = 4,
input_x_2 = 10, input_b_2 = 4,
out_size_2 = output_x_2 * output_b_2;
out_size_2 = output_x_2 * output_b_2,
output_x_3 = 10, output_b_3 = 16,
input_x_3 = 10, input_b_3 = 16,
out_size_3 = output_x_3 * output_b_3;

cldnn::engine& engine = get_test_engine();

Expand Down Expand Up @@ -196,6 +199,29 @@ TEST(softmax_gpu_dynamic_f32_test_upper_bound, input_same_values) {
ASSERT_EQ(internal_mems_1[i]->get_allocation_type(), allocation_type::usm_device);
}
}
// Third run
float out_buffer_3[out_size_3];
std::vector<float> in_b_3(out_size_3, 2.0f);
std::vector<float> expected_buffer_3(out_size_3, 0.1f);
cldnn::memory::ptr input_3 = engine.allocate_memory({ data_types::f32, format::bfyx, {input_b_3, 1, input_x_3, 1}});
set_values(input_3, in_b_3);
network.set_input_data("input", input_3);
auto outputs_3 = network.execute();
auto output_mem_3 = outputs_3.begin()->second.get_memory();
cldnn::mem_lock<float> output_ptr_3(output_mem_3, get_test_stream());
for (uint32_t i = 0; i < out_size_3; i++) {
out_buffer_3[i] = output_ptr_3[i];
}
compare_out_buffer_with_expected(out_buffer_3, expected_buffer_3, out_size_3);
auto internal_mems_3 = network.get_primitive("softmax")->get_intermediates_memories();
for (size_t i = 0; i < internal_mems_3.size(); ++i) {
if (engine.get_device_info().supports_immad) {
ASSERT_EQ(internal_mems_3[i]->get_allocation_type(), allocation_type::usm_device);
}
}
auto& pool = network.get_memory_pool();
// check if previously allocated internal buffer is released
ASSERT_EQ(pool.get_non_padded_pool_size(), 3);
}

TEST(dyn_shape_mem_test, igpu_shape_infer_dep_mem_type) {
Expand Down

0 comments on commit db8c29e

Please sign in to comment.