Skip to content

Commit

Permalink
refactor: revamp starkNet_signMessage (#321)
Browse files Browse the repository at this point in the history
* chore: revamp sign message

* chore: add comment
  • Loading branch information
stanleyyconsensys authored Aug 15, 2024
1 parent 22524fb commit 1922348
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 190 deletions.
7 changes: 3 additions & 4 deletions packages/starknet-snap/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ import { getTransactions } from './getTransactions';
import { getTransactionStatus } from './getTransactionStatus';
import { getValue } from './getValue';
import { recoverAccounts } from './recoverAccounts';
import type { SignMessageParams } from './rpcs';
import { signMessage } from './rpcs';
import { sendTransaction } from './sendTransaction';
import { signDeclareTransaction } from './signDeclareTransaction';
import { signDeployAccountTransaction } from './signDeployAccountTransaction';
import type { SignMessageParams } from './signMessage';
import { signMessage } from './signMessage';
import { signTransaction } from './signTransaction';
import { switchNetwork } from './switchNetwork';
import type {
Expand Down Expand Up @@ -175,9 +175,8 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
);

case 'starkNet_signMessage':
return await signMessage(
return await signMessage.execute(
apiParams.requestParams as unknown as SignMessageParams,
state,
);

case 'starkNet_signTransaction':
Expand Down
1 change: 1 addition & 0 deletions packages/starknet-snap/src/rpcs/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from './signMessage';
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ import {
} from '@metamask/snaps-sdk';
import { constants } from 'starknet';

import type { StarknetAccount } from '../test/utils';
import { generateAccounts } from '../test/utils';
import typedDataExample from './__tests__/fixture/typedDataExample.json';
import type { SignMessageParams } from './signMessage';
import type { StarknetAccount } from '../../test/utils';
import { generateAccounts } from '../../test/utils';
import typedDataExample from '../__tests__/fixture/typedDataExample.json';
import type { SnapState } from '../types/snapState';
import { toJson } from '../utils';
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants';
import * as snapHelper from '../utils/snap';
import * as snapUtils from '../utils/snapUtils';
import * as starknetUtils from '../utils/starknetUtils';
import { signMessage } from './signMessage';
import type { SnapState } from './types/snapState';
import { toJson } from './utils';
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from './utils/constants';
import * as snapHelper from './utils/snap';
import * as snapUtils from './utils/snapUtils';
import * as starknetUtils from './utils/starknetUtils';
import type { SignMessageParams } from './signMessage';

jest.mock('./utils/snap');
jest.mock('./utils/logger');
jest.mock('../utils/snap');
jest.mock('../utils/logger');

describe('signMessage', function () {
describe('signMessage', () => {
const state: SnapState = {
accContracts: [],
erc20Tokens: [],
Expand All @@ -33,6 +33,7 @@ describe('signMessage', function () {
};

const prepareSignMessageMock = async (account: StarknetAccount) => {
const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData');
const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn(
snapUtils,
'verifyIfAccountNeedUpgradeOrDeploy',
Expand All @@ -52,6 +53,7 @@ describe('signMessage', function () {

verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis();
confirmDialogSpy.mockResolvedValue(true);
getStateDataSpy.mockResolvedValue(state);

return {
getKeysFromAddressSpy,
Expand All @@ -60,7 +62,7 @@ describe('signMessage', function () {
};
};

it('signs message correctly', async function () {
it('signs message correctly', async () => {
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);

await prepareSignMessageMock(account);
Expand All @@ -73,27 +75,27 @@ describe('signMessage', function () {

const request = {
chainId: constants.StarknetChainId.SN_SEPOLIA,
signerAddress: account.address,
address: account.address,
typedDataMessage: typedDataExample,
};
const result = await signMessage(request, state);
const result = await signMessage.execute(request);

expect(result).toStrictEqual(expectedResult);
});

it('renders confirmation dialog', async function () {
it('renders confirmation dialog', async () => {
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);

const { confirmDialogSpy } = await prepareSignMessageMock(account);

const request = {
chainId: constants.StarknetChainId.SN_SEPOLIA,
signerAddress: account.address,
address: account.address,
typedDataMessage: typedDataExample,
enableAuthorize: true,
};

await signMessage(request, state);
await signMessage.execute(request);

const calls = confirmDialogSpy.mock.calls[0][0];
expect(calls).toStrictEqual([
Expand All @@ -119,7 +121,7 @@ describe('signMessage', function () {
]);
});

it('throws `UserRejectedRequestError` if user denied the operation', async function () {
it('throws `UserRejectedRequestError` if user denied the operation', async () => {
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);

const { confirmDialogSpy } = await prepareSignMessageMock(account);
Expand All @@ -128,40 +130,19 @@ describe('signMessage', function () {

const request = {
chainId: constants.StarknetChainId.SN_SEPOLIA,
signerAddress: account.address,
address: account.address,
typedDataMessage: typedDataExample,
enableAuthorize: true,
};

await expect(signMessage(request, state)).rejects.toThrow(
await expect(signMessage.execute(request)).rejects.toThrow(
UserRejectedRequestError,
);
});

it('throws `InvalidParamsError` when request parameter is not correct', async function () {
const request = {
chainId: STARKNET_SEPOLIA_TESTNET_NETWORK.chainId,
};
it('throws `InvalidParamsError` when request parameter is not correct', async () => {
await expect(
signMessage(request as unknown as SignMessageParams, state),
signMessage.execute({} as unknown as SignMessageParams),
).rejects.toThrow(InvalidParamsError);
});

it('throws `Failed to sign the message` if another error was thrown', async function () {
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);

const { getKeysFromAddressSpy } = await prepareSignMessageMock(account);

getKeysFromAddressSpy.mockRejectedValue(new Error('some error'));

const request = {
chainId: constants.StarknetChainId.SN_SEPOLIA,
signerAddress: account.address,
typedDataMessage: typedDataExample,
};

await expect(signMessage(request, state)).rejects.toThrow(
'Failed to sign the message',
);
});
});
112 changes: 112 additions & 0 deletions packages/starknet-snap/src/rpcs/signMessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import type { Component } from '@metamask/snaps-sdk';
import {
heading,
row,
text,
UserRejectedRequestError,
} from '@metamask/snaps-sdk';
import type { Infer } from 'superstruct';
import { array, object, string, assign } from 'superstruct';

import {
confirmDialog,
AddressStruct,
toJson,
TypeDataStruct,
AuthorizableStruct,
BaseRequestStruct,
AccountRpcController,
} from '../utils';
import { signMessage as signMessageUtil } from '../utils/starknetUtils';

export const SignMessageRequestStruct = assign(
object({
address: AddressStruct,
typedDataMessage: TypeDataStruct,
}),
AuthorizableStruct,
BaseRequestStruct,
);

export const SignMessageResponseStruct = array(string());

export type SignMessageParams = Infer<typeof SignMessageRequestStruct>;

export type SignMessageResponse = Infer<typeof SignMessageResponseStruct>;

/**
* The RPC handler to sign a message.
*/
export class SignMessageRpc extends AccountRpcController<
SignMessageParams,
SignMessageResponse
> {
protected requestStruct = SignMessageRequestStruct;

protected responseStruct = SignMessageResponseStruct;

/**
* Execute the sign message request handler.
* It will show a confirmation dialog to the user before signing the message.
*
* @param params - The parameters of the request.
* @param params.address - The address of the signer.
* @param params.typedDataMessage - The Starknet type data message to sign.
* @param [params.enableAuthorize] - Optional, a flag to enable or display the confirmation dialog to the user.
* @param params.chainId - The chain id of the network.
* @returns the signature of the message in string array.
*/
async execute(params: SignMessageParams): Promise<SignMessageResponse> {
return super.execute(params);
}

protected async handleRequest(
params: SignMessageParams,
): Promise<SignMessageResponse> {
const { enableAuthorize, typedDataMessage, address } = params;
if (
// Get Starknet expected not to show the confirm dialog, therefore, `enableAuthorize` will set to false to bypass the confirmation
// TODO: enableAuthorize should set default to true
enableAuthorize &&
!(await this.getSignMessageConsensus(typedDataMessage, address))
) {
throw new UserRejectedRequestError() as unknown as Error;
}

return await signMessageUtil(
this.account.privateKey,
typedDataMessage,
address,
);
}

protected async getSignMessageConsensus(
typedDataMessage: Infer<typeof TypeDataStruct>,
address: string,
) {
const components: Component[] = [];
components.push(heading('Do you want to sign this message?'));
components.push(
row(
'Message',
text({
value: toJson(typedDataMessage),
markdown: false,
}),
),
);
components.push(
row(
'Signer Address',
text({
value: address,
markdown: false,
}),
),
);

return await confirmDialog(components);
}
}

export const signMessage = new SignMessageRpc();
Loading

0 comments on commit 1922348

Please sign in to comment.