Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions app/scripts/lib/unlock-wrapper.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import {
withUnlockPrompt,
createUnlockedMethodWrappers,
} from './unlock-wrapper';

describe('unlock-wrapper', () => {
describe('withUnlockPrompt', () => {
it('should call the function immediately if wallet is unlocked', async () => {
const mockFn = jest.fn().mockResolvedValue('result');
const mockGetUnlockPromise = jest.fn();
const mockIsUnlocked = jest.fn().mockReturnValue(true);

const wrappedFn = withUnlockPrompt(
mockFn,
mockGetUnlockPromise,
mockIsUnlocked,
);

const result = await wrappedFn('arg1', 'arg2');

expect(result).toBe('result');
expect(mockFn).toHaveBeenCalledWith('arg1', 'arg2');
expect(mockGetUnlockPromise).not.toHaveBeenCalled();
expect(mockIsUnlocked).toHaveBeenCalled();
});

it('should wait for unlock if wallet is locked', async () => {
const mockFn = jest.fn().mockResolvedValue('result');
const mockGetUnlockPromise = jest.fn().mockResolvedValue(undefined);
const mockIsUnlocked = jest.fn().mockReturnValue(false);

const wrappedFn = withUnlockPrompt(
mockFn,
mockGetUnlockPromise,
mockIsUnlocked,
);

const result = await wrappedFn('arg1', 'arg2');

expect(result).toBe('result');
expect(mockGetUnlockPromise).toHaveBeenCalledWith(true);
expect(mockFn).toHaveBeenCalledWith('arg1', 'arg2');
});

it('should pass shouldShowUnlockRequest as true to trigger popup', async () => {
const mockFn = jest.fn().mockResolvedValue('result');
const mockGetUnlockPromise = jest.fn().mockResolvedValue(undefined);
const mockIsUnlocked = jest.fn().mockReturnValue(false);

const wrappedFn = withUnlockPrompt(
mockFn,
mockGetUnlockPromise,
mockIsUnlocked,
);

await wrappedFn();

expect(mockGetUnlockPromise).toHaveBeenCalledWith(true);
});

it('should propagate errors from the wrapped function', async () => {
const mockError = new Error('test error');
const mockFn = jest.fn().mockRejectedValue(mockError);
const mockGetUnlockPromise = jest.fn();
const mockIsUnlocked = jest.fn().mockReturnValue(true);

const wrappedFn = withUnlockPrompt(
mockFn,
mockGetUnlockPromise,
mockIsUnlocked,
);

await expect(wrappedFn()).rejects.toThrow('test error');
expect(mockFn).toHaveBeenCalled();
});

it('should preserve function context and arguments', async () => {
const mockFn = jest.fn(async (a, b, c) => `${a}-${b}-${c}`);
const mockGetUnlockPromise = jest.fn();
const mockIsUnlocked = jest.fn().mockReturnValue(true);

const wrappedFn = withUnlockPrompt(
mockFn,
mockGetUnlockPromise,
mockIsUnlocked,
);

const result = await wrappedFn('one', 'two', 'three');

expect(result).toBe('one-two-three');
expect(mockFn).toHaveBeenCalledWith('one', 'two', 'three');
});
});

describe('createUnlockedMethodWrappers', () => {
it('should create wrapper with bound methods', () => {
const mockAppStateController = {
getUnlockPromise: jest.fn(),
};
const mockKeyringController = {
state: {
isUnlocked: true,
},
};

const { wrapWithUnlock } = createUnlockedMethodWrappers({
appStateController: mockAppStateController,
keyringController: mockKeyringController,
});

expect(wrapWithUnlock).toBeInstanceOf(Function);
});

it('should correctly wrap methods with unlock logic', async () => {
const mockGetUnlockPromise = jest.fn().mockResolvedValue(undefined);
const mockAppStateController = {
getUnlockPromise: mockGetUnlockPromise,
};
const mockKeyringController = {
state: {
isUnlocked: false,
},
};

const { wrapWithUnlock } = createUnlockedMethodWrappers({
appStateController: mockAppStateController,
keyringController: mockKeyringController,
});

const mockFn = jest.fn().mockResolvedValue('wrapped-result');
const wrappedFn = wrapWithUnlock(mockFn);

const result = await wrappedFn('test-arg');

expect(result).toBe('wrapped-result');
expect(mockGetUnlockPromise).toHaveBeenCalledWith(true);
expect(mockFn).toHaveBeenCalledWith('test-arg');
});

it('should use keyring state to check unlock status', async () => {
const mockAppStateController = {
getUnlockPromise: jest.fn(),
};
const mockKeyringController = {
state: {
isUnlocked: true,
},
};

const { wrapWithUnlock } = createUnlockedMethodWrappers({
appStateController: mockAppStateController,
keyringController: mockKeyringController,
});

const mockFn = jest.fn().mockResolvedValue('result');
const wrappedFn = wrapWithUnlock(mockFn);

await wrappedFn();

expect(mockAppStateController.getUnlockPromise).not.toHaveBeenCalled();
expect(mockFn).toHaveBeenCalled();
});
});
});
63 changes: 63 additions & 0 deletions app/scripts/lib/unlock-wrapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/**
* Utility to wrap methods that require the wallet to be unlocked.
* If the wallet is locked, it will trigger the unlock prompt and wait for unlock.
*/

