Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ Bottom level categories:

### Changes

#### General

- Naga and `wgpu` now reject shaders with an `enable` directive for functionality that is not available, even if that functionality is not used by the shader. By @andyleiserson in [#8913](https://github.com/gfx-rs/wgpu/pull/8913).

#### Naga

- Prevent UB from incorrectly using ray queries on HLSL. By @Vecvec in [#8763](https://github.com/gfx-rs/wgpu/pull/8763).
Expand Down
50 changes: 48 additions & 2 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ struct Args {
/// defines to be passed to the parser (only glsl is supported)
#[argh(option, short = 'D')]
defines: Vec<Defines>,

/// capabilities for parsing and validation.
///
/// Can be a comma-separated list of capability names (e.g.,
/// "shader_float16,dual_source_blending"), a numeric bitflags value (e.g.,
/// "67108864"), the string "none", or the string "all".
#[argh(option, default = "CapabilitiesArg(naga::valid::Capabilities::all())")]
capabilities: CapabilitiesArg,
}

/// Newtype so we can implement [`FromStr`] for `BoundsCheckPolicy`.
Expand Down Expand Up @@ -336,6 +344,37 @@ impl FromStr for Defines {
}
}

#[derive(Debug, Clone, Copy)]
struct CapabilitiesArg(naga::valid::Capabilities);

impl FromStr for CapabilitiesArg {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
use naga::valid::Capabilities;

let s = s.to_uppercase();

if s == "NONE" {
Ok(Self(Capabilities::empty()))
} else if s == "ALL" {
Ok(Self(Capabilities::all()))
} else if let Ok(bits) = s.parse::<u64>() {
Capabilities::from_bits(bits)
.map(Self)
.ok_or_else(|| format!("Invalid capabilities bitflags value: {bits}"))
} else {
s.split(',')
.try_fold(Capabilities::empty(), |acc, s| {
Capabilities::from_name(s.trim())
.map(|cap| acc | cap)
.ok_or(format!("Unknown capability {}", s.trim()))
})
.map(Self)
}
}
}

#[derive(Default)]
struct Parameters<'a> {
validation_flags: naga::valid::ValidationFlags,
Expand All @@ -352,6 +391,7 @@ struct Parameters<'a> {
input_kind: Option<InputKind>,
shader_stage: Option<ShaderStage>,
defines: FastHashMap<String, String>,
capabilities: naga::valid::Capabilities,

/// We use this copy of `args.compact` to know whether we should pass the
/// entrypoint to `process_overrides`, which will result in removal from
Expand Down Expand Up @@ -505,6 +545,7 @@ fn run() -> anyhow::Result<()> {
);

params.compact = args.compact;
params.capabilities = args.capabilities.0;

if args.bulk_validate {
return bulk_validate(&args, &params);
Expand Down Expand Up @@ -682,7 +723,12 @@ fn parse_input(input_path: &Path, input: Vec<u8>, params: &Parameters) -> anyhow
},
InputKind::Wgsl => {
let input = String::from_utf8(input)?;
let result = naga::front::wgsl::parse_str(&input);
let options = naga::front::wgsl::Options {
parse_doc_comments: false,
capabilities: params.capabilities,
};
let mut frontend = naga::front::wgsl::Frontend::new_with_options(options);
let result = frontend.parse(&input);
match result {
Ok(v) => Parsed {
module: v,
Expand Down Expand Up @@ -961,7 +1007,7 @@ fn bulk_validate(args: &Args, params: &Parameters) -> anyhow::Result<()> {
};

let mut validator =
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
naga::valid::Validator::new(params.validation_flags, params.capabilities);
validator.subgroup_stages(naga::valid::ShaderStages::all());
validator.subgroup_operations(naga::valid::SubgroupOperationSet::all());

Expand Down
1 change: 1 addition & 0 deletions naga-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl From<&WgslInParameters> for naga::front::wgsl::Options {
fn from(value: &WgslInParameters) -> Self {
Self {
parse_doc_comments: value.parse_doc_comments,
capabilities: naga::valid::Capabilities::all(),
}
}
}
Expand Down
22 changes: 9 additions & 13 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,15 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) {
}
}
if e.stage == crate::ShaderStage::Task || e.stage == crate::ShaderStage::Mesh {
// u32 should always be there if the module is valid, as it is e.g. the type of some expressions
let u32_type = module
.types
.iter()
.find_map(|tuple| {
if tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32) {
Some(tuple.0)
} else {
None
}
})
.unwrap();
module_tracer.types_used.insert(u32_type);
// Mesh shaders always need a u32 type, as it is e.g. the type of some
// expressions. We tolerate its absence here because compaction is
// infallible, but the module will fail validation.
if let Some(u32_type) = module.types.iter().find_map(|tuple| {
(tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32))
.then_some(tuple.0)
}) {
module_tracer.types_used.insert(u32_type);
}
}

