Skip to content

Commit

Permalink
chore: relocate base RPC controllers (#444)
Browse files Browse the repository at this point in the history
* chore: add chain rpc controller

* chore: relocate base rpc controller
  • Loading branch information
stanleyyconsensys authored Nov 29, 2024
1 parent 2a37d50 commit 6d972dc
Show file tree
Hide file tree
Showing 18 changed files with 309 additions and 295 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { constants } from 'starknet';
import { object, string } from 'superstruct';
import type { Infer } from 'superstruct';

import type { StarknetAccount } from '../../__tests__/helper';
import { generateAccounts } from '../../__tests__/helper';
import type { SnapState } from '../../types/snapState';
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants';
import * as snapHelper from '../../utils/snap';
import * as snapUtils from '../../utils/snapUtils';
import * as starknetUtils from '../../utils/starknetUtils';
import { AccountRpcController } from './account-rpc-controller';

jest.mock('../../utils/snap');
jest.mock('../../utils/logger');

describe('AccountRpcController', () => {
const state: SnapState = {
accContracts: [],
erc20Tokens: [],
networks: [STARKNET_SEPOLIA_TESTNET_NETWORK],
transactions: [],
};

const RequestStruct = object({
address: string(),
chainId: string(),
});

type Request = Infer<typeof RequestStruct>;

class MockAccountRpc extends AccountRpcController<Request, string> {
protected requestStruct = RequestStruct;

protected responseStruct = string();

// Set it to public to be able to spy on it
async handleRequest(param: Request) {
return `done ${param.address} and ${param.chainId}`;
}
}

const mockAccount = async (network: constants.StarknetChainId) => {
const accounts = await generateAccounts(network, 1);
return accounts[0];
};

const prepareExecute = async (account: StarknetAccount) => {
const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn(
snapUtils,
'verifyIfAccountNeedUpgradeOrDeploy',
);

const getKeysFromAddressSpy = jest.spyOn(
starknetUtils,
'getKeysFromAddress',
);

const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData');

getStateDataSpy.mockResolvedValue(state);

getKeysFromAddressSpy.mockResolvedValue({
privateKey: account.privateKey,
publicKey: account.publicKey,
addressIndex: account.addressIndex,
derivationPath: account.derivationPath as unknown as any,
});

verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis();

return {
getKeysFromAddressSpy,
getStateDataSpy,
verifyIfAccountNeedUpgradeOrDeploySpy,
};
};

it('executes request', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
await prepareExecute(account);
const rpc = new MockAccountRpc();

const result = await rpc.execute({
address: account.address,
chainId,
});

expect(result).toBe(`done ${account.address} and ${chainId}`);
});

it('fetchs account before execute', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
const { getKeysFromAddressSpy } = await prepareExecute(account);
const rpc = new MockAccountRpc();

await rpc.execute({ address: account.address, chainId });

expect(getKeysFromAddressSpy).toHaveBeenCalled();
});

it.each([true, false])(
`assign verifyIfAccountNeedUpgradeOrDeploy's argument "showAlert" to %s if the constructor option 'showInvalidAccountAlert' is set to %s`,
async (showInvalidAccountAlert: boolean) => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
const { verifyIfAccountNeedUpgradeOrDeploySpy } = await prepareExecute(
account,
);
const rpc = new MockAccountRpc({
showInvalidAccountAlert,
});

await rpc.execute({ address: account.address, chainId });

expect(verifyIfAccountNeedUpgradeOrDeploySpy).toHaveBeenCalledWith(
expect.any(Object),
account.address,
account.publicKey,
showInvalidAccountAlert,
);
},
);
});
86 changes: 86 additions & 0 deletions packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import type { getBIP44ChangePathString } from '@metamask/key-tree/dist/types/utils';
import type { Json } from '@metamask/snaps-sdk';

import type { Network, SnapState } from '../../types/snapState';
import { getBip44Deriver, getStateData } from '../../utils';
import {
getNetworkFromChainId,
verifyIfAccountNeedUpgradeOrDeploy,
} from '../../utils/snapUtils';
import { getKeysFromAddress } from '../../utils/starknetUtils';
import { RpcController } from './base-rpc-controller';

export type AccountRpcParams = {
chainId: string;
address: string;
};

// TODO: the Account object should move into a account manager for generate account
export type Account = {
privateKey: string;
publicKey: string;
addressIndex: number;
// This is the derivation path of the address, it is used in `getNextAddressIndex` to find the account in state where matching the same derivation path
derivationPath: ReturnType<typeof getBIP44ChangePathString>;
};

export type AccountRpcControllerOptions = {
showInvalidAccountAlert: boolean;
};

