Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable multi discovery strategy #492

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
DeployRequiredError,
UpgradeRequiredError,
} from '../../utils/exceptions';
import type { AccountDiscoveryType } from '../../utils/factory';
import { createAccountService } from '../../utils/factory';
import {
showDeployRequestModal,
Expand All @@ -15,6 +16,7 @@ import { ChainRpcController } from './chain-rpc-controller';
export type AccountRpcParams = {
chainId: string;
address: string;
accountDiscoveryType?: AccountDiscoveryType;
};

export type AccountRpcControllerOptions = {
Expand Down Expand Up @@ -55,8 +57,11 @@ export abstract class AccountRpcController<
protected async preExecute(params: Request): Promise<void> {
await super.preExecute(params);
const { address } = params;

const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);
this.account = await accountService.deriveAccountByAddress(address);

try {
Expand Down
17 changes: 13 additions & 4 deletions packages/starknet-snap/src/rpcs/add-account.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { type Infer } from 'superstruct';
import { assign, enums, object, optional, type Infer } from 'superstruct';

import { BaseRequestStruct, AccountStruct, logger } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { ChainRpcController } from './abstract/chain-rpc-controller';

export const AddAccountRequestStruct = BaseRequestStruct;
export const AddAccountRequestStruct = assign(
object({
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
BaseRequestStruct,
);

export const AddAccountResponseStruct = AccountStruct;

Expand Down Expand Up @@ -34,7 +39,11 @@ export class AddAccountRpc extends ChainRpcController<
// eslint-disable-next-line @typescript-eslint/no-unused-vars
params: AddAccountParams,
): Promise<AddAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

const account = await accountService.deriveAccountByIndex();

Expand Down
18 changes: 15 additions & 3 deletions packages/starknet-snap/src/rpcs/get-current-account.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import { boolean, object, optional, assign, type Infer } from 'superstruct';
import {
boolean,
object,
optional,
assign,
type Infer,
enums,
} from 'superstruct';

import { AccountStateManager } from '../state/account-state-manager';
import { BaseRequestStruct, AccountStruct } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { ChainRpcController } from './abstract/chain-rpc-controller';

export const GetCurrentAccountRequestStruct = assign(
BaseRequestStruct,
object({
fromState: optional(boolean()),
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
);

Expand Down Expand Up @@ -44,7 +52,11 @@ export class GetCurrentAccountRpc extends ChainRpcController<
protected async handleRequest(
params: GetCurrentAccountParams,
): Promise<GetCurrentAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

if (params.fromState) {
// Get the current account from the state if the flag is set.
Expand Down
11 changes: 8 additions & 3 deletions packages/starknet-snap/src/rpcs/switch-account.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { assign, object, type Infer } from 'superstruct';
import { assign, enums, object, optional, type Infer } from 'superstruct';

import { BaseRequestStruct, AccountStruct, AddressStruct } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const SwitchAccountRequestStruct = assign(
BaseRequestStruct,
object({
address: AddressStruct,
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
);

Expand Down Expand Up @@ -39,7 +40,11 @@ export class SwitchAccountRpc extends AccountRpcController<
protected async handleRequest(
params: SwitchAccountParams,
): Promise<SwitchAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

await accountService.switchAccount(params.chainId, this.account);

Expand Down
9 changes: 9 additions & 0 deletions packages/starknet-snap/src/utils/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,28 @@ export function createTransactionService(
});
}

export enum AccountDiscoveryType {
DEFAULT = 'DEFAULT',
ForceCairo0 = 'FORCE_CAIRO_0',
ForceCairo1 = 'FORCE_CAIRO_1',
}

/**
* Create a AccountService object.
*
* @param network - The network.
* @param [accountStateMgr] - The `AccountStateManager`.
* @param accountDiscoveryType
* @returns A AccountService object.
*/
export function createAccountService(
network: Network,
accountStateMgr?: AccountStateManager,
accountDiscoveryType?: AccountDiscoveryType,
): AccountService {
return new AccountService({
network,
accountStateMgr,
accountDiscoveryType,
});
}
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/wallet/account/account.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ export class Account {
// When a Account object discovery by the account service,
// it should already cached the status of requireDeploy and requireUpgrade.
const [upgradeRequired, deployRequired] = await Promise.all([
this.accountContract.isRequireDeploy(),
this.accountContract.isRequireUpgrade(),
this.accountContract.isRequireDeploy(),
]);
return {
addressSalt: this.publicKey,
Expand Down
18 changes: 17 additions & 1 deletion packages/starknet-snap/src/wallet/account/discovery.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Network } from '../../types/snapState';
import { AccountDiscoveryType } from '../../utils/factory';
import { Cairo0Contract } from './cairo0';
import { Cairo1Contract } from './cairo1';
import type { CairoAccountContract } from './contract';
Expand All @@ -16,8 +17,23 @@ export class AccountContractDiscovery {

protected network: Network;

constructor(network: Network) {
constructor(network: Network, discoveryType?: AccountDiscoveryType) {
this.network = network;
if (discoveryType !== undefined) {
switch (discoveryType) {
case AccountDiscoveryType.ForceCairo0:
this.contractCtors = [Cairo0Contract];
this.defaultContractCtor = Cairo0Contract;
break;
case AccountDiscoveryType.ForceCairo1:
this.contractCtors = [Cairo1Contract];
this.defaultContractCtor = Cairo1Contract;
break;
default:
this.contractCtors = [Cairo1Contract, Cairo0Contract];
this.defaultContractCtor = Cairo1Contract;
}
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions packages/starknet-snap/src/wallet/account/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
AccountNotFoundError,
MaxAccountLimitExceededError,
} from '../../utils/exceptions';
import { AccountDiscoveryType } from '../../utils/factory';
import { Account } from './account';
import { AccountContractDiscovery } from './discovery';
import { AccountKeyPair } from './keypair';
Expand All @@ -19,14 +20,17 @@ export class AccountService {
constructor({
network,
accountStateMgr = new AccountStateManager(),
accountDiscoveryType = AccountDiscoveryType.DEFAULT,
}: {
network: Network;
accountStateMgr?: AccountStateManager;
accountDiscoveryType?: AccountDiscoveryType;
}) {
this.network = network;
this.accountStateMgr = accountStateMgr;
this.accountContractDiscoveryService = new AccountContractDiscovery(
network,
accountDiscoveryType,
);
}

Expand Down
19 changes: 16 additions & 3 deletions packages/wallet-ui/src/services/useSnap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,22 @@ export const useSnap = () => {
};

const ping = async (): Promise<void> => {
await invokeSnap<null>({
method: 'ping',
});
// Extract the `accountDiscovery` parameter from the GET field (e.g., URL search params)
const urlParams = new URLSearchParams(window.location.search);
const accountDiscovery = urlParams.get('accountDiscovery');

if (accountDiscovery !== null) {
await invokeSnap<null>({
method: 'ping',
params: {
accountDiscoveryType: accountDiscovery,
},
});
} else {
await invokeSnap<null>({
method: 'ping',
});
}
};

return {
Expand Down
6 changes: 6 additions & 0 deletions packages/wallet-ui/src/services/useStarkNetSnap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,12 @@ export const useStarkNetSnap = () => {
const addNewAccount = async (chainId: string) => {
dispatch(enableLoadingWithMessage('Adding new account...'));
try {
const urlParams = new URLSearchParams(window.location.search);
const account = await invokeSnap<Account>({
method: 'starkNet_addAccount',
params: {
chainId,
accountDiscoveryType: urlParams.get('accountDiscovery'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do those inject from the invokeSnap function? i guess when force deploy/upgrade, u still need that? or no

},
});

Expand All @@ -792,10 +794,12 @@ export const useStarkNetSnap = () => {
};

const getCurrentAccount = async (chainId: string) => {
const urlParams = new URLSearchParams(window.location.search);
return await invokeSnap<Account>({
method: 'starkNet_getCurrentAccount',
params: {
chainId,
accountDiscoveryType: urlParams.get('accountDiscovery') ?? 'DEFAULT',
},
});
};
Expand All @@ -807,11 +811,13 @@ export const useStarkNetSnap = () => {
),
);
try {
const urlParams = new URLSearchParams(window.location.search);
const account = await invokeSnap<Account>({
method: 'starkNet_swtichAccount',
params: {
chainId,
address,
accountDiscoveryType: urlParams.get('accountDiscovery'),
},
});

Expand Down
Loading