let mut used = module_tracer.as_function(&e.function);
Expand Down
15 changes: 15 additions & 0 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ pub(crate) enum Error<'a> {
kind: EnableExtension,
span: Span,
},
EnableExtensionNotSupported {
kind: EnableExtension,
span: Span,
},
LanguageExtensionNotYetImplemented {
kind: UnimplementedLanguageExtension,
span: Span,
Expand Down Expand Up @@ -1240,6 +1244,17 @@ impl<'a> Error<'a> {
]
},
},
Error::EnableExtensionNotSupported { kind, span } => ParseError {
message: format!(
"the `{}` extension is not supported in the current environment",
kind.to_ident()
),
labels: vec![(
span,
"unsupported enable-extension".into(),
)],
notes: vec![],
},
Error::LanguageExtensionNotYetImplemented { kind, span } => ParseError {
message: format!(
"the `{}` language extension is not yet supported",
Expand Down
16 changes: 16 additions & 0 deletions naga/src/front/wgsl/parse/directive/enable_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ pub enum ImplementedEnableExtension {
WgpuCooperativeMatrix,
}

impl ImplementedEnableExtension {
/// Returns the capability required for this enable extension.
pub const fn capability(self) -> crate::valid::Capabilities {
use crate::valid::Capabilities as C;
match self {
Self::F16 => C::SHADER_FLOAT16,
Self::DualSourceBlending => C::DUAL_SOURCE_BLENDING,
Self::ClipDistances => C::CLIP_DISTANCE,
Self::WgpuMeshShader => C::MESH_SHADER,
Self::WgpuRayQuery => C::RAY_QUERY,
Self::WgpuRayQueryVertexReturn => C::RAY_HIT_VERTEX_POSITION,
Self::WgpuCooperativeMatrix => C::COOPERATIVE_MATRIX,
}
}
}

/// A variant of [`EnableExtension::Unimplemented`].
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub enum UnimplementedEnableExtension {
Expand Down
13 changes: 12 additions & 1 deletion naga/src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,16 @@ impl<'a> BindingParser<'a> {
pub struct Options {
/// Controls whether the parser should parse doc comments.
pub parse_doc_comments: bool,
/// Capabilities to enable during parsing.
pub capabilities: crate::valid::Capabilities,
}

impl Options {
/// Creates a new [`Options`] without doc comments parsing.
/// Creates a new default [`Options`].
pub const fn new() -> Self {
Options {
parse_doc_comments: false,
capabilities: crate::valid::Capabilities::all(),
}
}
}
Expand Down Expand Up @@ -3280,6 +3283,14 @@ impl Parser {
}))
}
};
// Check if the required capability is supported
let required_capability = extension.capability();
if !options.capabilities.contains(required_capability) {
return Err(Box::new(Error::EnableExtensionNotSupported {
kind,
span,
}));
}
enable_extensions.add(extension);
Ok(())
})?;
Expand Down
111 changes: 108 additions & 3 deletions naga/tests/naga/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,63 @@
)]

use naga::{
ir,
valid::{self, ModuleInfo},
Expression, Function, Module, Scalar,
ir::{self, Expression, Function, Module, Scalar},
valid::{self, Capabilities, ModuleInfo, ValidationFlags},
};

fn test_span_generator() -> impl FnMut() -> naga::Span {
let mut index = 0;
move || {
let span = naga::Span::new(index, index + 1);
index += 1;
span
}
}

#[track_caller]
fn expect_validation_error_impl<I: IntoIterator<Item = naga::Span>>(
module: &Module,
validation_flags: valid::ValidationFlags,
capabilities: valid::Capabilities,
spans: Option<I>,
) -> naga::valid::ValidationError {
let err = valid::Validator::new(validation_flags, capabilities)
.validate(module)
.expect_err("module should be invalid");

if let Some(expected_spans_iter) = spans {
let actual_spans = err.spans().map(|sctx| sctx.0).collect::<Vec<_>>();
let expected_spans = expected_spans_iter.into_iter().collect::<Vec<_>>();
assert_eq!(
actual_spans, expected_spans,
"expected error spans to be {expected_spans:?}, got {actual_spans:?}",
);
}

err.into_inner()
}

/// Validate `module` with the given `validation_flags` and `capabilities`.
///
/// Panics if validation succeeds or fails with an error not associated with
/// `span`. Otherwise, returns the validation error.
///
/// Note that only the span is checked, not the associated context string.
#[track_caller]
fn expect_validation_error_with_span(
module: &Module,
validation_flags: valid::ValidationFlags,
capabilities: valid::Capabilities,
span: naga::Span,
) -> naga::valid::ValidationError {
expect_validation_error_impl(
module,
validation_flags,
capabilities,
Some(core::iter::once(span)),
)
}

/// Validation should fail if `AtomicResult` expressions are not
/// populated by `Atomic` statements.
#[test]
Expand Down Expand Up @@ -1184,3 +1236,56 @@ fn main() {
"#,
);
}

#[test]
fn unexpected_task_payload() {
let mut make_test_span = test_span_generator();
let mut module = Module::default();

let ty_payload = module.types.insert(
ir::Type {
name: Some("u32".into()),
inner: ir::TypeInner::Scalar(naga::Scalar::U32),
},
make_test_span(),
);

let err_span = make_test_span();
let payload_handle = module.global_variables.append(
ir::GlobalVariable {
name: Some("task_payload".into()),
space: ir::AddressSpace::TaskPayload,
binding: None,
ty: ty_payload,
init: None,
},
err_span,
);

let entry_point = ir::EntryPoint {
name: "main".into(),
stage: ir::ShaderStage::Compute,
early_depth_test: None,
workgroup_size: [1, 1, 1],
workgroup_size_overrides: None,
function: ir::Function::default(),
mesh_info: None,
task_payload: Some(payload_handle), // invalid for compute stage
};
module.entry_points.push(entry_point);

let err = expect_validation_error_with_span(
&module,
ValidationFlags::default(),
Capabilities::MESH_SHADER,
err_span,
);

assert!(matches!(
err,
valid::ValidationError::EntryPoint {
source: valid::EntryPointError::UnexpectedTaskPayload,
..
}
));
}
Loading
Loading