/**
* A base class for rpc controllers that require account discovery.
*
* @template Request - The expected structure of the request parameters.
* @template Response - The expected structure of the response.
* @class AccountRpcController
*/
export abstract class AccountRpcController<
Request extends AccountRpcParams,
Response extends Json,
> extends RpcController<Request, Response> {
protected account: Account;

protected network: Network;

protected options: AccountRpcControllerOptions;

protected defaultOptions: AccountRpcControllerOptions = {
showInvalidAccountAlert: true,
};

constructor(options?: AccountRpcControllerOptions) {
super();
this.options = Object.assign({}, this.defaultOptions, options);
}

protected async preExecute(params: Request): Promise<void> {
await super.preExecute(params);

const { chainId, address } = params;
const { showInvalidAccountAlert } = this.options;

const deriver = await getBip44Deriver();
// TODO: Instead of getting the state directly, we should implement state management to consolidate the state fetching
const state = await getStateData<SnapState>();

// TODO: getNetworkFromChainId from state is still needed, due to it is supporting in get-starknet at this moment
this.network = getNetworkFromChainId(state, chainId);

// TODO: This method should be refactored to get the account from an account manager
this.account = await getKeysFromAddress(
deriver,
this.network,
state,
address,
);

// TODO: rename this method to verifyAccount
await verifyIfAccountNeedUpgradeOrDeploy(
this.network,
address,
this.account.publicKey,
showInvalidAccountAlert,
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { string } from 'superstruct';

import { RpcController } from './base-rpc-controller';

jest.mock('../../utils/logger');

describe('RpcController', () => {
class MockRpc extends RpcController<string, string> {
protected requestStruct = string();

protected responseStruct = string();

// Set it to public to be able to spy on it
async handleRequest(params: string) {
return `done ${params}`;
}
}

it('executes request', async () => {
const rpc = new MockRpc();

const result = await rpc.execute('test');

expect(result).toBe('done test');
});
});
51 changes: 51 additions & 0 deletions packages/starknet-snap/src/rpcs/abstract/base-rpc-controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import type { Json } from '@metamask/snaps-sdk';
import type { Struct } from 'superstruct';

import { logger, validateRequest, validateResponse } from '../../utils';

/**
* A base class for rpc controllers.
*
* @template Request - The expected structure of the request parameters.
* @template Response - The expected structure of the response.
* @class RpcController
*/
export abstract class RpcController<
Request extends Json,
Response extends Json,
> {
/**
* Superstruct for the request.
*/
protected abstract requestStruct: Struct;

/**
* Superstruct for the response.
*/
protected abstract responseStruct: Struct;

protected abstract handleRequest(params: Request): Promise<Response>;

protected async preExecute(params: Request): Promise<void> {
logger.info(`Request: ${JSON.stringify(params)}`);
validateRequest(params, this.requestStruct);
}

protected async postExecute(response: Response): Promise<void> {
logger.info(`Response: ${JSON.stringify(response)}`);
validateResponse(response, this.responseStruct);
}

/**
* A method to execute the rpc method.
*
* @param params - An struct contains the require parameter for the request.
* @returns A promise that resolves to an json.
*/
async execute(params: Request): Promise<Response> {
await this.preExecute(params);
const resp = await this.handleRequest(params);
await this.postExecute(resp);
return resp;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { Json } from '@metamask/snaps-sdk';
import { NetworkStateManager } from '../../state/network-state-manager';
import type { Network } from '../../types/snapState';
import { InvalidNetworkError } from '../../utils/exceptions';
import { RpcController } from '../../utils/rpc';
import { RpcController } from './base-rpc-controller';

/**
* A base class for all RPC controllers that require a chainId to be provided in the request parameters.
Expand Down
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/rpcs/declare-contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import {
mapDeprecatedParams,
UniversalDetailsStruct,
confirmDialog,
AccountRpcController,
signerUI,
networkUI,
rowUI,
Expand All @@ -19,6 +18,7 @@ import {
} from '../utils';
import { UserRejectedOpError } from '../utils/exceptions';
import { declareContract as declareContractUtil } from '../utils/starknetUtils';
import { AccountRpcController } from './abstract/account-rpc-controller';

// Define the DeclareContractRequestStruct
export const DeclareContractRequestStruct = assign(
Expand Down
7 changes: 2 additions & 5 deletions packages/starknet-snap/src/rpcs/display-private-key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ import {
renderDisplayPrivateKeyAlertUI,
renderDisplayPrivateKeyConfirmUI,
} from '../ui/utils';
import {
AccountRpcController,
AddressStruct,
BaseRequestStruct,
} from '../utils';
import { AddressStruct, BaseRequestStruct } from '../utils';
import { UserRejectedOpError } from '../utils/exceptions';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const DisplayPrivateKeyRequestStruct = assign(
object({
Expand Down
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/rpcs/estimate-fee.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import { FeeTokenUnit } from '../types/snapApi';
import {
AddressStruct,
BaseRequestStruct,
AccountRpcController,
UniversalDetailsStruct,
InvocationsStruct,
} from '../utils';
import { getEstimatedFees } from '../utils/starknetUtils';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const EstimateFeeRequestStruct = assign(
object({
Expand Down
4 changes: 2 additions & 2 deletions packages/starknet-snap/src/rpcs/execute-txn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ import { FeeToken } from '../types/snapApi';
import type { TransactionRequest } from '../types/snapState';
import { VoyagerTransactionType, type Transaction } from '../types/snapState';
import { generateExecuteTxnFlow } from '../ui/utils';
import type { AccountRpcControllerOptions } from '../utils';
import {
AddressStruct,
BaseRequestStruct,
AccountRpcController,
UniversalDetailsStruct,
CallsStruct,
mapDeprecatedParams,
Expand All @@ -30,6 +28,8 @@ import {
executeTxn as executeTxnUtil,
getEstimatedFees,
} from '../utils/starknetUtils';
import type { AccountRpcControllerOptions } from './abstract/account-rpc-controller';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const ExecuteTxnRequestStruct = assign(
object({
Expand Down
8 changes: 2 additions & 6 deletions packages/starknet-snap/src/rpcs/get-deployment-data.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import type { Infer } from 'superstruct';
import { object, string, assign, array } from 'superstruct';

import {
AddressStruct,
BaseRequestStruct,
AccountRpcController,
CairoVersionStruct,
} from '../utils';
import { AddressStruct, BaseRequestStruct, CairoVersionStruct } from '../utils';
import { ACCOUNT_CLASS_HASH, CAIRO_VERSION } from '../utils/constants';
import { AccountAlreadyDeployedError } from '../utils/exceptions';
import {
getDeployAccountCallData,
isAccountDeployed,
} from '../utils/starknetUtils';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const GetDeploymentDataRequestStruct = assign(
object({
Expand Down
Loading

0 comments on commit 6d972dc

Please sign in to comment.