Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 66 additions & 71 deletions crates/dyn-abi/src/eip712/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,13 @@ struct DfsContext<'a> {
/// type graph and traversing the dep graph.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Resolver {
/// Nodes in the graph
// NOTE: Non-duplication of names must be enforced. See note on impl of Ord
// for TypeDef
// INVARIANT: if a type name is in `nodes`, then it is also in `edges`.
// NOTE: Non-duplication of names must be enforced. See note on `impl Ord for TypeDef`.
/// Nodes in the graph.
/// Type name to definition.
nodes: BTreeMap<String, TypeDef>,
/// Edges from a type name to its dependencies.
/// Type name => directly dependent type names.
edges: BTreeMap<String, Vec<String>>,
}

Expand Down Expand Up @@ -262,34 +264,36 @@ impl Resolver {
}

/// Detect cycles in the subgraph rooted at `type_name`
fn detect_cycle<'a>(&'a self, type_name: &str, context: &mut DfsContext<'a>) -> bool {
let ty = match self.nodes.get(type_name) {
Some(ty) => ty,
None => return false,
};
fn detect_cycle(&self, type_name: &str) -> Result<()> {
match self.detect_cycle_inner(type_name, &mut DfsContext::default()) {
true => Err(Error::circular_dependency(type_name)),
false => Ok(()),
}
}

fn detect_cycle_inner<'a>(&'a self, type_name: &str, context: &mut DfsContext<'a>) -> bool {
let Some(ty) = self.nodes.get(type_name) else { return false };

