Skip to content

Commit

Permalink
Fix VarBuilder::from_slice_safetensors (huggingface#2180)
Browse files Browse the repository at this point in the history
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <[email protected]>
  • Loading branch information
boustrophedon authored May 12, 2024
1 parent 21f82a5 commit 13c64f6
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
}
}

impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
fn get(
&self,
s: Shape,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
}

impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` using a custom backend.
///
Expand Down Expand Up @@ -481,15 +507,15 @@ impl<'a> VarBuilder<'a> {
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}

/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
/// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}

/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
/// Initializes a `VarBuilder` from a binary slice in the safetensor format.
pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::SliceSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}

Expand Down

0 comments on commit 13c64f6

Please sign in to comment.