Skip to content

Commit 12a50ec

Browse files
authored
Merge pull request #51 from ConsenSys/fix-selector-computation
Fix canonical signature hashes (`selector`s) computation
2 parents 3a7bb01 + 08fe8e8 commit 12a50ec

File tree

8 files changed

+714
-20
lines changed

8 files changed

+714
-20
lines changed

src/ast/implementation/declaration/variable_declaration.ts

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import { getUserDefinedTypeFQName } from "../../../types";
12
import { ASTNode } from "../../ast_node";
2-
import { DataLocation, Mutability, StateVariableVisibility } from "../../constants";
3+
import { ContractKind, DataLocation, Mutability, StateVariableVisibility } from "../../constants";
34
import { encodeSignature } from "../../utils";
45
import { Expression } from "../expression/expression";
56
import { OverrideSpecifier } from "../meta/override_specifier";
@@ -158,33 +159,55 @@ export class VariableDeclaration extends ASTNode {
158159
const type = this.vType;
159160

160161
if (type instanceof UserDefinedTypeName) {
161-
const declaration = type.vReferencedDeclaration;
162+
const site = this.getClosestParentByType(ContractDefinition);
162163

163-
if (declaration instanceof StructDefinition) {
164-
const signatures = declaration.vMembers.map(
165-
(member) => member.canonicalSignatureType
164+
if (site === undefined) {
165+
throw new Error(
166+
`Unable to compute canonical signature type for variables outside of contract: ${this.print()}`
166167
);
167-
168-
return "(" + signatures.join(",") + ")";
169168
}
170169

171-
if (declaration instanceof ContractDefinition) {
172-
return "address";
173-
}
170+
const declaration = type.vReferencedDeclaration;
174171

175-
if (declaration instanceof EnumDefinition) {
176-
const length = declaration.children.length;
172+
if (site.kind === ContractKind.Library) {
173+
if (
174+
declaration instanceof ContractDefinition ||
175+
declaration instanceof StructDefinition ||
176+
declaration instanceof EnumDefinition
177+
) {
178+
return getUserDefinedTypeFQName(declaration);
179+
}
180+
} else {
181+
if (declaration instanceof StructDefinition) {
182+
const types = declaration.vMembers.map(
183+
(member) => member.canonicalSignatureType
184+
);
177185

178-
for (let n = 8; n <= 32; n += 8) {
179-
if (length < 2 ** n) {
180-
return "uint" + n;
181-
}
186+
return "(" + types.join(",") + ")";
182187
}
183188

184-
throw new Error("Unable to detect enum type size - member count exceeds 2 ** 32");
189+
if (declaration instanceof ContractDefinition) {
190+
return "address";
191+
}
192+
193+
if (declaration instanceof EnumDefinition) {
194+
const length = declaration.children.length;
195+
196+
for (let n = 8; n <= 32; n += 8) {
197+
if (length < 2 ** n) {
198+
return "uint" + n;
199+
}
200+
}
201+
202+
throw new Error(
203+
"Unable to detect enum type size - member count exceeds 2 ** 32"
204+
);
205+
}
185206
}
186207

187-
throw new Error("Unknown user defined type");
208+
throw new Error(
209+
`Unhandled user-defined type when computing canonical signature type: ${declaration.print()}`
210+
);
188211
}
189212

190213
return this.typeString;

src/types/utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ export function generalizeType(type: TypeNode): [TypeNode, DataLocation | undefi
130130
return [type, undefined];
131131
}
132132

133-
function getUserDefinedTypeFQName(
133+
export function getUserDefinedTypeFQName(
134134
def: ContractDefinition | StructDefinition | EnumDefinition
135135
): string {
136136
return def.vScope instanceof ContractDefinition ? `${def.vScope.name}.${def.name}` : def.name;

test/integration/factory/copy.spec.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ describe(`ASTNodeFactory.copy() validation`, () => {
5151
.join("\n")
5252
.replace(new RegExp(process.cwd(), "g"), ".");
5353

54+
// Uncomment next line to update snapshots
55+
// fse.writeFileSync(snapshot, result, { encoding: "utf-8" });
56+
5457
const content = fse.readFileSync(snapshot, { encoding: "utf-8" });
5558

5659
expect(result).toEqual(content);

test/integration/sol-ast-compile/tree.spec.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ const cases = [
77
"test/samples/solidity/declarations/interface_060.sol",
88
"test/samples/solidity/declarations/interface_060.tree.txt"
99
],
10-
["test/samples/solidity/interface_id.sol", "test/samples/solidity/interface_id.tree.txt"]
10+
["test/samples/solidity/interface_id.sol", "test/samples/solidity/interface_id.tree.txt"],
11+
[
12+
"test/samples/solidity/library_fun_overloads.sol",
13+
"test/samples/solidity/library_fun_overloads.tree.txt"
14+
],
15+
["test/samples/solidity/fun_selectors.sol", "test/samples/solidity/fun_selectors.tree.txt"]
1116
];
1217

1318
for (const [sample, snapshot] of cases) {
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
contract D {
2+
uint a;
3+
}
4+
5+
struct T {
6+
address y;
7+
}
8+
9+
enum X {
10+
A
11+
}
12+
13+
library Foo {
14+
enum Y {
15+
A
16+
}
17+
18+
struct S {
19+
uint x;
20+
}
21+
22+
function funD(D d) public {}
23+
function funS(S memory s) public {}
24+
function funT(T memory t) public {}
25+
function funX(X x) public {}
26+
function funY(Y y) public {}
27+
}
28+
29+
interface Bar {
30+
function funD(D d) external;
31+
function funS(Foo.S calldata s) external;
32+
function funT(T calldata s) external;
33+
function funX(X x) external;
34+
function funY(Foo.Y y) external;
35+
}
36+
37+
contract Baz {
38+
function funD(D d) external {}
39+
function funS(Foo.S calldata s) external {}
40+
function funT(T calldata s) external {}
41+
function funX(X x) external {}
42+
function funY(Foo.Y y) external {}
43+
44+
function main() public {
45+
assert(Foo.funD.selector == bytes4(keccak256("funD(D)")));
46+
assert(Foo.funD.selector == 0x46467911);
47+
48+
assert(Foo.funS.selector == bytes4(keccak256("funS(Foo.S)")));
49+
assert(Foo.funS.selector == 0x2d9de9c3);
50+
51+
assert(Foo.funT.selector == bytes4(keccak256("funT(T)")));
52+
assert(Foo.funT.selector == 0x8c551140);
53+
54+
assert(Foo.funX.selector == bytes4(keccak256("funX(X)")));
55+
assert(Foo.funX.selector == 0x20c5a75c);
56+
57+
assert(Foo.funY.selector == bytes4(keccak256("funY(Foo.Y)")));
58+
assert(Foo.funY.selector == 0xc79a4d37);
59+
60+
assert(Bar.funD.selector == bytes4(keccak256("funD(address)")));
61+
assert(Bar.funD.selector == 0x4e209091);
62+
63+
assert(Bar.funS.selector == bytes4(keccak256("funS((uint256))")));
64+
assert(Bar.funS.selector == 0xe373f962);
65+
66+
assert(Bar.funT.selector == bytes4(keccak256("funT((address))")));
67+
assert(Bar.funT.selector == 0x3793b6f0);
68+
69+
assert(Bar.funX.selector == bytes4(keccak256("funX(uint8)")));
70+
assert(Bar.funX.selector == 0x0a42a215);
71+
72+
assert(Bar.funY.selector == bytes4(keccak256("funY(uint8)")));
73+
assert(Bar.funY.selector == 0x0a035664);
74+
75+
assert(type(Bar).interfaceId == 0x9a812b72);
76+
77+
assert(Baz.funD.selector == bytes4(keccak256("funD(address)")));
78+
assert(Baz.funD.selector == 0x4e209091);
79+
80+
assert(Baz.funS.selector == bytes4(keccak256("funS((uint256))")));
81+
assert(Baz.funS.selector == 0xe373f962);
82+
83+
assert(Baz.funT.selector == bytes4(keccak256("funT((address))")));
84+
assert(Baz.funT.selector == 0x3793b6f0);
85+
86+
assert(Baz.funX.selector == bytes4(keccak256("funX(uint8)")));
87+
assert(Baz.funX.selector == 0x0a42a215);
88+
89+
assert(Baz.funY.selector == bytes4(keccak256("funY(uint8)")));
90+
assert(Baz.funY.selector == 0x0a035664);
91+
}
92+
}

0 commit comments

Comments
 (0)