Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): apply limits to recursive functions in validation #748

Merged
merged 16 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/apollo-compiler/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ pub(crate) enum DiagnosticData {
/// Name of the argument where variable is used
arg_name: String,
},
#[error("too much recursion")]
RecursionError {},
}

impl ApolloDiagnostic {
Expand Down
70 changes: 47 additions & 23 deletions crates/apollo-compiler/src/validation/directive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::validation::{NodeLocation, RecursionGuard, RecursionStack};
use crate::{ast, schema, Node, ValidationDatabase};
use std::collections::{HashMap, HashSet};

use super::CycleError;

/// This struct just groups functions that are used to find self-referential directives.
/// The way to use it is to call `FindRecursiveDirective::check`.
struct FindRecursiveDirective<'s> {
Expand All @@ -14,7 +16,7 @@ impl FindRecursiveDirective<'_> {
&self,
seen: &mut RecursionGuard<'_>,
def: &schema::ExtendedType,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
match def {
schema::ExtendedType::Scalar(scalar_type_definition) => {
self.directives(seen, &scalar_type_definition.directives)?;
Expand Down Expand Up @@ -49,7 +51,7 @@ impl FindRecursiveDirective<'_> {
&self,
seen: &mut RecursionGuard<'_>,
input_value: &Node<ast::InputValueDefinition>,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
for directive in &input_value.directives {
self.directive(seen, directive)?;
}
Expand All @@ -66,7 +68,7 @@ impl FindRecursiveDirective<'_> {
&self,
seen: &mut RecursionGuard<'_>,
enum_value: &Node<ast::EnumValueDefinition>,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
for directive in &enum_value.directives {
self.directive(seen, directive)?;
}
Expand All @@ -78,7 +80,7 @@ impl FindRecursiveDirective<'_> {
&self,
seen: &mut RecursionGuard<'_>,
directives: &[schema::Component<ast::Directive>],
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
for directive in directives {
self.directive(seen, directive)?;
}
Expand All @@ -89,17 +91,18 @@ impl FindRecursiveDirective<'_> {
&self,
seen: &mut RecursionGuard<'_>,
directive: &Node<ast::Directive>,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
if !seen.contains(&directive.name) {
if let Some(def) = self.schema.directive_definitions.get(&directive.name) {
self.directive_definition(seen.push(&directive.name), def)?;
self.directive_definition(seen.push(&directive.name)?, def)
.map_err(|error| error.trace(directive))?;
}
} else if seen.first() == Some(&directive.name) {
// Only report an error & bail out early if this is the *initial* directive.
// This prevents raising confusing errors when a directive `@b` which is not
// self-referential uses a directive `@a` that *is*. The error with `@a` should
// only be reported on its definition, not on `@b`'s.
return Err(directive.clone());
return Err(CycleError::Recursed(vec![directive.clone()]));
}

Ok(())
Expand All @@ -109,7 +112,7 @@ impl FindRecursiveDirective<'_> {
&self,
mut seen: RecursionGuard<'_>,
def: &Node<ast::DirectiveDefinition>,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
for input_value in &def.arguments {
self.input_value(&mut seen, input_value)?;
}
Expand All @@ -120,7 +123,7 @@ impl FindRecursiveDirective<'_> {
fn check(
schema: &schema::Schema,
directive_def: &Node<ast::DirectiveDefinition>,
) -> Result<(), Node<ast::Directive>> {
) -> Result<(), CycleError<ast::Directive>> {
let mut recursion_stack = RecursionStack::with_root(directive_def.name.clone());
FindRecursiveDirective { schema }
.directive_definition(recursion_stack.guard(), directive_def)
Expand All @@ -143,22 +146,43 @@ pub(crate) fn validate_directive_definition(
// references itself directly.
//
// Returns Recursive Definition error.
if let Err(directive) = FindRecursiveDirective::check(&db.schema(), &def) {
if let Err(error) = FindRecursiveDirective::check(&db.schema(), &def) {
let definition_location = def.location();
let head_location = NodeLocation::recompose(def.location(), def.name.location());
let directive_location = directive.location();

diagnostics.push(
ApolloDiagnostic::new(
db,
definition_location,
DiagnosticData::RecursiveDirectiveDefinition {
name: def.name.to_string(),
},
)
.label(Label::new(head_location, "recursive directive definition"))
.label(Label::new(directive_location, "refers to itself here")),
);
let mut diagnostic = ApolloDiagnostic::new(
db,
definition_location,
DiagnosticData::RecursiveDirectiveDefinition {
name: def.name.to_string(),
},
)
.label(Label::new(head_location, "recursive directive definition"));

if let CycleError::Recursed(trace) = error {
if let Some((cyclical_application, path)) = trace.split_first() {
let mut prev_name = &def.name;
for directive_application in path.iter().rev() {
diagnostic = diagnostic.label(Label::new(
directive_application.location(),
format!(
"`{}` references `{}` here...",
prev_name, directive_application.name
),
));
prev_name = &directive_application.name;
}

diagnostic = diagnostic.label(Label::new(
cyclical_application.location(),
format!(
"`{}` circularly references `{}` here",
prev_name, cyclical_application.name
),
));
}
}

diagnostics.push(diagnostic);
}

diagnostics
Expand Down
48 changes: 23 additions & 25 deletions crates/apollo-compiler/src/validation/fragment.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::diagnostics::{ApolloDiagnostic, DiagnosticData, Label};
use crate::validation::{FileId, NodeLocation, RecursionGuard, RecursionStack};
use crate::validation::operation::OperationValidationConfig;
use crate::validation::{CycleError, FileId, NodeLocation, RecursionGuard, RecursionStack};
use crate::{ast, schema, Node, ValidationDatabase};
use std::collections::{HashMap, HashSet};

use super::operation::OperationValidationConfig;

/// Given a type definition, find all the type names that can be used for fragment spreading.
///
/// Spec: https://spec.graphql.org/October2021/#GetPossibleTypes()
Expand Down Expand Up @@ -300,13 +299,13 @@ pub(crate) fn validate_fragment_cycles(
named_fragments: &HashMap<ast::Name, Node<ast::FragmentDefinition>>,
selection_set: &[ast::Selection],
visited: &mut RecursionGuard<'_>,
) -> Result<(), Vec<Node<ast::FragmentSpread>>> {
) -> Result<(), CycleError<ast::FragmentSpread>> {
for selection in selection_set {
match selection {
ast::Selection::FragmentSpread(spread) => {
if visited.contains(&spread.fragment_name) {
if visited.first() == Some(&spread.fragment_name) {
return Err(vec![spread.clone()]);
return Err(CycleError::Recursed(vec![spread.clone()]));
}
continue;
}
Expand All @@ -315,12 +314,9 @@ pub(crate) fn validate_fragment_cycles(
detect_fragment_cycles(
named_fragments,
&fragment.selection_set,
&mut visited.push(&fragment.name),
&mut visited.push(&fragment.name)?,
)
.map_err(|mut trace| {
trace.push(spread.clone());
trace
})?;
.map_err(|error| error.trace(spread))?;
}
}
ast::Selection::InlineFragment(inline) => {
Expand Down Expand Up @@ -351,26 +347,28 @@ pub(crate) fn validate_fragment_cycles(
)
.label(Label::new(head_location, "recursive fragment definition"));

if let Some((cyclical_spread, path)) = cycle.split_first() {
let mut prev_name = &def.name;
for spread in path.iter().rev() {
if let CycleError::Recursed(trace) = cycle {
if let Some((cyclical_spread, path)) = trace.split_first() {
let mut prev_name = &def.name;
for spread in path.iter().rev() {
diagnostic = diagnostic.label(Label::new(
spread.location(),
format!(
"`{}` references `{}` here...",
prev_name, spread.fragment_name
),
));
prev_name = &spread.fragment_name;
}

diagnostic = diagnostic.label(Label::new(
spread.location(),
cyclical_spread.location(),
format!(
"`{}` references `{}` here...",
prev_name, spread.fragment_name
"`{}` circularly references `{}` here",
prev_name, cyclical_spread.fragment_name
),
));
prev_name = &spread.fragment_name;
}

diagnostic = diagnostic.label(Label::new(
cyclical_spread.location(),
format!(
"`{}` circularly references `{}` here",
prev_name, cyclical_spread.fragment_name
),
));
}

diagnostics.push(diagnostic);
Expand Down
63 changes: 41 additions & 22 deletions crates/apollo-compiler/src/validation/input_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
ast,
diagnostics::{ApolloDiagnostic, DiagnosticData, Label},
schema,
validation::{RecursionGuard, RecursionStack},
validation::{CycleError, RecursionGuard, RecursionStack},
Node, ValidationDatabase,
};
use std::collections::HashMap;
Expand All @@ -18,19 +18,20 @@ impl FindRecursiveInputValue<'_> {
&self,
seen: &mut RecursionGuard<'_>,
def: &Node<ast::InputValueDefinition>,
) -> Result<(), Node<ast::InputValueDefinition>> {
) -> Result<(), CycleError<ast::InputValueDefinition>> {
return match &*def.ty {
// NonNull type followed by Named type is the one that's not allowed
// to be cyclical, so this is only case we care about.
//
// Everything else may be a cyclical input value.
ast::Type::NonNullNamed(name) => {
if !seen.contains(name) {
if let Some(def) = self.db.ast_types().input_objects.get(name) {
self.input_object_definition(seen.push(name), def)?
if let Some(object_def) = self.db.ast_types().input_objects.get(name) {
self.input_object_definition(seen.push(name)?, object_def)
.map_err(|err| err.trace(def))?
}
} else if seen.first() == Some(name) {
return Err(def.clone());
return Err(CycleError::Recursed(vec![def.clone()]));
}

Ok(())
Expand All @@ -43,7 +44,7 @@ impl FindRecursiveInputValue<'_> {
&self,
mut seen: RecursionGuard<'_>,
input_object: &ast::TypeWithExtensions<ast::InputObjectTypeDefinition>,
) -> Result<(), Node<ast::InputValueDefinition>> {
) -> Result<(), CycleError<ast::InputValueDefinition>> {
for input_value in input_object.fields() {
self.input_value_definition(&mut seen, input_value)?;
}
Expand All @@ -54,7 +55,7 @@ impl FindRecursiveInputValue<'_> {
fn check(
db: &dyn ValidationDatabase,
input_object: &ast::TypeWithExtensions<ast::InputObjectTypeDefinition>,
) -> Result<(), Node<ast::InputValueDefinition>> {
) -> Result<(), CycleError<ast::InputValueDefinition>> {
let mut recursion_stack = RecursionStack::with_root(input_object.definition.name.clone());
FindRecursiveInputValue { db }
.input_object_definition(recursion_stack.guard(), input_object)
Expand Down Expand Up @@ -85,23 +86,41 @@ pub(crate) fn validate_input_object_definition(
Default::default(),
);

if let Err(input_val) = FindRecursiveInputValue::check(db, &input_object) {
let mut labels = vec![Label::new(
if let Err(error) = FindRecursiveInputValue::check(db, &input_object) {
let mut diagnostic = ApolloDiagnostic::new(
db,
input_object.definition.location(),
"cyclical input object definition",
)];
let loc = input_val.location();
labels.push(Label::new(loc, "refers to itself here"));
diagnostics.push(
ApolloDiagnostic::new(
db,
input_object.definition.location(),
DiagnosticData::RecursiveInputObjectDefinition {
name: input_object.definition.name.to_string(),
},
)
.labels(labels),
DiagnosticData::RecursiveInputObjectDefinition {
name: input_object.definition.name.to_string(),
},
)
.label(Label::new(
input_object.definition.location(),
"cyclical input object definition",
));

if let CycleError::Recursed(trace) = error {
if let Some((cyclical_reference, path)) = trace.split_first() {
let mut prev_name = &input_object.definition.name;
for reference in path.iter().rev() {
diagnostic = diagnostic.label(Label::new(
reference.location(),
format!("`{}` references `{}` here...", prev_name, reference.name),
));
prev_name = &reference.name;
}

diagnostic = diagnostic.label(Label::new(
cyclical_reference.location(),
format!(
"`{}` circularly references `{}` here",
prev_name, cyclical_reference.name
),
));
}
}

diagnostics.push(diagnostic);
}

// Fields in an Input Object Definition must be unique
Expand Down
Loading