Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable multi discovery strategy #492

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
DeployRequiredError,
UpgradeRequiredError,
} from '../../utils/exceptions';
import type { AccountDiscoveryType } from '../../utils/factory';
import { createAccountService } from '../../utils/factory';
import {
showDeployRequestModal,
Expand All @@ -15,6 +16,7 @@ import { ChainRpcController } from './chain-rpc-controller';
export type AccountRpcParams = {
chainId: string;
address: string;
accountDiscoveryType?: AccountDiscoveryType;
};

export type AccountRpcControllerOptions = {
Expand Down Expand Up @@ -55,8 +57,11 @@ export abstract class AccountRpcController<
protected async preExecute(params: Request): Promise<void> {
await super.preExecute(params);
const { address } = params;

const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);
this.account = await accountService.deriveAccountByAddress(address);

try {
Expand Down
17 changes: 13 additions & 4 deletions packages/starknet-snap/src/rpcs/add-account.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { type Infer } from 'superstruct';
import { assign, enums, object, optional, type Infer } from 'superstruct';

import { BaseRequestStruct, AccountStruct, logger } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { ChainRpcController } from './abstract/chain-rpc-controller';

export const AddAccountRequestStruct = BaseRequestStruct;
export const AddAccountRequestStruct = assign(
object({
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
BaseRequestStruct,
);

export const AddAccountResponseStruct = AccountStruct;

Expand Down Expand Up @@ -34,7 +39,11 @@ export class AddAccountRpc extends ChainRpcController<
// eslint-disable-next-line @typescript-eslint/no-unused-vars
params: AddAccountParams,
): Promise<AddAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

const account = await accountService.deriveAccountByIndex();

Expand Down
18 changes: 15 additions & 3 deletions packages/starknet-snap/src/rpcs/get-current-account.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import { boolean, object, optional, assign, type Infer } from 'superstruct';
import {
boolean,
object,
optional,
assign,
type Infer,
enums,
} from 'superstruct';

import { AccountStateManager } from '../state/account-state-manager';
import { BaseRequestStruct, AccountStruct } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { ChainRpcController } from './abstract/chain-rpc-controller';

export const GetCurrentAccountRequestStruct = assign(
BaseRequestStruct,
object({
fromState: optional(boolean()),
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
);

Expand Down Expand Up @@ -44,7 +52,11 @@ export class GetCurrentAccountRpc extends ChainRpcController<
protected async handleRequest(
params: GetCurrentAccountParams,
): Promise<GetCurrentAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

if (params.fromState) {
// Get the current account from the state if the flag is set.
Expand Down
11 changes: 8 additions & 3 deletions packages/starknet-snap/src/rpcs/switch-account.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { assign, object, type Infer } from 'superstruct';
import { assign, enums, object, optional, type Infer } from 'superstruct';

import { BaseRequestStruct, AccountStruct, AddressStruct } from '../utils';
import { createAccountService } from '../utils/factory';
import { AccountDiscoveryType, createAccountService } from '../utils/factory';
import { AccountRpcController } from './abstract/account-rpc-controller';

export const SwitchAccountRequestStruct = assign(
BaseRequestStruct,
object({
address: AddressStruct,
accountDiscoveryType: optional(enums(Object.values(AccountDiscoveryType))),
}),
);

Expand Down Expand Up @@ -39,7 +40,11 @@ export class SwitchAccountRpc extends AccountRpcController<
protected async handleRequest(
params: SwitchAccountParams,
): Promise<SwitchAccountResponse> {
const accountService = createAccountService(this.network);
const accountService = createAccountService(
this.network,
undefined,
params.accountDiscoveryType,
);

await accountService.switchAccount(params.chainId, this.account);

Expand Down
9 changes: 9 additions & 0 deletions packages/starknet-snap/src/utils/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,28 @@ export function createTransactionService(
});
}

export enum AccountDiscoveryType {
DEFAULT = 'DEFAULT',
ForceCairo0 = 'FORCE_CAIRO_0',
ForceCairo1 = 'FORCE_CAIRO_1',
}

/**
* Create a AccountService object.
*
* @param network - The network.
* @param [accountStateMgr] - The `AccountStateManager`.
* @param accountDiscoveryType
* @returns A AccountService object.
*/
export function createAccountService(
network: Network,
accountStateMgr?: AccountStateManager,
accountDiscoveryType?: AccountDiscoveryType,
): AccountService {
return new AccountService({
network,
accountStateMgr,
accountDiscoveryType,
});
}
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/wallet/account/account.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ export class Account {
// When a Account object discovery by the account service,
// it should already cached the status of requireDeploy and requireUpgrade.
const [upgradeRequired, deployRequired] = await Promise.all([
this.accountContract.isRequireDeploy(),
this.accountContract.isRequireUpgrade(),
this.accountContract.isRequireDeploy(),
]);
return {
addressSalt: this.publicKey,
Expand Down
16 changes: 15 additions & 1 deletion packages/starknet-snap/src/wallet/account/discovery.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Network } from '../../types/snapState';
import { AccountDiscoveryType } from '../../utils/factory';
import { Cairo0Contract } from './cairo0';
import { Cairo1Contract } from './cairo1';
import type { CairoAccountContract } from './contract';
Expand All @@ -16,8 +17,21 @@ export class AccountContractDiscovery {

protected network: Network;

constructor(network: Network) {
constructor(network: Network, discoveryType?: AccountDiscoveryType) {
this.network = network;
if (discoveryType !== undefined) {
switch (discoveryType) {
case AccountDiscoveryType.ForceCairo1:
this.contractCtors = [Cairo1Contract];
this.defaultContractCtor = Cairo1Contract;
break;
// We default te Cairo0 discovery
default:
this.contractCtors = [Cairo0Contract];
this.defaultContractCtor = Cairo0Contract;
break;
}
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions packages/starknet-snap/src/wallet/account/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
AccountNotFoundError,
MaxAccountLimitExceededError,
} from '../../utils/exceptions';
import { AccountDiscoveryType } from '../../utils/factory';
import { Account } from './account';
import { AccountContractDiscovery } from './discovery';
import { AccountKeyPair } from './keypair';
Expand All @@ -19,14 +20,17 @@ export class AccountService {
constructor({
network,
accountStateMgr = new AccountStateManager(),
accountDiscoveryType = AccountDiscoveryType.DEFAULT,
}: {
network: Network;
accountStateMgr?: AccountStateManager;
accountDiscoveryType?: AccountDiscoveryType;
}) {
this.network = network;
this.accountStateMgr = accountStateMgr;
this.accountContractDiscoveryService = new AccountContractDiscovery(
network,
accountDiscoveryType,
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export const Wrapper = styled.div`
export const ColMiddle = styled.div`
width: 1040px;
margin: auto;

margin-top: 40px;
@media (max-width: 1024px) {
width: 896px;
}
Expand All @@ -34,14 +34,16 @@ export const Content = styled.div`
export const Banner = styled.div`
position: fixed;
left: 0px;
bottom: 0px;
top: 0px;
width: 100%;
background-color: ${(props) => props.theme.palette.primary.main};
background-color: red;
margin-bottom: 24px;
color: ${(props) => props.theme.palette.grey.grey3};
display: flex;
align-items: center;
padding: 13px 24px;
justify-content: space-between;
padding: 10px 24px;
justify-content: space-around;
z-index: 2;
`;

export const CloseIcon = styled(FontAwesomeIcon)`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ReactNode, useState } from 'react';
import { ReactNode } from 'react';
import { Footer } from 'components/ui/organism/Footer';
import { Button, Stack, Typography } from '@mui/material'; // Importing MUI components
import {
Banner,
CloseIcon,
ColMiddle,
Content,
MenuStyled,
Expand All @@ -15,20 +15,72 @@ interface Props {
}

export const FrameworkView = ({ connected, children }: Props) => {
const [bannerOpen, setBannerOpen] = useState(true);
// Get the current `accountDiscovery` value from the URL
const urlParams = new URLSearchParams(window.location.search);
const accountDiscovery = urlParams.get('accountDiscovery') ?? 'FORCE_CAIRO_0';

const bannerMessage =
accountDiscovery === 'FORCE_CAIRO_1'
? 'This is a special version for recovering funds on a Cairo 1 account.'
: 'This is a special version for recovering funds on a Cairo 0 account.';

const handleAccountChange = (version: string) => {
// Update the URL without reloading the page
const newParams = new URLSearchParams(window.location.search);
newParams.set('accountDiscovery', version);
window.history.replaceState(
{},
'',
`${window.location.pathname}?${newParams.toString()}`,
);
window.location.reload(); // Reload to apply the updated query parameter
};

return (
<Wrapper>
<ColMiddle>
<MenuStyled connected={connected} />
<Content>{children}</Content>
<Footer />
</ColMiddle>
{bannerOpen && (
<Banner>
This is the Open Beta version of the dapp, updates are made regularly{' '}
<CloseIcon icon={'close'} onClick={() => setBannerOpen(false)} />
</Banner>
)}
<Banner>
<Typography variant="body1" sx={{ mb: 2 }}>
{bannerMessage}, click{' '}
<a
target="_blank"
href="https://github.com/Consensys/starknet-snap/blob/main/docs/tutorial-resolving-stuck-funds.md"
rel="noreferrer"
>
here
</a>{' '}
to access the tutorial
</Typography>
<Stack
direction="row"
spacing={2}
justifyContent="center"
alignItems="center"
>
<Button
variant={
accountDiscovery === 'FORCE_CAIRO_1' ? 'contained' : 'outlined'
}
color="warning"
onClick={() => handleAccountChange('FORCE_CAIRO_1')}
>
Force Cairo 1
</Button>
<Button
variant={
accountDiscovery === 'FORCE_CAIRO_0' ? 'contained' : 'outlined'
}
color="warning"
onClick={() => handleAccountChange('FORCE_CAIRO_0')}
>
Force Cairo 0
</Button>
</Stack>
</Banner>
</Wrapper>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ export const AccountSwitchModalView = ({

return (
<Menu as="div" style={{ display: 'inline-block', position: 'relative' }}>
<Menu.Button style={{ background: 'none', border: 'none' }}>
<Wrapper backgroundTransparent iconRight="angle-down">
<Menu.Button style={{ background: 'none', border: 'none' }} disabled>
<Wrapper backgroundTransparent>
{full
? starkName ?? currentAddress
: starkName
Expand Down
19 changes: 16 additions & 3 deletions packages/wallet-ui/src/services/useSnap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,22 @@ export const useSnap = () => {
};

const ping = async (): Promise<void> => {
await invokeSnap<null>({
method: 'ping',
});
// Extract the `accountDiscovery` parameter from the GET field (e.g., URL search params)
const urlParams = new URLSearchParams(window.location.search);
const accountDiscovery = urlParams.get('accountDiscovery');

if (accountDiscovery !== null) {
await invokeSnap<null>({
method: 'ping',
params: {
accountDiscoveryType: accountDiscovery,
},
});
} else {
await invokeSnap<null>({
method: 'ping',
});
}
};

return {
Expand Down
Loading
Loading