Skip to content

Commit 22524fb

Browse files
chore: add sharable rpc handler (#320)
* chore: add rpc share handler * chore: lint * chore: update rpc abstract class * chore: add test
1 parent f1c5a33 commit 22524fb

File tree

2 files changed

+307
-9
lines changed

2 files changed

+307
-9
lines changed
Lines changed: 177 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,39 @@
11
import { InvalidParamsError, SnapError } from '@metamask/snaps-sdk';
2-
import { object } from 'superstruct';
3-
import type { Struct } from 'superstruct';
2+
import { constants } from 'starknet';
3+
import { object, string } from 'superstruct';
4+
import type { Struct, Infer } from 'superstruct';
45

5-
import { validateRequest, validateResponse } from './rpc';
6+
import type { StarknetAccount } from '../../test/utils';
7+
import { generateAccounts } from '../../test/utils';
8+
import type { SnapState } from '../types/snapState';
9+
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from './constants';
10+
import {
11+
AccountRpcController,
12+
RpcController,
13+
validateRequest,
14+
validateResponse,
15+
} from './rpc';
16+
import * as snapHelper from './snap';
17+
import * as snapUtils from './snapUtils';
18+
import * as starknetUtils from './starknetUtils';
619
import { AddressStruct } from './superstruct';
720

8-
const struct = object({
21+
jest.mock('./snap');
22+
jest.mock('./logger');
23+
24+
const validateStruct = object({
925
signerAddress: AddressStruct,
1026
});
1127

12-
const params = {
28+
const validateParam = {
1329
signerAddress:
1430
'0x04882a372da3dfe1c53170ad75893832469bf87b62b13e84662565c4a88f25cd',
1531
};
1632

1733
describe('validateRequest', () => {
1834
it('does not throw error if the request is valid', () => {
1935
expect(() =>
20-
validateRequest(params, struct as unknown as Struct),
36+
validateRequest(validateParam, validateStruct as unknown as Struct),
2137
).not.toThrow();
2238
});
2339

@@ -27,15 +43,15 @@ describe('validateRequest', () => {
2743
};
2844

2945
expect(() =>
30-
validateRequest(requestParams, struct as unknown as Struct),
46+
validateRequest(requestParams, validateStruct as unknown as Struct),
3147
).toThrow(InvalidParamsError);
3248
});
3349
});
3450

3551
describe('validateResponse', () => {
3652
it('does not throw error if the response is valid', () => {
3753
expect(() =>
38-
validateResponse(params, struct as unknown as Struct),
54+
validateResponse(validateParam, validateStruct as unknown as Struct),
3955
).not.toThrow();
4056
});
4157

@@ -45,7 +61,159 @@ describe('validateResponse', () => {
4561
};
4662

4763
expect(() =>
48-
validateResponse(response, struct as unknown as Struct),
64+
validateResponse(response, validateStruct as unknown as Struct),
4965
).toThrow(new SnapError('Invalid Response'));
5066
});
5167
});
68+
69+
describe('RpcController', () => {
70+
class MockRpc extends RpcController<string, string> {
71+
protected requestStruct = string();
72+
73+
protected responseStruct = string();
74+
75+
// Set it to public to be able to spy on it
76+
async handleRequest(params: string) {
77+
return `done ${params}`;
78+
}
79+
}
80+
81+
it('executes request', async () => {
82+
const rpc = new MockRpc();
83+
84+
const result = await rpc.execute('test');
85+
86+
expect(result).toBe('done test');
87+
});
88+
89+
it('throws `Failed to execute the rpc method` if an error was thrown', async () => {
90+
const rpc = new MockRpc();
91+
92+
jest
93+
.spyOn(MockRpc.prototype, 'handleRequest')
94+
.mockRejectedValue(new Error('error'));
95+
96+
await expect(rpc.execute('test')).rejects.toThrow(
97+
'Failed to execute the rpc method',
98+
);
99+
});
100+
101+
it('throws the actual error if an snap error was thrown', async () => {
102+
const rpc = new MockRpc();
103+
104+
await expect(rpc.execute(1 as unknown as string)).rejects.toThrow(
105+
'Expected a string, but received: 1',
106+
);
107+
});
108+
});
109+
110+
describe('AccountRpcController', () => {
111+
const state: SnapState = {
112+
accContracts: [],
113+
erc20Tokens: [],
114+
networks: [STARKNET_SEPOLIA_TESTNET_NETWORK],
115+
transactions: [],
116+
};
117+
118+
const RequestStruct = object({
119+
address: string(),
120+
chainId: string(),
121+
});
122+
123+
type Request = Infer<typeof RequestStruct>;
124+
125+
class MockAccountRpc extends AccountRpcController<Request, string> {
126+
protected requestStruct = RequestStruct;
127+
128+
protected responseStruct = string();
129+
130+
// Set it to public to be able to spy on it
131+
async handleRequest(param: Request) {
132+
return `done ${param.address} and ${param.chainId}`;
133+
}
134+
}
135+
136+
const mockAccount = async (network: constants.StarknetChainId) => {
137+
const accounts = await generateAccounts(network, 1);
138+
return accounts[0];
139+
};
140+
141+
const prepareExecute = async (account: StarknetAccount) => {
142+
const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn(
143+
snapUtils,
144+
'verifyIfAccountNeedUpgradeOrDeploy',
145+
);
146+
147+
const getKeysFromAddressSpy = jest.spyOn(
148+
starknetUtils,
149+
'getKeysFromAddress',
150+
);
151+
152+
const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData');
153+
154+
getStateDataSpy.mockResolvedValue(state);
155+
156+
getKeysFromAddressSpy.mockResolvedValue({
157+
privateKey: account.privateKey,
158+
publicKey: account.publicKey,
159+
addressIndex: account.addressIndex,
160+
derivationPath: account.derivationPath as unknown as any,
161+
});
162+
163+
verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis();
164+
165+
return {
166+
getKeysFromAddressSpy,
167+
getStateDataSpy,
168+
verifyIfAccountNeedUpgradeOrDeploySpy,
169+
};
170+
};
171+
172+
it('executes request', async () => {
173+
const chainId = constants.StarknetChainId.SN_SEPOLIA;
174+
const account = await mockAccount(chainId);
175+
await prepareExecute(account);
176+
const rpc = new MockAccountRpc();
177+
178+
const result = await rpc.execute({
179+
address: account.address,
180+
chainId,
181+
});
182+
183+
expect(result).toBe(`done ${account.address} and ${chainId}`);
184+
});
185+
186+
it('fetchs account before execute', async () => {
187+
const chainId = constants.StarknetChainId.SN_SEPOLIA;
188+
const account = await mockAccount(chainId);
189+
const { getKeysFromAddressSpy } = await prepareExecute(account);
190+
const rpc = new MockAccountRpc();
191+
192+
await rpc.execute({ address: account.address, chainId });
193+
194+
expect(getKeysFromAddressSpy).toHaveBeenCalled();
195+
});
196+
197+
it.each([true, false])(
198+
`assign verifyIfAccountNeedUpgradeOrDeploy's argument "showAlert" to %s if the constructor option 'showInvalidAccountAlert' is set to %s`,
199+
async (showInvalidAccountAlert: boolean) => {
200+
const chainId = constants.StarknetChainId.SN_SEPOLIA;
201+
const account = await mockAccount(chainId);
202+
const { verifyIfAccountNeedUpgradeOrDeploySpy } = await prepareExecute(
203+
account,
204+
);
205+
const rpc = new MockAccountRpc({
206+
showInvalidAccountAlert,
207+
});
208+
209+
await rpc.execute({ address: account.address, chainId });
210+
211+
expect(verifyIfAccountNeedUpgradeOrDeploySpy).toHaveBeenCalledWith(
212+
expect.any(Object),
213+
account.address,
214+
account.publicKey,
215+
showInvalidAccountAlert,
216+
);
217+
},
218+
);
219+
});

packages/starknet-snap/src/utils/rpc.ts

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
import type { getBIP44ChangePathString } from '@metamask/key-tree/dist/types/utils';
2+
import type { Json } from '@metamask/snaps-sdk';
13
import { InvalidParamsError, SnapError } from '@metamask/snaps-sdk';
24
import type { Struct } from 'superstruct';
35
import { assert } from 'superstruct';
46

7+
import type { Network, SnapState } from '../types/snapState';
8+
import { logger } from './logger';
9+
import { getBip44Deriver, getStateData } from './snap';
10+
import {
11+
getNetworkFromChainId,
12+
verifyIfAccountNeedUpgradeOrDeploy,
13+
} from './snapUtils';
14+
import { getKeysFromAddress } from './starknetUtils';
15+
516
/**
617
* Validates that the request parameters conform to the expected structure defined by the provided struct.
718
*
@@ -33,3 +44,122 @@ export function validateResponse<Params>(response: Params, struct: Struct) {
3344
throw new SnapError('Invalid Response') as unknown as Error;
3445
}
3546
}
47+
48+
export abstract class RpcController<
49+
Request extends Json,
50+
Response extends Json,
51+
> {
52+
/**
53+
* Superstruct for the request.
54+
*/
55+
protected abstract requestStruct: Struct;
56+
57+
/**
58+
* Superstruct for the response.
59+
*/
60+
protected abstract responseStruct: Struct;
61+
62+
protected abstract handleRequest(params: Request): Promise<Response>;
63+
64+
protected async preExecute(params: Request): Promise<void> {
65+
logger.info(`Request: ${JSON.stringify(params)}`);
66+
validateRequest(params, this.requestStruct);
67+
}
68+
69+
protected async postExecute(response: Response): Promise<void> {
70+
logger.info(`Response: ${JSON.stringify(response)}`);
71+
validateResponse(response, this.responseStruct);
72+
}
73+
74+
/**
75+
* A method to execute the rpc method.
76+
*
77+
* @param params - An struct contains the require parameter for the request.
78+
* @returns A promise that resolves to an json.
79+
*/
80+
async execute(params: Request): Promise<Response> {
81+
try {
82+
await this.preExecute(params);
83+
const resp = await this.handleRequest(params);
84+
await this.postExecute(resp);
85+
return resp;
86+
} catch (error) {
87+
logger.error('Failed to execute the rpc method', error);
88+
89+
if (error instanceof SnapError) {
90+
throw error as unknown as Error;
91+
}
92+
93+
throw new Error('Failed to execute the rpc method');
94+
}
95+
}
96+
}
97+
98+
// TODO: the Type should be moved to a common place
99+
export type AccountRpcParams = Json & {
100+
chainId: string;
101+
address: string;
102+
};
103+
104+
// TODO: the Account object should move into a account manager for generate account
105+
export type Account = {
106+
privateKey: string;
107+
publicKey: string;
108+
addressIndex: number;
109+
// This is the derivation path of the address, it is used in `getNextAddressIndex` to find the account in state where matching the same derivation path
110+
derivationPath: ReturnType<typeof getBIP44ChangePathString>;
111+
};
112+
113+
export type AccountRpcControllerOptions = Json & {
114+
showInvalidAccountAlert: boolean;
115+
};
116+
117+
export abstract class AccountRpcController<
118+
Request extends AccountRpcParams,
119+
Response extends Json,
120+
> extends RpcController<Request, Response> {
121+
protected account: Account;
122+
123+
protected network: Network;
124+
125+
protected options: AccountRpcControllerOptions;
126+
127+
protected defaultOptions: AccountRpcControllerOptions = {
128+
showInvalidAccountAlert: true,
129+
};
130+
131+
constructor(options?: AccountRpcControllerOptions) {
132+
super();
133+
this.options = Object.assign({}, this.defaultOptions, options);
134+
}
135+
136+
protected async preExecute(params: Request): Promise<void> {
137+
await super.preExecute(params);
138+
139+
const { chainId, address } = params;
140+
const { showInvalidAccountAlert } = this.options;
141+
142+
const deriver = await getBip44Deriver();
143+
// TODO: Instead of getting the state directly, we should implement state management to consolidate the state fetching
144+
const state = await getStateData<SnapState>();
145+
146+
// TODO: getNetworkFromChainId from state is still needed, due to it is supporting in get-starknet at this moment
147+
this.network = getNetworkFromChainId(state, chainId);
148+
149+
// TODO: This method should be refactored to get the account from an account manager
150+
this.account = await getKeysFromAddress(
151+
deriver,
152+
this.network,
153+
state,
154+
address,
155+
);
156+
157+
// TODO: rename this method to verifyAccount
158+
await verifyIfAccountNeedUpgradeOrDeploy(
159+
this.network,
160+
address,
161+
this.account.publicKey,
162+
showInvalidAccountAlert,
163+
);
164+
}
165+
}

0 commit comments

Comments
 (0)