Skip to content

Commit

Permalink
ONNX: add ArgMin, ArgMax and LeakyRelu (huggingface#2246)
Browse files Browse the repository at this point in the history
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

* Added ArgMin operator implementation

* Added tests for ArgMin

* ArgMin now returns a tensor with i64

* Added tests from pytorch examples

* Added ArgMax operator implementation

* Added tests for ArgMax

* Added LeakyRelu implementation

* Added a test for LeakyRelu

* Typo fix

* Fix a weird automatic RustRover change

---------

Co-authored-by: Mateusz Okulus <[email protected]>
  • Loading branch information
B1rtek and mokulus authored Jun 4, 2024
1 parent 9182c82 commit cb180eb
Show file tree
Hide file tree
Showing 2 changed files with 523 additions and 0 deletions.
53 changes: 53 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,59 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"ArgMin" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmin_keepdim(axis)?
} else {
input.argmin(axis)?
}.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"ArgMax" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmax_keepdim(axis)?
} else {
input.argmax(axis)?
}.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
let input = get(&node.input[0])?;
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str())
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
Expand Down
Loading

0 comments on commit cb180eb

Please sign in to comment.