Skip to content

Commit

Permalink
chore: add sharable rpc handler (#320)
Browse files Browse the repository at this point in the history
* chore: add rpc share handler

* chore: lint

* chore: update rpc abstract class

* chore: add test
  • Loading branch information
stanleyyconsensys authored Aug 15, 2024
1 parent f1c5a33 commit 22524fb
Show file tree
Hide file tree
Showing 2 changed files with 307 additions and 9 deletions.
186 changes: 177 additions & 9 deletions packages/starknet-snap/src/utils/rpc.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import { InvalidParamsError, SnapError } from '@metamask/snaps-sdk';
import { object } from 'superstruct';
import type { Struct } from 'superstruct';
import { constants } from 'starknet';
import { object, string } from 'superstruct';
import type { Struct, Infer } from 'superstruct';

import { validateRequest, validateResponse } from './rpc';
import type { StarknetAccount } from '../../test/utils';
import { generateAccounts } from '../../test/utils';
import type { SnapState } from '../types/snapState';
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from './constants';
import {
AccountRpcController,
RpcController,
validateRequest,
validateResponse,
} from './rpc';
import * as snapHelper from './snap';
import * as snapUtils from './snapUtils';
import * as starknetUtils from './starknetUtils';
import { AddressStruct } from './superstruct';

const struct = object({
jest.mock('./snap');
jest.mock('./logger');

const validateStruct = object({
signerAddress: AddressStruct,
});

const params = {
const validateParam = {
signerAddress:
'0x04882a372da3dfe1c53170ad75893832469bf87b62b13e84662565c4a88f25cd',
};

describe('validateRequest', () => {
it('does not throw error if the request is valid', () => {
expect(() =>
validateRequest(params, struct as unknown as Struct),
validateRequest(validateParam, validateStruct as unknown as Struct),
).not.toThrow();
});

Expand All @@ -27,15 +43,15 @@ describe('validateRequest', () => {
};

expect(() =>
validateRequest(requestParams, struct as unknown as Struct),
validateRequest(requestParams, validateStruct as unknown as Struct),
).toThrow(InvalidParamsError);
});
});

describe('validateResponse', () => {
it('does not throw error if the response is valid', () => {
expect(() =>
validateResponse(params, struct as unknown as Struct),
validateResponse(validateParam, validateStruct as unknown as Struct),
).not.toThrow();
});

Expand All @@ -45,7 +61,159 @@ describe('validateResponse', () => {
};

expect(() =>
validateResponse(response, struct as unknown as Struct),
validateResponse(response, validateStruct as unknown as Struct),
).toThrow(new SnapError('Invalid Response'));
});
});

describe('RpcController', () => {
class MockRpc extends RpcController<string, string> {
protected requestStruct = string();

protected responseStruct = string();

// Set it to public to be able to spy on it
async handleRequest(params: string) {
return `done ${params}`;
}
}

it('executes request', async () => {
const rpc = new MockRpc();

const result = await rpc.execute('test');

expect(result).toBe('done test');
});

it('throws `Failed to execute the rpc method` if an error was thrown', async () => {
const rpc = new MockRpc();

jest
.spyOn(MockRpc.prototype, 'handleRequest')
.mockRejectedValue(new Error('error'));

await expect(rpc.execute('test')).rejects.toThrow(
'Failed to execute the rpc method',
);
});

it('throws the actual error if an snap error was thrown', async () => {
const rpc = new MockRpc();

await expect(rpc.execute(1 as unknown as string)).rejects.toThrow(
'Expected a string, but received: 1',
);
});
});

describe('AccountRpcController', () => {
const state: SnapState = {
accContracts: [],
erc20Tokens: [],
networks: [STARKNET_SEPOLIA_TESTNET_NETWORK],
transactions: [],
};

const RequestStruct = object({
address: string(),
chainId: string(),
});

type Request = Infer<typeof RequestStruct>;

class MockAccountRpc extends AccountRpcController<Request, string> {
protected requestStruct = RequestStruct;

protected responseStruct = string();

// Set it to public to be able to spy on it
async handleRequest(param: Request) {
return `done ${param.address} and ${param.chainId}`;
}
}

const mockAccount = async (network: constants.StarknetChainId) => {
const accounts = await generateAccounts(network, 1);
return accounts[0];
};

const prepareExecute = async (account: StarknetAccount) => {
const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn(
snapUtils,
'verifyIfAccountNeedUpgradeOrDeploy',
);

const getKeysFromAddressSpy = jest.spyOn(
starknetUtils,
'getKeysFromAddress',
);

const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData');

getStateDataSpy.mockResolvedValue(state);

getKeysFromAddressSpy.mockResolvedValue({
privateKey: account.privateKey,
publicKey: account.publicKey,
addressIndex: account.addressIndex,
derivationPath: account.derivationPath as unknown as any,
});

verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis();

return {
getKeysFromAddressSpy,
getStateDataSpy,
verifyIfAccountNeedUpgradeOrDeploySpy,
};
};

it('executes request', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
await prepareExecute(account);
const rpc = new MockAccountRpc();

const result = await rpc.execute({
address: account.address,
chainId,
});

expect(result).toBe(`done ${account.address} and ${chainId}`);
});

it('fetchs account before execute', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
const { getKeysFromAddressSpy } = await prepareExecute(account);
const rpc = new MockAccountRpc();

await rpc.execute({ address: account.address, chainId });

expect(getKeysFromAddressSpy).toHaveBeenCalled();
});

it.each([true, false])(
`assign verifyIfAccountNeedUpgradeOrDeploy's argument "showAlert" to %s if the constructor option 'showInvalidAccountAlert' is set to %s`,
async (showInvalidAccountAlert: boolean) => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
const { verifyIfAccountNeedUpgradeOrDeploySpy } = await prepareExecute(
account,
);
const rpc = new MockAccountRpc({
showInvalidAccountAlert,
});

await rpc.execute({ address: account.address, chainId });

expect(verifyIfAccountNeedUpgradeOrDeploySpy).toHaveBeenCalledWith(
expect.any(Object),
account.address,
account.publicKey,
showInvalidAccountAlert,
);
},
);
});
130 changes: 130 additions & 0 deletions packages/starknet-snap/src/utils/rpc.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import type { getBIP44ChangePathString } from '@metamask/key-tree/dist/types/utils';
import type { Json } from '@metamask/snaps-sdk';
import { InvalidParamsError, SnapError } from '@metamask/snaps-sdk';
import type { Struct } from 'superstruct';
import { assert } from 'superstruct';

