Skip to content

Commit 085b2e8

Browse files
committed
Merge branch 'fix-unbalance_memory_for_localized_embedding' into 'master'
fix unbalance memory for localized embedding See merge request dl/hugectr/hugectr!501
2 parents efccfbb + f122d69 commit 085b2e8

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

HugeCTR/src/embeddings/localized_slot_sparse_embedding_hash.cu

+6
Original file line numberDiff line numberDiff line change
@@ -405,16 +405,22 @@ LocalizedSlotSparseEmbeddingHash<TypeHashKey, TypeEmbeddingComp>::LocalizedSlotS
405405
} else {
406406
const std::shared_ptr<BufferBlock2<float>> &block = buf->create_block<float>();
407407
Tensors2<float> tensors;
408+
size_t vocabulary_size_in_current_gpu = 0;
408409
for (size_t i = 0; i < slot_size_array_.size(); i++) {
409410
if ((i % embedding_data_.get_resource_manager().get_global_gpu_count()) == gid) {
410411
Tensor2<float> tensor;
411412
block->reserve(
412413
{slot_size_array_[i], embedding_data_.embedding_params_.embedding_vec_size},
413414
&tensor);
414415
tensors.push_back(tensor);
416+
vocabulary_size_in_current_gpu += slot_size_array_[i];
415417
}
416418
}
417419
value_table_tensors_.push_back(tensors);
420+
if (max_vocabulary_size_per_gpu_ > vocabulary_size_in_current_gpu) {
421+
Tensor2<float> padding_tensor_for_optimizer;
422+
block->reserve({max_vocabulary_size_per_gpu_ - vocabulary_size_in_current_gpu, embedding_data_.embedding_params_.embedding_vec_size}, &padding_tensor_for_optimizer);
423+
}
418424
hash_table_value_tensors_.push_back(block->as_tensor());
419425
}
420426
{

0 commit comments

Comments
 (0)