From 6861362404393664e2c7eb97b458be27ded144a1 Mon Sep 17 00:00:00 2001 From: Maarten Zuidhoorn Date: Mon, 13 Jan 2025 12:35:30 +0100 Subject: [PATCH] Cache snap state in memory (#2980) To improve performance when using state, especially when using the new state methods (#2916), we now cache state in memory, and write updates to disk asynchronously. The first time the state is fetched, it's cached in the Snap's runtime data, which is used in subsequent calls. --- .../packages/core-signer/package.json | 2 +- .../packages/core-signer/snap.manifest.json | 2 +- packages/snaps-controllers/coverage.json | 8 +- packages/snaps-controllers/package.json | 1 + .../src/snaps/SnapController.test.tsx | 214 +++++++++++++++++- .../src/snaps/SnapController.ts | 164 ++++++++++++-- .../src/test-utils/controller.ts | 34 ++- yarn.lock | 11 +- 8 files changed, 405 insertions(+), 31 deletions(-) diff --git a/packages/examples/packages/invoke-snap/packages/core-signer/package.json b/packages/examples/packages/invoke-snap/packages/core-signer/package.json index 1f11726d56..57edf5722f 100644 --- a/packages/examples/packages/invoke-snap/packages/core-signer/package.json +++ b/packages/examples/packages/invoke-snap/packages/core-signer/package.json @@ -47,7 +47,7 @@ "@metamask/snaps-sdk": "workspace:^", "@metamask/utils": "^10.0.0", "@noble/curves": "^1.1.0", - "async-mutex": "^0.4.0" + "async-mutex": "^0.5.0" }, "devDependencies": { "@jest/globals": "^29.5.0", diff --git a/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json b/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json index e90d64456b..1e901ffa31 100644 --- a/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json +++ b/packages/examples/packages/invoke-snap/packages/core-signer/snap.manifest.json @@ -7,7 +7,7 @@ "url": "https://github.com/MetaMask/snaps.git" }, "source": { - "shasum": "9+79ZuJLehTDLvMMK/dR0C29/5Q/GRdvTq8EaxTwQkU=", + "shasum": "5YpYX3b3wdRQEjPd2lUeNsNK7FwiflxMLCCPYtDeLnQ=", "location": { "npm": { "filePath": "dist/bundle.js", diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json index 357e17f097..7a462547c5 100644 --- a/packages/snaps-controllers/coverage.json +++ b/packages/snaps-controllers/coverage.json @@ -1,6 +1,6 @@ { - "branches": 93.06, - "functions": 96.54, - "lines": 98.02, - "statements": 97.74 + "branches": 92.96, + "functions": 96.56, + "lines": 98.05, + "statements": 97.77 } diff --git a/packages/snaps-controllers/package.json b/packages/snaps-controllers/package.json index d69e2c9b5e..79f6cb4056 100644 --- a/packages/snaps-controllers/package.json +++ b/packages/snaps-controllers/package.json @@ -95,6 +95,7 @@ "@metamask/snaps-utils": "workspace:^", "@metamask/utils": "^10.0.0", "@xstate/fsm": "^2.0.0", + "async-mutex": "^0.5.0", "browserify-zlib": "^0.2.0", "concat-stream": "^2.0.0", "fast-deep-equal": "^3.1.3", diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx index a99703d581..aaa463bfed 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx +++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx @@ -52,6 +52,7 @@ import { MOCK_SNAP_NAME, DEFAULT_SOURCE_PATH, DEFAULT_ICON_PATH, + TEST_SECRET_RECOVERY_PHRASE_BYTES, } from '@metamask/snaps-utils/test-utils'; import type { SemVerRange, SemVerVersion, Json } from '@metamask/utils'; import { @@ -60,6 +61,7 @@ import { AssertionError, base64ToBytes, stringToBytes, + createDeferredPromise, } from '@metamask/utils'; import { File } from 'buffer'; import { webcrypto } from 'crypto'; @@ -78,6 +80,7 @@ import { getNodeEESMessenger, getPersistedSnapsState, getSnapController, + getSnapControllerEncryptor, getSnapControllerMessenger, getSnapControllerOptions, getSnapControllerWithEES, @@ -97,6 +100,7 @@ import { MOCK_WALLET_SNAP_PERMISSION, MockSnapsRegistry, sleep, + waitForStateChange, } from '../test-utils'; import { delay } from '../utils'; import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants'; @@ -2117,6 +2121,59 @@ describe('SnapController', () => { await service.terminateAllSnaps(); }); + it('clears encrypted state of Snaps when the client is locked', async () => { + const rootMessenger = getControllerMessenger(); + const messenger = getSnapControllerMessenger(rootMessenger); + + const state = { myVariable: 1 }; + + const mockEncryptedState = await encrypt( + ENCRYPTION_KEY, + state, + undefined, + undefined, + DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, + ); + + const getMnemonic = jest + .fn() + .mockReturnValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); + + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: { + [MOCK_SNAP_ID]: getPersistedSnapObject(), + }, + snapStates: { + [MOCK_SNAP_ID]: mockEncryptedState, + }, + }, + getMnemonic, + }), + ); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual(state); + expect(getMnemonic).toHaveBeenCalledTimes(1); + + rootMessenger.publish('KeyringController:lock'); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual(state); + + // We assume `getMnemonic` is called again because the controller needs to + // decrypt the state again. This is not an ideal way to test this, but it + // is the easiest to test this without exposing the internal state of the + // `SnapController`. + expect(getMnemonic).toHaveBeenCalledTimes(2); + + snapController.destroy(); + }); + describe('handleRequest', () => { it.each( Object.keys(handlerEndowments).filter( @@ -8801,6 +8858,7 @@ describe('SnapController', () => { ); const newState = { myVariable: 2 }; + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', @@ -8817,6 +8875,8 @@ describe('SnapController', () => { DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, ); + await promise; + const result = await messenger.call( 'SnapController:getSnapState', MOCK_SNAP_ID, @@ -8831,7 +8891,7 @@ describe('SnapController', () => { snapController.destroy(); }); - it('different snaps use different encryption keys', async () => { + it('uses different encryption keys for different snaps', async () => { const messenger = getSnapControllerMessenger(); const state = { foo: 'bar' }; @@ -8857,6 +8917,8 @@ describe('SnapController', () => { true, ); + const promise = waitForStateChange(messenger); + await messenger.call( 'SnapController:updateSnapState', MOCK_LOCAL_SNAP_ID, @@ -8864,6 +8926,8 @@ describe('SnapController', () => { true, ); + await promise; + const encryptedState1 = await encrypt( ENCRYPTION_KEY, state, @@ -9073,6 +9137,8 @@ describe('SnapController', () => { undefined, DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, ); + + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9080,6 +9146,8 @@ describe('SnapController', () => { true, ); + await promise; + expect(updateSnapStateSpy).toHaveBeenCalledTimes(1); expect(snapController.state.snapStates[MOCK_SNAP_ID]).toStrictEqual( mockEncryptedState, @@ -9137,6 +9205,8 @@ describe('SnapController', () => { ); const state = { foo: 'bar' }; + + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9144,10 +9214,117 @@ describe('SnapController', () => { true, ); + await promise; + expect(pbkdf2Sha512).toHaveBeenCalledTimes(1); snapController.destroy(); }); + + it('queues multiple state updates', async () => { + const messenger = getSnapControllerMessenger(); + + jest.useFakeTimers(); + + const encryptor = getSnapControllerEncryptor(); + const { promise, resolve } = createDeferredPromise(); + const encryptWithKey = jest + .fn< + ReturnType, + Parameters + >() + .mockImplementation(async (...args) => { + resolve(); + await sleep(1); + return await encryptor.encryptWithKey(...args); + }); + + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + encryptor: { + ...getSnapControllerEncryptor(), + // @ts-expect-error - Missing required properties. + encryptWithKey, + }, + }), + ); + + const firstStateChange = waitForStateChange(messenger); + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { foo: 'bar' }, + true, + ); + + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { bar: 'baz' }, + true, + ); + + // We await this promise to ensure the timer is queued. + await promise; + jest.advanceTimersByTime(1); + + // After this point the second update should be queued. + await firstStateChange; + const secondStateChange = waitForStateChange(messenger); + + expect(encryptWithKey).toHaveBeenCalledTimes(1); + + // This is a bit hacky, but we can't simply advance the timer by 1ms + // because the second timer is not running yet. + jest.useRealTimers(); + await secondStateChange; + + expect(encryptWithKey).toHaveBeenCalledTimes(2); + + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual({ bar: 'baz' }); + + snapController.destroy(); + }); + + it('logs an error message if the state fails to persist', async () => { + const messenger = getSnapControllerMessenger(); + + const errorValue = new Error('Failed to persist state.'); + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + // @ts-expect-error - Missing required properties. + encryptor: { + ...getSnapControllerEncryptor(), + encryptWithKey: jest.fn().mockRejectedValue(errorValue), + }, + }), + ); + + const { promise, resolve } = createDeferredPromise(); + const error = jest.spyOn(console, 'error').mockImplementation(resolve); + + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { foo: 'bar' }, + true, + ); + + await promise; + expect(error).toHaveBeenCalledWith(errorValue); + + snapController.destroy(); + }); }); describe('SnapController:clearSnapState', () => { @@ -9206,6 +9383,41 @@ describe('SnapController', () => { snapController.destroy(); }); + + it('logs an error message if the state fails to persist', async () => { + const messenger = getSnapControllerMessenger(); + + const errorValue = new Error('Failed to persist state.'); + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + // @ts-expect-error - Missing required properties. + encryptor: { + ...getSnapControllerEncryptor(), + encryptWithKey: jest.fn().mockRejectedValue(errorValue), + }, + }), + ); + + const { promise, resolve } = createDeferredPromise(); + const error = jest.spyOn(console, 'error').mockImplementation(resolve); + + // @ts-expect-error - Property `update` is protected. + // eslint-disable-next-line jest/prefer-spy-on + snapController.update = jest.fn().mockImplementation(() => { + throw errorValue; + }); + + await messenger.call('SnapController:clearSnapState', MOCK_SNAP_ID, true); + + await promise; + expect(error).toHaveBeenCalledWith(errorValue); + + snapController.destroy(); + }); }); describe('SnapController:updateBlockedSnaps', () => { diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts index 173e8ba0d1..655c1bcdec 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.ts +++ b/packages/snaps-controllers/src/snaps/SnapController.ts @@ -109,6 +109,7 @@ import { } from '@metamask/utils'; import type { StateMachine } from '@xstate/fsm'; import { createMachine, interpret } from '@xstate/fsm'; +import { Mutex } from 'async-mutex'; import type { Patch } from 'immer'; import { nanoid } from 'nanoid'; import semver from 'semver'; @@ -252,6 +253,21 @@ export interface SnapRuntimeData { * A boolean flag to determine whether the Snap is currently being stopped. */ stopping: boolean; + + /** + * Cached encrypted state of the Snap. + */ + state?: Record | null; + + /** + * Cached unencrypted state of the Snap. + */ + unencryptedState?: Record | null; + + /** + * A mutex to prevent concurrent state updates. + */ + stateMutex: Mutex; } export type SnapError = { @@ -552,6 +568,11 @@ export type SnapControllerStateChangeEvent = ControllerStateChangeEvent< SnapControllerState >; +type KeyringControllerLock = { + type: 'KeyringController:lock'; + payload: []; +}; + export type SnapControllerEvents = | SnapBlocked | SnapInstalled @@ -596,7 +617,8 @@ export type AllowedActions = export type AllowedEvents = | ExecutionServiceEvents | SnapInstalled - | SnapUpdated; + | SnapUpdated + | KeyringControllerLock; type SnapControllerMessenger = RestrictedControllerMessenger< typeof controllerName, @@ -906,6 +928,7 @@ export class SnapController extends BaseController< this._onOutboundResponse = this._onOutboundResponse.bind(this); this.#rollbackSnapshots = new Map(); this.#snapsRuntimeData = new Map(); + this.#pollForLastRequestStatus(); /* eslint-disable @typescript-eslint/unbound-method */ @@ -955,6 +978,11 @@ export class SnapController extends BaseController< }, ); + this.messagingSystem.subscribe( + 'KeyringController:lock', + this.#handleLock.bind(this), + ); + this.#initializeStateMachine(); this.#registerMessageHandlers(); @@ -1820,6 +1848,7 @@ export class SnapController extends BaseController< const useCache = this.#hasCachedEncryptionKey(snapId) || this.#encryptor.isVaultUpdated(state); + const { key } = await this.#getSnapEncryptionKey({ snapId, salt, @@ -1860,6 +1889,73 @@ export class SnapController extends BaseController< return JSON.stringify(encryptedState); } + /** + * Get the new Snap state to persist based on the given state and encryption + * flag. + * + * - If the state is null, return null. + * - If the state should be encrypted, return the encrypted state. + * - Otherwise, if the state should not be encrypted, return the JSON- + * stringified state. + * + * @param snapId - The Snap ID. + * @param state - The state to persist. + * @param encrypted - A flag to indicate whether to use encrypted storage or + * not. + * @returns The state to persist. + */ + async #getStateToPersist( + snapId: SnapId, + state: Record | null, + encrypted: boolean, + ) { + if (state === null) { + return null; + } + + if (encrypted) { + return await this.#encryptSnapState(snapId, state); + } + + return JSON.stringify(state); + } + + /** + * Persist the state of a Snap. + * + * This is run with a mutex to ensure that only one state update per Snap is + * processed at a time, avoiding possible race conditions. + * + * @param snapId - The Snap ID. + * @param newSnapState - The new state of the Snap. + * @param encrypted - A flag to indicate whether to use encrypted storage or + * not. + */ + async #persistSnapState( + snapId: SnapId, + newSnapState: Record | null, + encrypted: boolean, + ) { + const runtime = this.#getRuntimeExpect(snapId); + await runtime.stateMutex.runExclusive(async () => { + const newState = await this.#getStateToPersist( + snapId, + newSnapState, + encrypted, + ); + + if (encrypted) { + return this.update((state) => { + state.snapStates[snapId] = newState; + }); + } + + return this.update((state) => { + state.unencryptedSnapStates[snapId] = newState; + }); + }); + } + /** * Updates the own state of the snap with the given id. * This is distinct from the state MetaMask uses to manage snaps. @@ -1873,17 +1969,19 @@ export class SnapController extends BaseController< newSnapState: Record, encrypted: boolean, ) { - if (encrypted) { - const encryptedState = await this.#encryptSnapState(snapId, newSnapState); + const runtime = this.#getRuntimeExpect(snapId); - this.update((state) => { - state.snapStates[snapId] = encryptedState; - }); + if (encrypted) { + runtime.state = newSnapState; } else { - this.update((state) => { - state.unencryptedSnapStates[snapId] = JSON.stringify(newSnapState); - }); + runtime.unencryptedState = newSnapState; } + + // This is intentionally run asynchronously to avoid blocking the main + // thread. + this.#persistSnapState(snapId, newSnapState, encrypted).catch((error) => { + logError(error); + }); } /** @@ -1894,12 +1992,17 @@ export class SnapController extends BaseController< * @param encrypted - A flag to indicate whether to use encrypted storage or not. */ clearSnapState(snapId: SnapId, encrypted: boolean) { - this.update((state) => { - if (encrypted) { - state.snapStates[snapId] = null; - } else { - state.unencryptedSnapStates[snapId] = null; - } + const runtime = this.#getRuntimeExpect(snapId); + if (encrypted) { + runtime.state = null; + } else { + runtime.unencryptedState = null; + } + + // This is intentionally run asynchronously to avoid blocking the main + // thread. + this.#persistSnapState(snapId, null, encrypted).catch((error) => { + logError(error); }); } @@ -1912,6 +2015,13 @@ export class SnapController extends BaseController< * @returns The requested snap state or null if no state exists. */ async getSnapState(snapId: SnapId, encrypted: boolean): Promise { + const runtime = this.#getRuntimeExpect(snapId); + const cachedState = encrypted ? runtime.state : runtime.unencryptedState; + + if (cachedState !== undefined) { + return cachedState; + } + const state = encrypted ? this.state.snapStates[snapId] : this.state.unencryptedSnapStates[snapId]; @@ -1921,11 +2031,17 @@ export class SnapController extends BaseController< } if (!encrypted) { - // For performance reasons, we do not validate that the state is JSON, since we control serialization. - return JSON.parse(state); + // For performance reasons, we do not validate that the state is JSON, + // since we control serialization. + const json = JSON.parse(state); + runtime.unencryptedState = json; + + return json; } const decrypted = await this.#decryptSnapState(snapId, state); + runtime.state = decrypted; + return decrypted; } @@ -3706,6 +3822,7 @@ export class SnapController extends BaseController< pendingOutboundRequests: 0, interpreter, stopping: false, + stateMutex: new Mutex(), }); } @@ -3913,4 +4030,17 @@ export class SnapController extends BaseController< }, }); } + + /** + * Handle the `KeyringController:lock` event. + * + * Currently this clears the cached encrypted state (if any) for all Snaps. + */ + #handleLock() { + for (const runtime of this.#snapsRuntimeData.values()) { + runtime.encryptionKey = null; + runtime.encryptionSalt = null; + runtime.state = undefined; + } + } } diff --git a/packages/snaps-controllers/src/test-utils/controller.ts b/packages/snaps-controllers/src/test-utils/controller.ts index d1d12b701c..c4b33a9f8b 100644 --- a/packages/snaps-controllers/src/test-utils/controller.ts +++ b/packages/snaps-controllers/src/test-utils/controller.ts @@ -1,4 +1,8 @@ import type { ApprovalRequest } from '@metamask/approval-controller'; +import type { + ControllerMessenger, + RestrictedControllerMessenger, +} from '@metamask/base-controller'; import { encryptWithKey, decryptWithKey, @@ -48,16 +52,17 @@ import type { SnapInterfaceControllerEvents, StoredInterface, } from '../interface/SnapInterfaceController'; +import { SnapController } from '../snaps'; import type { AllowedActions, AllowedEvents, PersistedSnapControllerState, SnapControllerActions, SnapControllerEvents, + SnapControllerStateChangeEvent, SnapsRegistryActions, SnapsRegistryEvents, } from '../snaps'; -import { SnapController } from '../snaps'; import type { KeyDerivationOptions } from '../types'; import { MOCK_CRONJOB_PERMISSION } from './cronjob'; import { getNodeEES, getNodeEESMessenger } from './execution-environment'; @@ -462,6 +467,7 @@ export const getSnapControllerMessenger = ( 'SnapController:snapUpdated', 'SnapController:stateChange', 'SnapController:snapRolledback', + 'KeyringController:lock', ], allowedActions: [ 'ApprovalController:addRequest', @@ -535,7 +541,7 @@ export const DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS = { }, }; -const getSnapControllerEncryptor = () => { +export const getSnapControllerEncryptor = () => { return { encryptWithKey, decryptWithKey, @@ -830,3 +836,27 @@ export const getRestrictedSnapInsightsControllerMessenger = ( return controllerMessenger; }; + +/** + * Wait for the state change event to be emitted by the messenger. + * + * @param messenger - The messenger to listen to. + * @returns A promise that resolves when the state change event is emitted. + */ +export async function waitForStateChange( + messenger: + | ControllerMessenger + | RestrictedControllerMessenger< + 'SnapController', + any, + SnapControllerStateChangeEvent, + any, + 'SnapController:stateChange' + >, +) { + return new Promise((resolve) => { + messenger.subscribe('SnapController:stateChange', () => { + resolve(); + }); + }); +} diff --git a/yarn.lock b/yarn.lock index 8d3f030af5..94d0ac018f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4204,7 +4204,7 @@ __metadata: "@swc/jest": "npm:^0.2.26" "@typescript-eslint/eslint-plugin": "npm:^5.42.1" "@typescript-eslint/parser": "npm:^6.21.0" - async-mutex: "npm:^0.4.0" + async-mutex: "npm:^0.5.0" deepmerge: "npm:^4.2.2" depcheck: "npm:^1.4.7" eslint: "npm:^8.27.0" @@ -5752,6 +5752,7 @@ __metadata: "@wdio/spec-reporter": "npm:^8.19.0" "@wdio/static-server-service": "npm:^8.19.0" "@xstate/fsm": "npm:^2.0.0" + async-mutex: "npm:^0.5.0" browserify-zlib: "npm:^0.2.0" concat-stream: "npm:^2.0.0" deepmerge: "npm:^4.2.2" @@ -9645,12 +9646,12 @@ __metadata: languageName: node linkType: hard -"async-mutex@npm:^0.4.0": - version: 0.4.0 - resolution: "async-mutex@npm:0.4.0" +"async-mutex@npm:^0.5.0": + version: 0.5.0 + resolution: "async-mutex@npm:0.5.0" dependencies: tslib: "npm:^2.4.0" - checksum: 10/4a55065aae8c7283e45e2a8ac38ba9812f030696640d650c4ec62cfd67e5d61bd698e67b758a81fcb845e2d5ea1d857106f9235cc4282ad40cd1944b26fde1b2 + checksum: 10/4c6bfce1cc9cd43f723c4d96403ac5f4757f885c953b839cde6956ec8817ff39623b82d67614de10c7933e21626925882fb9bac367db7d15d7cb4f84228722c9 languageName: node linkType: hard