diff --git a/packages/examples/packages/invoke-snap/packages/core-signer/package.json b/packages/examples/packages/invoke-snap/packages/core-signer/package.json index 1f11726d56..57edf5722f 100644 --- a/packages/examples/packages/invoke-snap/packages/core-signer/package.json +++ b/packages/examples/packages/invoke-snap/packages/core-signer/package.json @@ -47,7 +47,7 @@ "@metamask/snaps-sdk": "workspace:^", "@metamask/utils": "^10.0.0", "@noble/curves": "^1.1.0", - "async-mutex": "^0.4.0" + "async-mutex": "^0.5.0" }, "devDependencies": { "@jest/globals": "^29.5.0", diff --git a/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json b/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json index e90d64456b..1e901ffa31 100644 --- a/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json +++ b/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json @@ -7,7 +7,7 @@ "url": "https://github.com/MetaMask/snaps.git" }, "source": { - "shasum": "9+79ZuJLehTDLvMMK/dR0C29/5Q/GRdvTq8EaxTwQkU=", + "shasum": "5YpYX3b3wdRQEjPd2lUeNsNK7FwiflxMLCCPYtDeLnQ=", "location": { "npm": { "filePath": "dist/bundle.js", diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json index 357e17f097..7a462547c5 100644 --- a/packages/snaps-controllers/coverage.json +++ b/packages/snaps-controllers/coverage.json @@ -1,6 +1,6 @@ { - "branches": 93.06, - "functions": 96.54, - "lines": 98.02, - "statements": 97.74 + "branches": 92.96, + "functions": 96.56, + "lines": 98.05, + "statements": 97.77 } diff --git a/packages/snaps-controllers/package.json b/packages/snaps-controllers/package.json index d69e2c9b5e..79f6cb4056 100644 --- a/packages/snaps-controllers/package.json +++ b/packages/snaps-controllers/package.json @@ -95,6 +95,7 @@ "@metamask/snaps-utils": "workspace:^", "@metamask/utils": "^10.0.0", "@xstate/fsm": "^2.0.0", + "async-mutex": "^0.5.0", "browserify-zlib": "^0.2.0", "concat-stream": "^2.0.0", "fast-deep-equal": "^3.1.3", diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx index a99703d581..aaa463bfed 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx +++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx @@ -52,6 +52,7 @@ import { MOCK_SNAP_NAME, DEFAULT_SOURCE_PATH, DEFAULT_ICON_PATH, + TEST_SECRET_RECOVERY_PHRASE_BYTES, } from '@metamask/snaps-utils/test-utils'; import type { SemVerRange, SemVerVersion, Json } from '@metamask/utils'; import { @@ -60,6 +61,7 @@ import { AssertionError, base64ToBytes, stringToBytes, + createDeferredPromise, } from '@metamask/utils'; import { File } from 'buffer'; import { webcrypto } from 'crypto'; @@ -78,6 +80,7 @@ import { getNodeEESMessenger, getPersistedSnapsState, getSnapController, + getSnapControllerEncryptor, getSnapControllerMessenger, getSnapControllerOptions, getSnapControllerWithEES, @@ -97,6 +100,7 @@ import { MOCK_WALLET_SNAP_PERMISSION, MockSnapsRegistry, sleep, + waitForStateChange, } from '../test-utils'; import { delay } from '../utils'; import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants'; @@ -2117,6 +2121,59 @@ describe('SnapController', () => { await service.terminateAllSnaps(); }); + it('clears encrypted state of Snaps when the client is locked', async () => { + const rootMessenger = getControllerMessenger(); + const messenger = getSnapControllerMessenger(rootMessenger); + + const state = { myVariable: 1 }; + + const mockEncryptedState = await encrypt( + ENCRYPTION_KEY, + state, + undefined, + undefined, + DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, + ); + + const getMnemonic = jest + .fn() + .mockReturnValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); + + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: { + [MOCK_SNAP_ID]: getPersistedSnapObject(), + }, + snapStates: { + [MOCK_SNAP_ID]: mockEncryptedState, + }, + }, + getMnemonic, + }), + ); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual(state); + expect(getMnemonic).toHaveBeenCalledTimes(1); + + rootMessenger.publish('KeyringController:lock'); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual(state); + + // We assume `getMnemonic` is called again because the controller needs to + // decrypt the state again. This is not an ideal way to test this, but it + // is the easiest to test this without exposing the internal state of the + // `SnapController`. + expect(getMnemonic).toHaveBeenCalledTimes(2); + + snapController.destroy(); + }); + describe('handleRequest', () => { it.each( Object.keys(handlerEndowments).filter( @@ -8801,6 +8858,7 @@ describe('SnapController', () => { ); const newState = { myVariable: 2 }; + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', @@ -8817,6 +8875,8 @@ describe('SnapController', () => { DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, ); + await promise; + const result = await messenger.call( 'SnapController:getSnapState', MOCK_SNAP_ID, @@ -8831,7 +8891,7 @@ describe('SnapController', () => { snapController.destroy(); }); - it('different snaps use different encryption keys', async () => { + it('uses different encryption keys for different snaps', async () => { const messenger = getSnapControllerMessenger(); const state = { foo: 'bar' }; @@ -8857,6 +8917,8 @@ describe('SnapController', () => { true, ); + const promise = waitForStateChange(messenger); + await messenger.call( 'SnapController:updateSnapState', MOCK_LOCAL_SNAP_ID, @@ -8864,6 +8926,8 @@ describe('SnapController', () => { true, ); + await promise; + const encryptedState1 = await encrypt( ENCRYPTION_KEY, state, @@ -9073,6 +9137,8 @@ describe('SnapController', () => { undefined, DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, ); + + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9080,6 +9146,8 @@ describe('SnapController', () => { true, ); + await promise; + expect(updateSnapStateSpy).toHaveBeenCalledTimes(1); expect(snapController.state.snapStates[MOCK_SNAP_ID]).toStrictEqual( mockEncryptedState, @@ -9137,6 +9205,8 @@ describe('SnapController', () => { ); const state = { foo: 'bar' }; + + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9144,10 +9214,117 @@ describe('SnapController', () => { true, ); + await promise; + expect(pbkdf2Sha512).toHaveBeenCalledTimes(1); snapController.destroy(); }); + + it('queues multiple state updates', async () => { + const messenger = getSnapControllerMessenger(); + + jest.useFakeTimers(); + + const encryptor = getSnapControllerEncryptor(); + const { promise, resolve } = createDeferredPromise(); + const encryptWithKey = jest + .fn< + ReturnType, + Parameters + >() + .mockImplementation(async (...args) => { + resolve(); + await sleep(1); + return await encryptor.encryptWithKey(...args); + }); + + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + encryptor: { + ...getSnapControllerEncryptor(), + // @ts-expect-error - Missing required properties. + encryptWithKey, + }, + }), + ); + + const firstStateChange = waitForStateChange(messenger); + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { foo: 'bar' }, + true, + ); + + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { bar: 'baz' }, + true, + ); + + // We await this promise to ensure the timer is queued. + await promise; + jest.advanceTimersByTime(1); + + // After this point the second update should be queued. + await firstStateChange; + const secondStateChange = waitForStateChange(messenger); + + expect(encryptWithKey).toHaveBeenCalledTimes(1); + + // This is a bit hacky, but we can't simply advance the timer by 1ms + // because the second timer is not running yet. + jest.useRealTimers(); + await secondStateChange; + + expect(encryptWithKey).toHaveBeenCalledTimes(2); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual({ bar: 'baz' }); + + snapController.destroy(); + }); + + it('logs an error message if the state fails to persist', async () => { + const messenger = getSnapControllerMessenger(); + + const errorValue = new Error('Failed to persist state.'); + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + // @ts-expect-error - Missing required properties. + encryptor: { + ...getSnapControllerEncryptor(), + encryptWithKey: jest.fn().mockRejectedValue(errorValue), + }, + }), + ); + + const { promise, resolve } = createDeferredPromise(); + const error = jest.spyOn(console, 'error').mockImplementation(resolve); + + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { foo: 'bar' }, + true, + ); + + await promise; + expect(error).toHaveBeenCalledWith(errorValue); + + snapController.destroy(); + }); }); describe('SnapController:clearSnapState', () => { @@ -9206,6 +9383,41 @@ describe('SnapController', () => { snapController.destroy(); }); + + it('logs an error message if the state fails to persist', async () => { + const messenger = getSnapControllerMessenger(); + + const errorValue = new Error('Failed to persist state.'); + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + // @ts-expect-error - Missing required properties. + encryptor: { + ...getSnapControllerEncryptor(), + encryptWithKey: jest.fn().mockRejectedValue(errorValue), + }, + }), + ); + + const { promise, resolve } = createDeferredPromise(); + const error = jest.spyOn(console, 'error').mockImplementation(resolve); + + // @ts-expect-error - Property `update` is protected. + // eslint-disable-next-line jest/prefer-spy-on + snapController.update = jest.fn().mockImplementation(() => { + throw errorValue; + }); + + await messenger.call('SnapController:clearSnapState', MOCK_SNAP_ID, true); + + await promise; + expect(error).toHaveBeenCalledWith(errorValue); + + snapController.destroy(); + }); }); describe('SnapController:updateBlockedSnaps', () => { diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts index 173e8ba0d1..655c1bcdec 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.ts +++ b/packages/snaps-controllers/src/snaps/SnapController.ts @@ -109,6 +109,7 @@ import { } from '@metamask/utils'; import type { StateMachine } from '@xstate/fsm'; import { createMachine, interpret } from '@xstate/fsm'; +import { Mutex } from 'async-mutex'; import type { Patch } from 'immer'; import { nanoid } from 'nanoid'; import semver from 'semver'; @@ -252,6 +253,21 @@ export interface SnapRuntimeData { * A boolean flag to determine whether the Snap is currently being stopped. */ stopping: boolean; + + /** + * Cached encrypted state of the Snap. + */ + state?: Record | null; + + /** + * Cached unencrypted state of the Snap. + */ + unencryptedState?: Record | null; + + /** + * A mutex to prevent concurrent state updates. + */ + stateMutex: Mutex; } export type SnapError = { @@ -552,6 +568,11 @@ export type SnapControllerStateChangeEvent = ControllerStateChangeEvent< SnapControllerState >; +type KeyringControllerLock = { + type: 'KeyringController:lock'; + payload: []; +}; + export type SnapControllerEvents = | SnapBlocked | SnapInstalled @@ -596,7 +617,8 @@ export type AllowedActions = export type AllowedEvents = | ExecutionServiceEvents | SnapInstalled - | SnapUpdated; + | SnapUpdated + | KeyringControllerLock; type SnapControllerMessenger = RestrictedControllerMessenger< typeof controllerName, @@ -906,6 +928,7 @@ export class SnapController extends BaseController< this._onOutboundResponse = this._onOutboundResponse.bind(this); this.#rollbackSnapshots = new Map(); this.#snapsRuntimeData = new Map(); + this.#pollForLastRequestStatus(); /* eslint-disable @typescript-eslint/unbound-method */ @@ -955,6 +978,11 @@ export class SnapController extends BaseController< }, ); + this.messagingSystem.subscribe( + 'KeyringController:lock', + this.#handleLock.bind(this), + ); + this.#initializeStateMachine(); this.#registerMessageHandlers(); @@ -1820,6 +1848,7 @@ export class SnapController extends BaseController< const useCache = this.#hasCachedEncryptionKey(snapId) || this.#encryptor.isVaultUpdated(state); + const { key } = await this.#getSnapEncryptionKey({ snapId, salt, @@ -1860,6 +1889,73 @@ export class SnapController extends BaseController< return JSON.stringify(encryptedState); } + /** + * Get the new Snap state to persist based on the given state and encryption + * flag. + * + * - If the state is null, return null. + * - If the state should be encrypted, return the encrypted state. + * - Otherwise, if the state should not be encrypted, return the JSON- + * stringified state. + * + * @param snapId - The Snap ID. + * @param state - The state to persist. + * @param encrypted - A flag to indicate whether to use encrypted storage or + * not. + * @returns The state to persist. + */ + async #getStateToPersist( + snapId: SnapId, + state: Record | null, + encrypted: boolean, + ) { + if (state === null) { + return null; + } + + if (encrypted) { + return await this.#encryptSnapState(snapId, state); + } + + return JSON.stringify(state); + } + + /** + * Persist the state of a Snap. + * + * This is run with a mutex to ensure that only one state update per Snap is + * processed at a time, avoiding possible race conditions. + * + * @param snapId - The Snap ID. + * @param newSnapState - The new state of the Snap. + * @param encrypted - A flag to indicate whether to use encrypted storage or + * not. + */ + async #persistSnapState( + snapId: SnapId, + newSnapState: Record | null, + encrypted: boolean, + ) { + const runtime = this.#getRuntimeExpect(snapId); + await runtime.stateMutex.runExclusive(async () => { + const newState = await this.#getStateToPersist( + snapId, + newSnapState, + encrypted, + ); + + if (encrypted) { + return this.update((state) => { + state.snapStates[snapId] = newState; + }); + } + + return this.update((state) => { + state.unencryptedSnapStates[snapId] = newState; + }); + }); + } + /** * Updates the own state of the snap with the given id. * This is distinct from the state MetaMask uses to manage snaps. @@ -1873,17 +1969,19 @@ export class SnapController extends BaseController< newSnapState: Record, encrypted: boolean, ) { - if (encrypted) { - const encryptedState = await this.#encryptSnapState(snapId, newSnapState); + const runtime = this.#getRuntimeExpect(snapId); - this.update((state) => { - state.snapStates[snapId] = encryptedState; - }); + if (encrypted) { + runtime.state = newSnapState; } else { - this.update((state) => { - state.unencryptedSnapStates[snapId] = JSON.stringify(newSnapState); - }); + runtime.unencryptedState = newSnapState; } + + // This is intentionally run asynchronously to avoid blocking the main + // thread. + this.#persistSnapState(snapId, newSnapState, encrypted).catch((error) => { + logError(error); + }); } /** @@ -1894,12 +1992,17 @@ export class SnapController extends BaseController< * @param encrypted - A flag to indicate whether to use encrypted storage or not. */ clearSnapState(snapId: SnapId, encrypted: boolean) { - this.update((state) => { - if (encrypted) { - state.snapStates[snapId] = null; - } else { - state.unencryptedSnapStates[snapId] = null; - } + const runtime = this.#getRuntimeExpect(snapId); + if (encrypted) { + runtime.state = null; + } else { + runtime.unencryptedState = null; + } + + // This is intentionally run asynchronously to avoid blocking the main + // thread. + this.#persistSnapState(snapId, null, encrypted).catch((error) => { + logError(error); }); } @@ -1912,6 +2015,13 @@ export class SnapController extends BaseController< * @returns The requested snap state or null if no state exists. */ async getSnapState(snapId: SnapId, encrypted: boolean): Promise { + const runtime = this.#getRuntimeExpect(snapId); + const cachedState = encrypted ? runtime.state : runtime.unencryptedState; + + if (cachedState !== undefined) { + return cachedState; + } + const state = encrypted ? this.state.snapStates[snapId] : this.state.unencryptedSnapStates[snapId]; @@ -1921,11 +2031,17 @@ export class SnapController extends BaseController< } if (!encrypted) { - // For performance reasons, we do not validate that the state is JSON, since we control serialization. - return JSON.parse(state); + // For performance reasons, we do not validate that the state is JSON, + // since we control serialization. + const json = JSON.parse(state); + runtime.unencryptedState = json; + + return json; } const decrypted = await this.#decryptSnapState(snapId, state); + runtime.state = decrypted; + return decrypted; } @@ -3706,6 +3822,7 @@ export class SnapController extends BaseController< pendingOutboundRequests: 0, interpreter, stopping: false, + stateMutex: new Mutex(), }); } @@ -3913,4 +4030,17 @@ export class SnapController extends BaseController< }, }); } + + /** + * Handle the `KeyringController:lock` event. + * + * Currently this clears the cached encrypted state (if any) for all Snaps. + */ + #handleLock() { + for (const runtime of this.#snapsRuntimeData.values()) { + runtime.encryptionKey = null; + runtime.encryptionSalt = null; + runtime.state = undefined; + } + } } diff --git a/packages/snaps-controllers/src/test-utils/controller.ts b/packages/snaps-controllers/src/test-utils/controller.ts index d1d12b701c..c4b33a9f8b 100644 --- a/packages/snaps-controllers/src/test-utils/controller.ts +++ b/packages/snaps-controllers/src/test-utils/controller.ts @@ -1,4 +1,8 @@ import type { ApprovalRequest } from '@metamask/approval-controller'; +import type { + ControllerMessenger, + RestrictedControllerMessenger, +} from '@metamask/base-controller'; import { encryptWithKey, decryptWithKey, @@ -48,16 +52,17 @@ import type { SnapInterfaceControllerEvents, StoredInterface, } from '../interface/SnapInterfaceController'; +import { SnapController } from '../snaps'; import type { AllowedActions, AllowedEvents, PersistedSnapControllerState, SnapControllerActions, SnapControllerEvents, + SnapControllerStateChangeEvent, SnapsRegistryActions, SnapsRegistryEvents, } from '../snaps'; -import { SnapController } from '../snaps'; import type { KeyDerivationOptions } from '../types'; import { MOCK_CRONJOB_PERMISSION } from './cronjob'; import { getNodeEES, getNodeEESMessenger } from './execution-environment'; @@ -462,6 +467,7 @@ export const getSnapControllerMessenger = ( 'SnapController:snapUpdated', 'SnapController:stateChange', 'SnapController:snapRolledback', + 'KeyringController:lock', ], allowedActions: [ 'ApprovalController:addRequest', @@ -535,7 +541,7 @@ export const DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS = { }, }; -const getSnapControllerEncryptor = () => { +export const getSnapControllerEncryptor = () => { return { encryptWithKey, decryptWithKey, @@ -830,3 +836,27 @@ export const getRestrictedSnapInsightsControllerMessenger = ( return controllerMessenger; }; + +/** + * Wait for the state change event to be emitted by the messenger. + * + * @param messenger - The messenger to listen to. + * @returns A promise that resolves when the state change event is emitted. + */ +export async function waitForStateChange( + messenger: + | ControllerMessenger + | RestrictedControllerMessenger< + 'SnapController', + any, + SnapControllerStateChangeEvent, + any, + 'SnapController:stateChange' + >, +) { + return new Promise((resolve) => { + messenger.subscribe('SnapController:stateChange', () => { + resolve(); + }); + }); +} diff --git a/yarn.lock b/yarn.lock index 8d3f030af5..94d0ac018f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4204,7 +4204,7 @@ __metadata: "@swc/jest": "npm:^0.2.26" "@typescript-eslint/eslint-plugin": "npm:^5.42.1" "@typescript-eslint/parser": "npm:^6.21.0" - async-mutex: "npm:^0.4.0" + async-mutex: "npm:^0.5.0" deepmerge: "npm:^4.2.2" depcheck: "npm:^1.4.7" eslint: "npm:^8.27.0" @@ -5752,6 +5752,7 @@ __metadata: "@wdio/spec-reporter": "npm:^8.19.0" "@wdio/static-server-service": "npm:^8.19.0" "@xstate/fsm": "npm:^2.0.0" + async-mutex: "npm:^0.5.0" browserify-zlib: "npm:^0.2.0" concat-stream: "npm:^2.0.0" deepmerge: "npm:^4.2.2" @@ -9645,12 +9646,12 @@ __metadata: languageName: node linkType: hard -"async-mutex@npm:^0.4.0": - version: 0.4.0 - resolution: "async-mutex@npm:0.4.0" +"async-mutex@npm:^0.5.0": + version: 0.5.0 + resolution: "async-mutex@npm:0.5.0" dependencies: tslib: "npm:^2.4.0" - checksum: 10/4a55065aae8c7283e45e2a8ac38ba9812f030696640d650c4ec62cfd67e5d61bd698e67b758a81fcb845e2d5ea1d857106f9235cc4282ad40cd1944b26fde1b2 + checksum: 10/4c6bfce1cc9cd43f723c4d96403ac5f4757f885c953b839cde6956ec8817ff39623b82d67614de10c7933e21626925882fb9bac367db7d15d7cb4f84228722c9 languageName: node linkType: hard