Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Bottom level categories:

- The validator checks that override-sized arrays have a positive size, if overrides have been resolved. By @andyleiserson in [#8822](https://github.com/gfx-rs/wgpu/pull/8822).
- Fix some cases where f16 constants were not working. By @andyleiserson in [#8816](https://github.com/gfx-rs/wgpu/pull/8816).
- Naga now detects bitwise shifts by a constant exceeding the operand bit width at compile time, and disallows scalar-by-vector and vector-by-scalar shifts in constant evaluation. By @andyleiserson in [#8907](https://github.com/gfx-rs/wgpu/pull/8907).

#### Naga

Expand Down
2 changes: 1 addition & 1 deletion cts_runner/fail.lst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ webgpu:shader,validation,expression,access,matrix:* // 90%
webgpu:shader,validation,expression,access,vector:* // 52%
webgpu:shader,validation,expression,binary,add_sub_mul:* // 95%
webgpu:shader,validation,expression,binary,and_or_xor:* // 96%
webgpu:shader,validation,expression,binary,bitwise_shift:* // 97%
webgpu:shader,validation,expression,binary,bitwise_shift:invalid_types:* // 93%, https://github.com/gfx-rs/wgpu/issues/5474
webgpu:shader,validation,expression,binary,comparison:* // 74%
webgpu:shader,validation,expression,binary,div_rem:* // 86%
webgpu:shader,validation,expression,binary,short_circuiting_and_or:* // 92%, https://github.com/gfx-rs/wgpu/issues/8440
Expand Down
6 changes: 6 additions & 0 deletions cts_runner/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ webgpu:shader,validation,expression,access,array:early_eval_errors:case="overrid
webgpu:shader,validation,expression,access,array:early_eval_errors:case="override_array_cnt_size_zero_signed"
webgpu:shader,validation,expression,access,array:early_eval_errors:case="override_array_cnt_size_zero_unsigned"
webgpu:shader,validation,expression,access,array:early_eval_errors:case="override_in_bounds"
webgpu:shader,validation,expression,binary,bitwise_shift:partial_eval_errors:*
webgpu:shader,validation,expression,binary,bitwise_shift:scalar_vector:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_abstract:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_concrete:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_abstract:*
webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*
webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_override:op="%26%26";a_val=1;b_val=1
webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*
webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:op="%26%26";lhs="bool";rhs="bool"
Expand Down
121 changes: 113 additions & 8 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,84 @@ impl<'a> ConstantEvaluator<'a> {
}
}

/// Return an error if a constant shift amount in `right` exceeds the scalar
/// bit width of `left`.
///
/// This function promises to return an error in cases where (1) the
/// expression is well-typed, (2) the type of `lhs` is `i32` or `u32`, and
/// (3) the shift will overflow. It does not return an error if there are
/// abstract integers on the LHS, because in that case the expression is an
/// ordinary const-expression and will be checked during constant evaluation
/// against slightly different rules. It also does not return an error in
/// cases where the expression is not well-typed (e.g. vector dimension
/// mismatch), because those will be rejected elsewhere.
fn validate_constant_shift_amounts(
&mut self,
left: Handle<Expression>,
right: Handle<Expression>,
) -> Result<(), ConstantEvaluatorError> {
fn is_overflowing_shift<'a>(
this: &mut ConstantEvaluator<'a>,
left: Handle<Expression>,
right: Handle<Expression>,
) -> bool {
if this
.expression_kind_tracker
.type_of_with_expr(&this.expressions[left])
== ExpressionKind::Const
{
// If the LHS is const, rely on constant evaluation to detect overflow.
return false;
}
let Ok(right_ty) = this.resolve_type(right) else {
return false;
};
match right_ty.inner_with(this.types).vector_size_and_scalar() {
Some((None, crate::Scalar::I32 | crate::Scalar::U32)) => {
let shift_amount = this
.to_ctx()
.get_const_val_from::<u32, _>(right, this.expressions);
shift_amount.ok().is_some_and(|s| s >= 32)
}
Some((Some(size), crate::Scalar::I32 | crate::Scalar::U32)) => {
match this.expressions[right] {
Expression::ZeroValue(_) => false, // zero shift does not overflow
Expression::Splat { value, .. } => this
.to_ctx()
.get_const_val_from::<u32, _>(value, this.expressions)
.ok()
.is_some_and(|s| s >= 32),
Expression::Compose {
ty: _,
ref components,
} => {
let len = components.len();
match size {
crate::VectorSize::Bi if len != 2 => false,
crate::VectorSize::Tri if len != 3 => false,
crate::VectorSize::Quad if len != 4 => false,
_ => components.iter().any(|comp| {
this.to_ctx()
.get_const_val_from::<u32, _>(*comp, this.expressions)
.ok()
.is_some_and(|s| s >= 32)
}),
}
}
_ => false,
}
}
_ => false,
}
}

if is_overflowing_shift(self, left, right) {
Err(ConstantEvaluatorError::ShiftedMoreThan32Bits)
} else {
Ok(())
}
}

/// Try to evaluate `expr` at compile time.
///
/// The `expr` argument can be any sort of Naga [`Expression`] you like. If
Expand Down Expand Up @@ -1165,6 +1243,19 @@ impl<'a> ConstantEvaluator<'a> {
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
// Most binary ops are runtime expressions if either of their operands
// is a runtime expression. Shifts are an exception: a shift `runtime <<
// constant` with a constant that is out of range for the data type can
// be detected during constant evaluation.
if let &Expression::Binary {
op: BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight,
left,
right,
} = &expr
{
self.validate_constant_shift_amounts(left, right)?;
}

match self.expression_kind_tracker.type_of_with_expr(&expr) {
ExpressionKind::Const => {
let eval_result = self.try_eval_and_append_impl(&expr, span);
Expand Down Expand Up @@ -2740,11 +2831,18 @@ impl<'a> ConstantEvaluator<'a> {
},
&Expression::Literal(_),
) => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, *component, right, span)?;
match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
}
_ => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, *component, right, span)?;
}
Expression::Compose { ty, components }
}
}
Expression::Compose { ty, components }
}
(
&Expression::Literal(_),
Expand All @@ -2753,11 +2851,18 @@ impl<'a> ConstantEvaluator<'a> {
ty,
},
) => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, left, *component, span)?;
match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
}
_ => {
let mut components = src_components.clone();
for component in &mut components {
*component = self.binary_op(op, left, *component, span)?;
}
Expression::Compose { ty, components }
}
}
Expression::Compose { ty, components }
}
(
&Expression::Compose {
Expand Down
7 changes: 6 additions & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,12 @@ pub struct GlobalCtx<'a> {
}

impl GlobalCtx<'_> {
/// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
/// Try to evaluate the expression in `self.global_expressions` using its
/// `handle` and return it as a `T: TryFrom<ir::Literal>`.
///
/// This currently only evaluates scalar expressions. If adding support for
/// vectors, consider changing [`constant_evaluator::validate_const_shift_amounts`]
/// to use that support.
#[cfg_attr(
not(any(
feature = "glsl-in",
Expand Down
46 changes: 23 additions & 23 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl super::Validator {
| Ti::ValuePointer { size: Some(_), .. }
| Ti::BindingArray { .. } => {}
ref other => {
log::error!("Indexing of {other:?}");
log::debug!("Indexing of {other:?}");
return Err(ExpressionError::InvalidBaseType(base));
}
};
Expand All @@ -284,7 +284,7 @@ impl super::Validator {
..
}) => {}
ref other => {
log::error!("Indexing by {other:?}");
log::debug!("Indexing by {other:?}");
return Err(ExpressionError::InvalidIndexType(index));
}
}
Expand Down Expand Up @@ -342,7 +342,7 @@ impl super::Validator {
}
Ti::Struct { ref members, .. } => members.len() as u32,
ref other => {
log::error!("Indexing of {other:?}");
log::debug!("Indexing of {other:?}");
return Err(ExpressionError::InvalidBaseType(top));
}
};
Expand All @@ -358,7 +358,7 @@ impl super::Validator {
E::Splat { size: _, value } => match resolver[value] {
Ti::Scalar { .. } => ShaderStages::all(),
ref other => {
log::error!("Splat scalar type {other:?}");
log::debug!("Splat scalar type {other:?}");
return Err(ExpressionError::InvalidSplatType(value));
}
},
Expand All @@ -370,7 +370,7 @@ impl super::Validator {
let vec_size = match resolver[vector] {
Ti::Vector { size: vec_size, .. } => vec_size,
ref other => {
log::error!("Swizzle vector type {other:?}");
log::debug!("Swizzle vector type {other:?}");
return Err(ExpressionError::InvalidVectorType(vector));
}
};
Expand Down Expand Up @@ -414,7 +414,7 @@ impl super::Validator {
.contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
Ti::ValuePointer { .. } => {}
ref other => {
log::error!("Loading {other:?}");
log::debug!("Loading {other:?}");
return Err(ExpressionError::InvalidPointerType(pointer));
}
}
Expand Down Expand Up @@ -786,7 +786,7 @@ impl super::Validator {
| (Uo::LogicalNot, Some(Sk::Bool))
| (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
other => {
log::error!("Op {op:?} kind {other:?}");
log::debug!("Op {op:?} kind {other:?}");
return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
}
}
Expand Down Expand Up @@ -903,7 +903,7 @@ impl super::Validator {
Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
}
Expand All @@ -915,7 +915,7 @@ impl super::Validator {
..
} => left_inner == right_inner,
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -925,7 +925,7 @@ impl super::Validator {
Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -935,7 +935,7 @@ impl super::Validator {
Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
},
ref other => {
log::error!("Op {op:?} left type {other:?}");
log::debug!("Op {op:?} left type {other:?}");
false
}
},
Expand All @@ -944,7 +944,7 @@ impl super::Validator {
Ti::Scalar(scalar) => (Ok(None), scalar),
Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
ref other => {
log::error!("Op {op:?} base type {other:?}");
log::debug!("Op {op:?} base type {other:?}");
(Err(()), Sc::BOOL)
}
};
Expand All @@ -955,7 +955,7 @@ impl super::Validator {
scalar: Sc { kind: Sk::Uint, .. },
} => Ok(Some(size)),
ref other => {
log::error!("Op {op:?} shift type {other:?}");
log::debug!("Op {op:?} shift type {other:?}");
Err(())
}
};
Expand All @@ -966,12 +966,12 @@ impl super::Validator {
}
};
if !good {
log::error!(
log::debug!(
"Left: {:?} of type {:?}",
function.expressions[left],
left_inner
);
log::error!(
log::debug!(
"Right: {:?} of type {:?}",
function.expressions[right],
right_inner
Expand Down Expand Up @@ -1060,15 +1060,15 @@ impl super::Validator {
..
} => {}
ref other => {
log::error!("All/Any of type {other:?}");
log::debug!("All/Any of type {other:?}");
return Err(ExpressionError::InvalidBooleanVector(argument));
}
},
Rf::IsNan | Rf::IsInf => match *argument_inner {
Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
if scalar.kind == Sk::Float => {}
ref other => {
log::error!("Float test of type {other:?}");
log::debug!("Float test of type {other:?}");
return Err(ExpressionError::InvalidFloatArgument(argument));
}
},
Expand Down Expand Up @@ -1206,7 +1206,7 @@ impl super::Validator {
}
}
ref other => {
log::error!("Array length of {other:?}");
log::debug!("Array length of {other:?}");
return Err(ExpressionError::InvalidArrayType(expr));
}
},
Expand All @@ -1221,12 +1221,12 @@ impl super::Validator {
} => match resolver.types[base].inner {
Ti::RayQuery { .. } => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {other:?}");
log::debug!("Intersection result of a pointer to {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {other:?}");
log::debug!("Intersection result of {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
Expand All @@ -1242,12 +1242,12 @@ impl super::Validator {
vertex_return: true,
} => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {other:?}");
log::debug!("Intersection result of a pointer to {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {other:?}");
log::debug!("Intersection result of {other:?}");
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
Expand All @@ -1272,7 +1272,7 @@ impl super::Validator {
match resolver[operand] {
Ti::CooperativeMatrix { role, .. } if role == expected_role => {}
ref other => {
log::error!("{expected_role:?} operand type: {other:?}");
log::debug!("{expected_role:?} operand type: {other:?}");
return Err(ExpressionError::InvalidCooperativeOperand(a));
}
}
Expand Down
Loading