Skip to content

Commit

Permalink
[const-eval] implement pow & clamp built-in functions properly
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Oct 12, 2023
1 parent c46a69d commit 4945b7a
Showing 1 changed file with 168 additions and 107 deletions.
275 changes: 168 additions & 107 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ pub enum ConstantEvaluatorError {
InvalidMathArg,
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
InvalidMathArgCount(crate::MathFunction, usize, usize),
#[error("value of `low` is greater than `high` for clamp built-in function")]
InvalidClamp,
#[error("Splat is defined only on scalar values")]
SplatScalarOnly,
#[error("Can only swizzle vector constants")]
Expand Down Expand Up @@ -501,60 +503,181 @@ impl<'a> ConstantEvaluator<'a> {
));
}

let const0 = &self.expressions[arg];
let const1 = arg1.map(|arg| &self.expressions[arg]);
let const2 = arg2.map(|arg| &self.expressions[arg]);
let _const3 = arg3.map(|arg| &self.expressions[arg]);

match fun {
crate::MathFunction::Pow => {
let literal = match (const0, const1.unwrap()) {
(&Expression::Literal(value0), &Expression::Literal(value1)) => {
match (value0, value1) {
(Literal::I32(a), Literal::I32(b)) => Literal::I32(a.pow(b as u32)),
(Literal::U32(a), Literal::U32(b)) => Literal::U32(a.pow(b)),
(Literal::F32(a), Literal::F32(b)) => Literal::F32(a.powf(b)),
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
}
crate::MathFunction::Pow => self.math_pow(arg, arg1.unwrap(), span),
crate::MathFunction::Clamp => self.math_clamp(arg, arg1.unwrap(), arg2.unwrap(), span),
fun => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
))),
}
}

