Skip to content

Commit b48879d

Browse files
committed
fix bug: unbalance hash table leads illegal memory access in global adam update
1 parent 0674e83 commit b48879d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

HugeCTR/src/embeddings/localized_slot_sparse_embedding_hash.cu

+6
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,22 @@ LocalizedSlotSparseEmbeddingHash<TypeHashKey, TypeEmbeddingComp>::LocalizedSlotS
204204
} else {
205205
const std::shared_ptr<BufferBlock2<float>> &block = buf->create_block<float>();
206206
Tensors2<float> tensors;
207+
size_t vocabulary_size_in_current_gpu = 0;
207208
for (size_t i = 0; i < slot_size_array_.size(); i++) {
208209
if ((i % embedding_data_.get_resource_manager().get_global_gpu_count()) == gid) {
209210
Tensor2<float> tensor;
210211
block->reserve(
211212
{slot_size_array_[i], embedding_data_.embedding_params_.embedding_vec_size},
212213
&tensor);
213214
tensors.push_back(tensor);
215+
vocabulary_size_in_current_gpu += slot_size_array_[i];
214216
}
215217
}
216218
value_table_tensors_.push_back(tensors);
219+
if (max_vocabulary_size_per_gpu_ > vocabulary_size_in_current_gpu) {
220+
Tensor2<float> padding_tensor_for_optimizer;
221+
block->reserve({max_vocabulary_size_per_gpu_ - vocabulary_size_in_current_gpu, embedding_data_.embedding_params_.embedding_vec_size}, &padding_tensor_for_optimizer);
222+
}
217223
hash_table_value_tensors_.push_back(block->as_tensor());
218224
}
219225

0 commit comments

Comments
 (0)