import type { AppStateController } from '../controllers/app-state-controller';
import type { KeyringController } from '@metamask/keyring-controller';

/**
* Wraps a function to ensure the wallet is unlocked before execution.
* If locked, triggers the unlock prompt and waits for the user to unlock.
*
* @param fn - The function to wrap
* @param getUnlockPromise - Function to get the unlock promise from AppStateController
* @param isUnlocked - Function to check if the wallet is unlocked
* @returns A wrapped function that waits for unlock before executing
*/
export function withUnlockPrompt<T extends (...args: any[]) => Promise<any>>(
fn: T,
getUnlockPromise: (shouldShowUnlockRequest: boolean) => Promise<void>,
isUnlocked: () => boolean,
): T {
return (async (...args: Parameters<T>): Promise<ReturnType<T>> => {
if (!isUnlocked()) {
await getUnlockPromise(true);
}

return fn(...args);
}) as T;
}

/**
* Creates wrapped versions of signature/decrypt/encryption methods
* that wait for unlock before processing.
*
* @param options - Configuration options
* @param options.appStateController - The AppStateController instance
* @param options.keyringController - The KeyringController instance
* @returns Object containing wrapped methods
*/
export function createUnlockedMethodWrappers({
appStateController,
keyringController,
}: {
appStateController: AppStateController;
keyringController: KeyringController;
}) {
const getUnlockPromise =
appStateController.getUnlockPromise.bind(appStateController);
const isUnlocked = () => keyringController.state.isUnlocked;

return {
/**
* Wraps a method to ensure unlock before execution
*
* @param fn - The function to wrap
* @returns Wrapped function
*/
wrapWithUnlock: <T extends (...args: any[]) => Promise<any>>(fn: T): T => {
return withUnlockPrompt(fn, getUnlockPromise, isUnlocked);
},
};
}
25 changes: 19 additions & 6 deletions app/scripts/metamask-controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ import { getAccountsBySnapId } from './lib/snap-keyring';
///: END:ONLY_INCLUDE_IF
import { addDappTransaction, addTransaction } from './lib/transaction/util';
import { addTypedMessage, addPersonalMessage } from './lib/signature/util';
import { createUnlockedMethodWrappers } from './lib/unlock-wrapper';
import {
METAMASK_CAIP_MULTICHAIN_PROVIDER,
METAMASK_COOKIE_HANDLER,
Expand Down Expand Up @@ -941,6 +942,12 @@ export default class MetamaskController extends EventEmitter {
),
});

// Create unlock wrapper for methods that require wallet to be unlocked
const { wrapWithUnlock } = createUnlockedMethodWrappers({
appStateController: this.appStateController,
keyringController: this.keyringController,
});

this.metamaskMiddleware = createMetamaskMiddleware({
static: {
eth_syncing: false,
Expand All @@ -956,36 +963,42 @@ export default class MetamaskController extends EventEmitter {
),
// msg signing

processTypedMessage: (...args) =>
processTypedMessage: wrapWithUnlock((...args) =>
addTypedMessage({
signatureController: this.signatureController,
signatureParams: args,
}),
processTypedMessageV3: (...args) =>
),
processTypedMessageV3: wrapWithUnlock((...args) =>
addTypedMessage({
signatureController: this.signatureController,
signatureParams: args,
}),
processTypedMessageV4: (...args) =>
),
processTypedMessageV4: wrapWithUnlock((...args) =>
addTypedMessage({
signatureController: this.signatureController,
signatureParams: args,
}),
processPersonalMessage: (...args) =>
),
processPersonalMessage: wrapWithUnlock((...args) =>
addPersonalMessage({
signatureController: this.signatureController,
signatureParams: args,
}),
),

processEncryptionPublicKey:
processEncryptionPublicKey: wrapWithUnlock(
this.encryptionPublicKeyController.newRequestEncryptionPublicKey.bind(
this.encryptionPublicKeyController,
),
),

processDecryptMessage:
processDecryptMessage: wrapWithUnlock(
this.decryptMessageController.newRequestDecryptMessage.bind(
this.decryptMessageController,
),
),
getPendingNonce: this.getPendingNonce.bind(this),
getPendingTransactionByHash: (hash) =>
this.txController.state.transactions.find(
Expand Down
Loading
Loading