Skip to content

Commit

Permalink
fix: zero out tensors created with Tensor::new on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 26, 2024
1 parent 1c900df commit 7a95f98
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ impl Tensor<String> {
}

impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the
/// value will be zero-allocated on the allocation device.
/// Construct a tensor via a given allocator with a given shape and datatype. The data in the tensor will be
/// **uninitialized**.
///
/// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned
/// (CPU) memory for use with CUDA:
Expand Down Expand Up @@ -129,6 +129,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
nonNull(value_ptr)
];

// `CreateTensorAsOrtValue` actually does not guarantee that the data allocated is zero'd out, so if we can, we should
// do it manually.
let memory_info = MemoryInfo::from_value(value_ptr).expect("CreateTensorAsOrtValue returned non-tensor");
if memory_info.is_cpu_accessible() {
let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut();
ortsys![unsafe GetTensorMutableData(value_ptr, &mut buffer_ptr)?; nonNull(buffer_ptr)];

unsafe { buffer_ptr.write_bytes(0, calculate_tensor_size(&shape) * size_of::<T>()) };
}

Ok(Value {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
Expand Down

0 comments on commit 7a95f98

Please sign in to comment.