fn math_pow(
&mut self,
e1: Handle<Expression>,
e2: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let e1 = self.eval_zero_value_and_splat(e1, span)?;
let e2 = self.eval_zero_value_and_splat(e2, span)?;

let expr = match (&self.expressions[e1], &self.expressions[e2]) {
(&Expression::Literal(Literal::F32(a)), &Expression::Literal(Literal::F32(b))) => {
Expression::Literal(Literal::F32(a.powf(b)))
}
(
&Expression::Compose {
components: ref src_components0,
ty: ty0,
},
&Expression::Compose {
components: ref src_components1,
ty: ty1,
},
) if ty0 == ty1
&& matches!(
self.types[ty0].inner,
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};
) =>
{
let mut components: Vec<_> = crate::proc::flatten_compose(
ty0,
src_components0,
self.expressions,
self.types,
)
.chain(crate::proc::flatten_compose(
ty1,
src_components1,
self.expressions,
self.types,
))
.collect();

let mid = components.len() / 2;
let (first, last) = components.split_at_mut(mid);
for (a, b) in first.iter_mut().zip(&*last) {
*a = self.math_pow(*a, *b, span)?;
}
components.truncate(mid);

let expr = Expression::Literal(literal);
Ok(self.register_evaluated_expr(expr, span))
Expression::Compose {
ty: ty0,
components,
}
}
crate::MathFunction::Clamp => {
let literal = match (const0, const1.unwrap(), const2.unwrap()) {
(
&Expression::Literal(value0),
&Expression::Literal(value1),
&Expression::Literal(value2),
) => match (value0, value1, value2) {
(Literal::I32(a), Literal::I32(b), Literal::I32(c)) => {
Literal::I32(a.clamp(b, c))
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
}

fn math_clamp(
&mut self,
e: Handle<Expression>,
low: Handle<Expression>,
high: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let e = self.eval_zero_value_and_splat(e, span)?;
let low = self.eval_zero_value_and_splat(low, span)?;
let high = self.eval_zero_value_and_splat(high, span)?;

let expr = match (
&self.expressions[e],
&self.expressions[low],
&self.expressions[high],
) {
(&Expression::Literal(e), &Expression::Literal(low), &Expression::Literal(high)) => {
let literal = match (e, low, high) {
(Literal::I32(e), Literal::I32(low), Literal::I32(high)) => {
if low > high {
return Err(ConstantEvaluatorError::InvalidClamp);
} else {
Literal::I32(e.clamp(low, high))
}
(Literal::U32(a), Literal::U32(b), Literal::U32(c)) => {
Literal::U32(a.clamp(b, c))
}
(Literal::U32(e), Literal::U32(low), Literal::U32(high)) => {
if low > high {
return Err(ConstantEvaluatorError::InvalidClamp);
} else {
Literal::U32(e.clamp(low, high))
}
(Literal::F32(a), Literal::F32(b), Literal::F32(c)) => {
Literal::F32(glsl_float_clamp(a, b, c))
}
(Literal::F32(e), Literal::F32(low), Literal::F32(high)) => {
if low > high {
return Err(ConstantEvaluatorError::InvalidClamp);
} else {
Literal::F32(e.clamp(low, high))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
},
_ => {
return Err(ConstantEvaluatorError::NotImplemented(
"clamp built-in function with vector values".into(),
))
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};
Expression::Literal(literal)
}
(
&Expression::Compose {
components: ref src_components0,
ty: ty0,
},
&Expression::Compose {
components: ref src_components1,
ty: ty1,
},
&Expression::Compose {
components: ref src_components2,
ty: ty2,
},
) if ty0 == ty1
&& ty0 == ty2
&& matches!(
self.types[ty0].inner,
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
}
) =>
{
let mut components: Vec<_> = crate::proc::flatten_compose(
ty0,
src_components0,
self.expressions,
self.types,
)
.chain(crate::proc::flatten_compose(
ty1,
src_components1,
self.expressions,
self.types,
))
.chain(crate::proc::flatten_compose(
ty2,
src_components2,
self.expressions,
self.types,
))
.collect();

let chunk_size = components.len() / 3;
let (es, rem) = components.split_at_mut(chunk_size);
let (lows, highs) = rem.split_at(chunk_size);
for ((e, low), high) in es.iter_mut().zip(lows).zip(highs) {
*e = self.math_clamp(*e, *low, *high, span)?;
}
components.truncate(chunk_size);

let expr = Expression::Literal(literal);
Ok(self.register_evaluated_expr(expr, span))
Expression::Compose {
ty: ty0,
components,
}
}
fun => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
))),
}
_ => return Err(ConstantEvaluatorError::InvalidMathArg),
};

Ok(self.register_evaluated_expr(expr, span))
}

fn array_length(
Expand Down Expand Up @@ -1000,6 +1123,8 @@ impl<'a> ConstantEvaluator<'a> {
}

fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle<Expression> {
// TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here

if let Some(FunctionLocalData {
ref mut emitter,
ref mut block,
Expand All @@ -1026,57 +1151,6 @@ impl<'a> ConstantEvaluator<'a> {
}
}

/// Helper function to implement the GLSL `max` function for floats.
///
/// While Rust does provide a `f64::max` method, it has a different behavior than the
/// GLSL `max` for NaNs. In Rust, if any of the arguments is a NaN, then the other
/// is returned.
///
/// This leads to different results in the following example
/// ```
/// use std::cmp::max;
/// std::f64::NAN.max(1.0);
/// ```
///
/// Rust will return `1.0` while GLSL should return NaN.
fn glsl_float_max(x: f32, y: f32) -> f32 {
if x < y {
y
} else {
x
}
}

/// Helper function to implement the GLSL `min` function for floats.
///
/// While Rust does provide a `f64::min` method, it has a different behavior than the
/// GLSL `min` for NaNs. In Rust, if any of the arguments is a NaN, then the other
/// is returned.
///
/// This leads to different results in the following example
/// ```
/// use std::cmp::min;
/// std::f64::NAN.min(1.0);
/// ```
///
/// Rust will return `1.0` while GLSL should return NaN.
fn glsl_float_min(x: f32, y: f32) -> f32 {
if y < x {
y
} else {
x
}
}

/// Helper function to implement the GLSL `clamp` function for floats.
///
/// While Rust does provide a `f64::clamp` method, it panics if either
/// `min` or `max` are `NaN`s which is not the behavior specified by
/// the glsl specification.
fn glsl_float_clamp(value: f32, min: f32, max: f32) -> f32 {
glsl_float_min(glsl_float_max(value, min), max)
}

#[cfg(test)]
mod tests {
use std::vec;
Expand All @@ -1088,19 +1162,6 @@ mod tests {

use super::{Behavior, ConstantEvaluator};

#[test]
fn nan_handling() {
assert!(super::glsl_float_max(f32::NAN, 2.0).is_nan());
assert!(!super::glsl_float_max(2.0, f32::NAN).is_nan());

assert!(super::glsl_float_min(f32::NAN, 2.0).is_nan());
assert!(!super::glsl_float_min(2.0, f32::NAN).is_nan());

assert!(super::glsl_float_clamp(f32::NAN, 1.0, 2.0).is_nan());
assert!(!super::glsl_float_clamp(1.0, f32::NAN, 2.0).is_nan());
assert!(!super::glsl_float_clamp(1.0, 2.0, f32::NAN).is_nan());
}

#[test]
fn unary_op() {
let mut types = UniqueArena::new();
Expand Down

0 comments on commit 4945b7a

Please sign in to comment.