import type { Network, SnapState } from '../types/snapState';
import { logger } from './logger';
import { getBip44Deriver, getStateData } from './snap';
import {
getNetworkFromChainId,
verifyIfAccountNeedUpgradeOrDeploy,
} from './snapUtils';
import { getKeysFromAddress } from './starknetUtils';

/**
* Validates that the request parameters conform to the expected structure defined by the provided struct.
*
Expand Down Expand Up @@ -33,3 +44,122 @@ export function validateResponse<Params>(response: Params, struct: Struct) {
throw new SnapError('Invalid Response') as unknown as Error;
}
}

export abstract class RpcController<
Request extends Json,
Response extends Json,
> {
/**
* Superstruct for the request.
*/
protected abstract requestStruct: Struct;

/**
* Superstruct for the response.
*/
protected abstract responseStruct: Struct;

protected abstract handleRequest(params: Request): Promise<Response>;

protected async preExecute(params: Request): Promise<void> {
logger.info(`Request: ${JSON.stringify(params)}`);
validateRequest(params, this.requestStruct);
}

protected async postExecute(response: Response): Promise<void> {
logger.info(`Response: ${JSON.stringify(response)}`);
validateResponse(response, this.responseStruct);
}

/**
* A method to execute the rpc method.
*
* @param params - An struct contains the require parameter for the request.
* @returns A promise that resolves to an json.
*/
async execute(params: Request): Promise<Response> {
try {
await this.preExecute(params);
const resp = await this.handleRequest(params);
await this.postExecute(resp);
return resp;
} catch (error) {
logger.error('Failed to execute the rpc method', error);

if (error instanceof SnapError) {
throw error as unknown as Error;
}

throw new Error('Failed to execute the rpc method');
}
}
}

// TODO: the Type should be moved to a common place
export type AccountRpcParams = Json & {
chainId: string;
address: string;
};

// TODO: the Account object should move into a account manager for generate account
export type Account = {
privateKey: string;
publicKey: string;
addressIndex: number;
// This is the derivation path of the address, it is used in `getNextAddressIndex` to find the account in state where matching the same derivation path
derivationPath: ReturnType<typeof getBIP44ChangePathString>;
};

export type AccountRpcControllerOptions = Json & {
showInvalidAccountAlert: boolean;
};

export abstract class AccountRpcController<
Request extends AccountRpcParams,
Response extends Json,
> extends RpcController<Request, Response> {
protected account: Account;

protected network: Network;

protected options: AccountRpcControllerOptions;

protected defaultOptions: AccountRpcControllerOptions = {
showInvalidAccountAlert: true,
};

constructor(options?: AccountRpcControllerOptions) {
super();
this.options = Object.assign({}, this.defaultOptions, options);
}

protected async preExecute(params: Request): Promise<void> {
await super.preExecute(params);

const { chainId, address } = params;
const { showInvalidAccountAlert } = this.options;

const deriver = await getBip44Deriver();
// TODO: Instead of getting the state directly, we should implement state management to consolidate the state fetching
const state = await getStateData<SnapState>();

// TODO: getNetworkFromChainId from state is still needed, due to it is supporting in get-starknet at this moment
this.network = getNetworkFromChainId(state, chainId);

// TODO: This method should be refactored to get the account from an account manager
this.account = await getKeysFromAddress(
deriver,
this.network,
state,
address,
);

// TODO: rename this method to verifyAccount
await verifyIfAccountNeedUpgradeOrDeploy(
this.network,
address,
this.account.publicKey,
showInvalidAccountAlert,
);
}
}

0 comments on commit 22524fb

Please sign in to comment.