// Detect cycle.
if context.stack.contains(type_name) {
return true;
}
if context.visited.contains(ty) {
// Mark as visited.
if !context.visited.insert(ty) {
return false;
}

// update visited and stack
context.visited.insert(ty);
context.stack.insert(&ty.type_name);

if self
.edges
.get(&ty.type_name)
.unwrap()
.iter()
.any(|edge| self.detect_cycle(edge, context))
{
return true;
let edges = self.edges(ty);
if !edges.is_empty() {
context.stack.insert(&ty.type_name);
for edge in edges {
if self.detect_cycle_inner(edge, context) {
return true;
}
}
context.stack.remove(type_name);
}

context.stack.remove(type_name);
false
}

Expand All @@ -310,13 +314,12 @@ impl Resolver {
/// Ingest a type.
pub fn ingest(&mut self, type_def: TypeDef) {
let type_name = type_def.type_name.to_owned();

// Insert the edges into the graph
{
let entry = self.edges.entry(type_name.clone()).or_default();
for prop in &type_def.props {
entry.push(prop.root_type_name().to_owned());
}
} // entry dropped here
let entry = self.edges.entry(type_name.clone()).or_default();
for prop in &type_def.props {
entry.push(prop.root_type_name().to_owned());
}

// Insert the node into the graph
self.nodes.insert(type_name, type_def);
Expand All @@ -335,60 +338,50 @@ impl Resolver {
fn linearize_into<'a>(
&'a self,
resolution: &mut Vec<&'a TypeDef>,
root_type: RootType<'_>,
root_type: &str,
) -> Result<()> {
if root_type.try_basic_solidity().is_ok() {
return Ok(());
}

let this_type = self
.nodes
.get(root_type.span())
.ok_or_else(|| Error::missing_type(root_type.span()))?;

let edges: &Vec<String> = self.edges.get(root_type.span()).unwrap();

let this_type = match self.nodes.get(root_type) {
Some(ty) => ty,
None if RootType::parse_eip712(root_type)
.is_ok_and(|rt| rt.try_basic_solidity().is_ok()) =>
{
return Ok(());
}
None => return Err(Error::missing_type(root_type)),
};
if !resolution.contains(&this_type) {
resolution.push(this_type);
for edge in edges {
let rt = edge.as_str().try_into()?;
self.linearize_into(resolution, rt)?;
for edge in self.edges(this_type) {
self.linearize_into(resolution, edge)?;
}
}

Ok(())
}

/// This function linearizes a type into a list of typedefs of its
/// dependencies.
/// This function linearizes a type into a list of typedefs of its dependencies.
pub fn linearize(&self, type_name: &str) -> Result<Vec<&TypeDef>> {
let mut context = DfsContext::default();
if self.detect_cycle(type_name, &mut context) {
return Err(Error::circular_dependency(type_name));
}
let root_type = RootType::parse_eip712(type_name)?;
self.detect_cycle(type_name)?;
let mut resolution = vec![];
self.linearize_into(&mut resolution, root_type)?;
self.linearize_into(&mut resolution, type_name)?;
Ok(resolution)
}

/// Resolve a typename into a [`crate::DynSolType`] or return an error if
/// the type is missing, or contains a circular dependency.
pub fn resolve(&self, type_name: &str) -> Result<DynSolType> {
if self.detect_cycle(type_name, &mut Default::default()) {
return Err(Error::circular_dependency(type_name));
}
self.unchecked_resolve(&TypeSpecifier::parse_eip712(type_name)?)
self.detect_cycle(type_name)?;
self.resolve_unchecked(&TypeSpecifier::parse_eip712(type_name)?)
}

/// Resolve a type into a [`crate::DynSolType`] without checking for cycles.
fn unchecked_resolve(&self, type_spec: &TypeSpecifier<'_>) -> Result<DynSolType> {
fn resolve_unchecked(&self, type_spec: &TypeSpecifier<'_>) -> Result<DynSolType> {
let ty = match &type_spec.stem {
TypeStem::Root(root) => self.resolve_root_type(*root),
TypeStem::Tuple(tuple) => tuple
.types
.iter()
.map(|ty| self.unchecked_resolve(ty))
.map(|ty| self.resolve_unchecked(ty))
.collect::<Result<_, _>>()
.map(DynSolType::Tuple),
}?;
Expand All @@ -410,31 +403,33 @@ impl Resolver {
let prop_names: Vec<_> = ty.prop_names().map(str::to_string).collect();
let tuple: Vec<_> = ty
.prop_types()
.map(|ty| self.unchecked_resolve(&TypeSpecifier::parse_eip712(ty)?))
.map(|ty| self.resolve_unchecked(&TypeSpecifier::parse_eip712(ty)?))
.collect::<Result<_, _>>()?;

Ok(DynSolType::CustomStruct { name: ty.type_name.clone(), prop_names, tuple })
}

fn edges(&self, ty: &TypeDef) -> &[String] {
self.edges.get(&ty.type_name).expect("no edges for node")
}

/// Encode the type into an EIP-712 `encodeType` string
///
/// <https://eips.ethereum.org/EIPS/eip-712#definition-of-encodetype>
pub fn encode_type(&self, name: &str) -> Result<String> {
let linear = self.linearize(name)?;
if linear.is_empty() {
let defs = self.linearize(name)?;
if defs.is_empty() {
return Err(Error::missing_type(name));
}
let first = linear.first().unwrap().eip712_encode_type();

// Sort references by name (eip-712 encodeType spec)
let mut sorted_refs =
linear[1..].iter().map(|t| t.eip712_encode_type()).collect::<Vec<String>>();
sorted_refs.sort();
let mut encoded = defs.into_iter().map(|x| x.eip712_encode_type()).collect::<Vec<_>>();
encoded[1..].sort_unstable();

Ok(sorted_refs.iter().fold(first, |mut acc, s| {
acc.push_str(s);
acc
}))
let mut output = String::new();
for ty in encoded {
output.push_str(&ty);
}
debug_assert!(!output.is_empty());
Ok(output)
}

/// Compute the keccak256 hash of the EIP-712 `encodeType` string.
Expand Down Expand Up @@ -539,7 +534,7 @@ mod tests {
vec![PropertyDef::new_unchecked("A", "myA")],
));

assert!(graph.detect_cycle("A", &mut DfsContext::default()));
assert!(graph.detect_cycle_inner("A", &mut DfsContext::default()));
}

#[test]
Expand Down
14 changes: 8 additions & 6 deletions crates/dyn-abi/src/eip712/typed_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,19 +810,21 @@ mod tests {
// Therefore the `try_as_basic_solidity` skips over them, returning `MissingType` because
// the list of linearized types is empty.
for primary in ["T.", "T.U", "bool", "uint256"] {
for set_types in [false, true] {
let typed_data = get_typed_data(primary, set_types);
let err = typed_data.eip712_signing_hash().unwrap_err();
assert_eq!(err, Error::missing_type(primary));
}
let typed_data = get_typed_data(primary, false);
let err = typed_data.eip712_signing_hash().unwrap_err();
assert_eq!(err, Error::missing_type(primary));

let typed_data = get_typed_data(primary, true);
let err = typed_data.eip712_signing_hash().unwrap_err();
assert!(err.to_string().contains("mismatch"), "{err}");
}

// Invalid syntax.
for primary in ["T[]", "string[]", "uint256[]", "(bool,string)", "(bool,string)[]"] {
for set_types in [false, true] {
let typed_data = get_typed_data(primary, set_types);
let err = typed_data.eip712_signing_hash().unwrap_err();
assert!(err.to_string().contains("parser error"), "{err}");
assert_eq!(err, Error::missing_type(primary));
}
}
}
Expand Down