diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index df25829013..31eb62cdf1 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1027,6 +1027,59 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "ArgMin" => { + let input = get(&node.input[0])?; + let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); + let rank_i64: i64 = input.rank().try_into().unwrap(); + if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { + bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + } + let axis = input.normalize_axis(axis_i64)?; + let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + if select_last_index == 1 { + bail!("select_last_index for ArgMin is currently not supported") + } + let output = if keepdims == 1 { + input.argmin_keepdim(axis)? + } else { + input.argmin(axis)? + }.to_dtype(DType::I64)?; + values.insert(node.output[0].clone(), output); + } + "ArgMax" => { + let input = get(&node.input[0])?; + let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); + let rank_i64: i64 = input.rank().try_into().unwrap(); + if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { + bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + } + let axis = input.normalize_axis(axis_i64)?; + let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + if select_last_index == 1 { + bail!("select_last_index for ArgMin is currently not supported") + } + let output = if keepdims == 1 { + input.argmax_keepdim(axis)? + } else { + input.argmax(axis)? + }.to_dtype(DType::I64)?; + values.insert(node.output[0].clone(), output); + } + "LeakyRelu" => { + let input = get(&node.input[0])?; + let dt = input.dtype(); + match dt { + DType::U8 | DType::U32 | DType::I64 => { + bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str()) + } + DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {} + } + let alpha = get_attr_opt::(node, "alpha")?.copied().unwrap_or(0.01); + let output = candle_nn::ops::leaky_relu(input, alpha.into())?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index f58aeccfac..ffafd7a7b2 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2708,3 +2708,473 @@ fn test_ceil() -> Result<()> { Ok(()) } + +// "ArgMin" +#[test] +fn test_argmin() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7 + // default_axes_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(1), + None, + &[ + [0i64, 0i64], + ], + )?; + // keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(1), + Some(1), + None, + &[ + [1i64], + [0i64] + ], + )?; + // // negative_axis_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(-1), + Some(1), + None, + &[ + [1i64], + [0i64] + ], + )?; + // no_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(0), + None, + &[0i64, 0i64], + )?; + // tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + Some(0), + None, + &[2i64, 1i64, 3i64, 1i64], + )?; + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + None, + None, + &[[2i64], [1i64], [3i64], [1i64]], + )?; + fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis.unwrap_or(0), + doc_string: "axis".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims.unwrap_or(1), + doc_string: "keepdims".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_select_last_index = AttributeProto { + name: "select_last_index".to_string(), + ref_attr_name: "select_last_index".to_string(), + i: select_last_index.unwrap_or(0), + doc_string: "select_last_index".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if axis.is_some() { + mut_attrs.push(att_axis); + } + if keepdims.is_some() { + mut_attrs.push(att_keepdims); + } + if select_last_index.is_some() { + mut_attrs.push(att_select_last_index); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ArgMin".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "ArgMax" +#[test] +fn test_argmax() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6 + // default_axes_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(1), + None, + &[ + [1i64, 1i64], + ], + )?; + // keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(1), + Some(1), + None, + &[ + [0i64], + [1i64] + ], + )?; + // // negative_axis_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(-1), + Some(1), + None, + &[ + [0i64], + [1i64] + ], + )?; + // no_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(0), + None, + &[1i64, 1i64], + )?; + // tests from https://pytorch.org/docs/stable/generated/torch.argmax.html + test( + &[ + [1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195] + ], + Some(1), + Some(0), + None, + &[0i64, 2i64, 0i64, 1i64], + )?; + test( + &[ + [1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195] + ], + Some(1), + None, + None, + &[[0i64], [2i64], [0i64], [1i64]], + )?; + fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis.unwrap_or(0), + doc_string: "axis".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims.unwrap_or(1), + doc_string: "keepdims".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_select_last_index = AttributeProto { + name: "select_last_index".to_string(), + ref_attr_name: "select_last_index".to_string(), + i: select_last_index.unwrap_or(0), + doc_string: "select_last_index".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if axis.is_some() { + mut_attrs.push(att_axis); + } + if keepdims.is_some() { + mut_attrs.push(att_keepdims); + } + if select_last_index.is_some() { + mut_attrs.push(att_select_last_index); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ArgMax".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "LeakyRelu" +#[test] +fn test_leakyrelu() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80 + // leakyrelu + test( + &[-1.0, 0.0, 1.0], + Some(0.1), + &[-0.1, 0.0, 1.0] + )?; + fn test(data: impl NdArray, alpha: Option, expected: impl NdArray) -> Result<()> { + let att_alpha = AttributeProto { + name: "alpha".to_string(), + ref_attr_name: "alpha".to_string(), + i: 0, + doc_string: "alpha".to_string(), + r#type: 1, // FLOAT + f: alpha.unwrap_or(0.01), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if alpha.is_some() { + mut_attrs.push(att_alpha); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "LeakyRelu".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + for both in z.to_vec1::()?.iter().zip(expected.to_vec1::()?.iter()) { + let (act, exp) = both; + assert!(f64::abs(act - exp) < f32::EPSILON.into()); + } + + Ok(()) + } + + Ok(()) +}