diff --git a/packages/shield-controller/CHANGELOG.md b/packages/shield-controller/CHANGELOG.md index 325e8f0f216..da6198508ea 100644 --- a/packages/shield-controller/CHANGELOG.md +++ b/packages/shield-controller/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Bump `@metamask/transaction-controller` from `^60.7.0` to `^60.8.0` ([#6883](https://github.com/MetaMask/core/pull/6883)) +- Updated internal coverage result polling and log logic. ([#6847](https://github.com/MetaMask/core/pull/6847)) + - Added cancellation logic to the polling. + - Updated implementation of timeout. + - Cancel any pending requests before starting new polling or logging. +- Updated TransactionMeta comparison in `TransactionController:stateChange` subscriber to avoid triggering multiple check coverage result unnecessarily. ([#6847](https://github.com/MetaMask/core/pull/6847)) +- Removed `Personal Sign` check from the check signature coverage result. ([#6847](https://github.com/MetaMask/core/pull/6847)) ## [0.3.2] diff --git a/packages/shield-controller/src/ShieldController.test.ts b/packages/shield-controller/src/ShieldController.test.ts index b741aaa5dfc..f223645421f 100644 --- a/packages/shield-controller/src/ShieldController.test.ts +++ b/packages/shield-controller/src/ShieldController.test.ts @@ -11,6 +11,7 @@ import { } from '@metamask/transaction-controller'; import { ShieldController } from './ShieldController'; +import { TX_META_SIMULATION_DATA_MOCKS } from '../tests/data'; import { createMockBackend, MOCK_COVERAGE_ID } from '../tests/mocks/backend'; import { createMockMessenger } from '../tests/mocks/messenger'; import { @@ -169,6 +170,47 @@ describe('ShieldController', () => { }); }); + TX_META_SIMULATION_DATA_MOCKS.forEach( + ({ description, previousSimulationData, newSimulationData }) => { + it(`should check coverage when ${description}`, async () => { + const { baseMessenger, backend } = setup(); + const previousTxMeta = { + ...generateMockTxMeta(), + simulationData: previousSimulationData, + }; + const coverageResultReceived = + setupCoverageResultReceived(baseMessenger); + + // Add transaction. + baseMessenger.publish( + 'TransactionController:stateChange', + { transactions: [previousTxMeta] } as TransactionControllerState, + undefined as never, + ); + expect(await coverageResultReceived).toBeUndefined(); + expect(backend.checkCoverage).toHaveBeenCalledWith({ + txMeta: previousTxMeta, + }); + + // Simulate transaction. + const txMeta2 = { ...previousTxMeta }; + txMeta2.simulationData = newSimulationData; + const coverageResultReceived2 = + setupCoverageResultReceived(baseMessenger); + baseMessenger.publish( + 'TransactionController:stateChange', + { transactions: [txMeta2] } as TransactionControllerState, + undefined as never, + ); + expect(await coverageResultReceived2).toBeUndefined(); + expect(backend.checkCoverage).toHaveBeenCalledWith({ + coverageId: MOCK_COVERAGE_ID, + txMeta: txMeta2, + }); + }); + }, + ); + it('throws an error when the coverage ID has changed', async () => { const { controller, backend } = setup(); backend.checkCoverage.mockResolvedValueOnce({ diff --git a/packages/shield-controller/src/ShieldController.ts b/packages/shield-controller/src/ShieldController.ts index 470f2982709..1b04e73fe67 100644 --- a/packages/shield-controller/src/ShieldController.ts +++ b/packages/shield-controller/src/ShieldController.ts @@ -5,7 +5,6 @@ import type { } from '@metamask/base-controller'; import { SignatureRequestStatus, - SignatureRequestType, type SignatureRequest, type SignatureStateChange, } from '@metamask/signature-controller'; @@ -236,10 +235,7 @@ export class ShieldController extends BaseController< // Check coverage if the signature request is new and has type // `personal_sign`. - if ( - !previousSignatureRequest && - signatureRequest.type === SignatureRequestType.PersonalSign - ) { + if (!previousSignatureRequest) { this.checkSignatureCoverage(signatureRequest).catch( // istanbul ignore next (error) => log('Error checking coverage:', error), @@ -268,15 +264,15 @@ export class ShieldController extends BaseController< ); for (const transaction of transactions) { const previousTransaction = previousTransactionsById.get(transaction.id); + // Check if the simulation data has changed. + const simulationDataChanged = this.#compareTransactionSimulationData( + transaction.simulationData, + previousTransaction?.simulationData, + ); // Check coverage if the transaction is new or if the simulation data has // changed. - if ( - !previousTransaction || - // Checking reference equality is sufficient because this object is - // replaced if the simulation data has changed. - previousTransaction.simulationData !== transaction.simulationData - ) { + if (!previousTransaction || simulationDataChanged) { this.checkCoverage(transaction).catch( // istanbul ignore next (error) => log('Error checking coverage:', error), @@ -443,4 +439,61 @@ export class ShieldController extends BaseController< #getLatestCoverageId(itemId: string): string | undefined { return this.state.coverageResults[itemId]?.results[0]?.coverageId; } + + /** + * Compares the simulation data of a transaction. + * + * @param simulationData - The simulation data of the transaction. + * @param previousSimulationData - The previous simulation data of the transaction. + * @returns Whether the simulation data has changed. + */ + #compareTransactionSimulationData( + simulationData?: TransactionMeta['simulationData'], + previousSimulationData?: TransactionMeta['simulationData'], + ) { + if (!simulationData && !previousSimulationData) { + return false; + } + + // check the simulation error + if ( + simulationData?.error?.code !== previousSimulationData?.error?.code || + simulationData?.error?.message !== previousSimulationData?.error?.message + ) { + return true; + } + + // check the native balance change + if ( + simulationData?.nativeBalanceChange?.difference !== + previousSimulationData?.nativeBalanceChange?.difference || + simulationData?.nativeBalanceChange?.newBalance !== + previousSimulationData?.nativeBalanceChange?.newBalance || + simulationData?.nativeBalanceChange?.previousBalance !== + previousSimulationData?.nativeBalanceChange?.previousBalance || + simulationData?.nativeBalanceChange?.isDecrease !== + previousSimulationData?.nativeBalanceChange?.isDecrease + ) { + return true; + } + + // check the token balance changes + if ( + simulationData?.tokenBalanceChanges?.length !== + previousSimulationData?.tokenBalanceChanges?.length || + simulationData?.tokenBalanceChanges?.some( + (tokenBalanceChange, index) => + tokenBalanceChange.difference !== + previousSimulationData?.tokenBalanceChanges?.[index]?.difference, + ) + ) { + return true; + } + + // check the isUpdatedAfterSecurityCheck + return ( + simulationData?.isUpdatedAfterSecurityCheck !== + previousSimulationData?.isUpdatedAfterSecurityCheck + ); + } } diff --git a/packages/shield-controller/src/backend.test.ts b/packages/shield-controller/src/backend.test.ts index b176059b61e..41d69550ec6 100644 --- a/packages/shield-controller/src/backend.test.ts +++ b/packages/shield-controller/src/backend.test.ts @@ -45,6 +45,11 @@ function setup({ } describe('ShieldRemoteBackend', () => { + afterEach(() => { + // Clean up mocks after each test + jest.clearAllMocks(); + }); + it('should check coverage', async () => { const { backend, fetchMock, getAccessToken } = setup(); @@ -143,7 +148,7 @@ describe('ShieldRemoteBackend', () => { const txMeta = generateMockTxMeta(); await expect(backend.checkCoverage({ txMeta })).rejects.toThrow( - 'Timeout waiting for coverage result', + 'getCoverageResult: Request timed out', ); // Waiting here ensures coverage of the unexpected error and lets us know diff --git a/packages/shield-controller/src/backend.ts b/packages/shield-controller/src/backend.ts index dcc863850de..6858476a0a9 100644 --- a/packages/shield-controller/src/backend.ts +++ b/packages/shield-controller/src/backend.ts @@ -1,6 +1,7 @@ import type { SignatureRequest } from '@metamask/signature-controller'; import type { TransactionMeta } from '@metamask/transaction-controller'; +import { PollingWithTimeoutAndAbort } from './polling-with-timeout-abort'; import type { CheckCoverageRequest, CheckSignatureCoverageRequest, @@ -58,6 +59,8 @@ export class ShieldRemoteBackend implements ShieldBackend { readonly #fetch: typeof globalThis.fetch; + readonly #pollingWithTimeout: PollingWithTimeoutAndAbort; + constructor({ getAccessToken, getCoverageResultTimeout = 5000, // milliseconds @@ -76,6 +79,10 @@ export class ShieldRemoteBackend implements ShieldBackend { this.#getCoverageResultPollInterval = getCoverageResultPollInterval; this.#baseUrl = baseUrl; this.#fetch = fetchFn; + this.#pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: getCoverageResultTimeout, + pollInterval: getCoverageResultPollInterval, + }); } async checkCoverage(req: CheckCoverageRequest): Promise { @@ -90,6 +97,7 @@ export class ShieldRemoteBackend implements ShieldBackend { const txCoverageResultUrl = `${this.#baseUrl}/v1/transaction/coverage/result`; const coverageResult = await this.#getCoverageResult(coverageId, { + requestId: req.txMeta.id, coverageResultUrl: txCoverageResultUrl, }); return { @@ -114,6 +122,7 @@ export class ShieldRemoteBackend implements ShieldBackend { const signatureCoverageResultUrl = `${this.#baseUrl}/v1/signature/coverage/result`; const coverageResult = await this.#getCoverageResult(coverageId, { + requestId: req.signatureRequest.id, coverageResultUrl: signatureCoverageResultUrl, }); return { @@ -132,6 +141,9 @@ export class ShieldRemoteBackend implements ShieldBackend { ...initBody, }; + // clean up the pending coverage result polling + this.#pollingWithTimeout.abortPendingRequests(req.signatureRequest.id); + const res = await this.#fetch( `${this.#baseUrl}/v1/signature/coverage/log`, { @@ -153,6 +165,9 @@ export class ShieldRemoteBackend implements ShieldBackend { ...initBody, }; + // clean up the pending coverage result polling + this.#pollingWithTimeout.abortPendingRequests(req.txMeta.id); + const res = await this.#fetch( `${this.#baseUrl}/v1/transaction/coverage/log`, { @@ -183,7 +198,8 @@ export class ShieldRemoteBackend implements ShieldBackend { async #getCoverageResult( coverageId: string, - configs: { + config: { + requestId: string; coverageResultUrl: string; timeout?: number; pollInterval?: number; @@ -192,40 +208,33 @@ export class ShieldRemoteBackend implements ShieldBackend { const reqBody: GetCoverageResultRequest = { coverageId, }; - - const timeout = configs?.timeout ?? this.#getCoverageResultTimeout; - const pollInterval = - configs?.pollInterval ?? this.#getCoverageResultPollInterval; - + const pollingOptions = { + timeout: config.timeout ?? this.#getCoverageResultTimeout, + pollInterval: config.pollInterval ?? this.#getCoverageResultPollInterval, + fnName: 'getCoverageResult', + }; const headers = await this.#createHeaders(); - return await new Promise((resolve, reject) => { - let timeoutReached = false; - setTimeout(() => { - timeoutReached = true; - reject(new Error('Timeout waiting for coverage result')); - }, timeout); - const poll = async (): Promise => { - // The timeoutReached variable is modified in the timeout callback. - // eslint-disable-next-line no-unmodified-loop-condition - while (!timeoutReached) { - const startTime = Date.now(); - const res = await this.#fetch(configs.coverageResultUrl, { - method: 'POST', - headers, - body: JSON.stringify(reqBody), - }); - if (res.status === 200) { - return (await res.json()) as GetCoverageResultResponse; - } - await sleep(pollInterval - (Date.now() - startTime)); + return await new Promise((resolve, reject) => { + const requestCoverageFn = async ( + signal: AbortSignal, + ): Promise => { + const res = await this.#fetch(config.coverageResultUrl, { + method: 'POST', + headers, + body: JSON.stringify(reqBody), + signal, + }); + if (res.status === 200) { + return (await res.json()) as GetCoverageResultResponse; } - // The following line will not have an effect as the upper level promise - // will already be rejected by now. - throw new Error('unexpected error'); + throw new Error(`Failed to get coverage result: ${res.status}`); }; - poll().then(resolve).catch(reject); + this.#pollingWithTimeout + .pollRequest(config.requestId, requestCoverageFn, pollingOptions) + .then(resolve) + .catch(reject); }); } @@ -238,16 +247,6 @@ export class ShieldRemoteBackend implements ShieldBackend { } } -/** - * Sleep for a specified amount of time. - * - * @param ms - The number of milliseconds to sleep. - * @returns A promise that resolves after the specified amount of time. - */ -async function sleep(ms: number) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - /** * Make the body for the init coverage check request. * diff --git a/packages/shield-controller/src/polling-with-timeout-abort.test.ts b/packages/shield-controller/src/polling-with-timeout-abort.test.ts new file mode 100644 index 00000000000..d77e0110abb --- /dev/null +++ b/packages/shield-controller/src/polling-with-timeout-abort.test.ts @@ -0,0 +1,105 @@ +import { PollingWithTimeoutAndAbort } from './polling-with-timeout-abort'; +import { delay } from '../tests/utils'; + +describe('PollingWithTimeoutAndAbort', () => { + it('should timeout when the request does not resolve within the timeout period', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 100, + pollInterval: 10, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 10); + }); + }); + + await expect( + pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + timeout: 100, + }), + ).rejects.toThrow('test: Request timed out'); + }); + + it('should timeout with default polling options', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 100, + pollInterval: 10, + }); + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 10); + }); + }); + + await expect( + pollingWithTimeout.pollRequest('test', requestFn), + ).rejects.toThrow(': Request timed out'); + }); + + it('should abort pending requests when new request is made', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 1000, + pollInterval: 20, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (signal: AbortSignal) => { + return new Promise((resolve, reject) => { + setTimeout(() => { + // eslint-disable-next-line jest/no-conditional-in-test -- we want to simulate the abort signal being triggered during the request + if (signal.aborted) { + reject(new Error('test error')); + } + resolve('test result'); + }, 100); + }); + }); + + const firstAttempt = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + await delay(15); // small delay to let the first request start + const secondAttempt = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + + await expect(firstAttempt).rejects.toThrow('test: Request cancelled'); // first request should be aborted by the second request + const result = await secondAttempt; + expect(result).toBe('test result'); // second request should succeed + }); + + it('should abort pending requests when abortPendingRequests is called', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 1000, + pollInterval: 20, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 100); + }); + }); + + const request = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + await delay(15); // small delay to let the request start + pollingWithTimeout.abortPendingRequests('test'); + await expect(request).rejects.toThrow('test: Request cancelled'); + }); +}); diff --git a/packages/shield-controller/src/polling-with-timeout-abort.ts b/packages/shield-controller/src/polling-with-timeout-abort.ts new file mode 100644 index 00000000000..68d7f9f3018 --- /dev/null +++ b/packages/shield-controller/src/polling-with-timeout-abort.ts @@ -0,0 +1,181 @@ +export type RequestEntry = { + abortController: AbortController; // The abort controller for the request + abortHandler: () => void; // The abort handler for the request + timerId: NodeJS.Timeout; // The timer ID for the request timeout +}; + +export type RequestFn = ( + signal: AbortSignal, +) => Promise; + +export class PollingWithTimeoutAndAbort { + readonly ABORT_REASON_TIMEOUT = 'Timeout'; + + readonly ABORT_REASON_USER_CANCELLATION = 'User cancellation'; + + // Map of request ID to request entry + readonly #requestEntries: Map = new Map(); + + readonly #timeout: number; + + readonly #pollInterval: number; + + constructor(config: { timeout: number; pollInterval: number }) { + this.#timeout = config.timeout; + this.#pollInterval = config.pollInterval; + } + + async pollRequest( + requestId: string, + requestFn: RequestFn, + pollingOptions: { + timeout?: number; + pollInterval?: number; + fnName?: string; + } = { + fnName: '', + }, + ) { + const timeout = pollingOptions.timeout ?? this.#timeout; + const pollInterval = pollingOptions.pollInterval ?? this.#pollInterval; + + // clean up the request entry if it exists + this.abortPendingRequests(requestId); + + // insert the request entry for the next polling cycle + const { abortController } = this.#insertRequestEntry(requestId, timeout); + + while (!abortController.signal.aborted) { + try { + const result = await requestFn(abortController.signal); + // polling success, we just need to clean up the request entry and return the result + this.#cleanUpOnFinished(requestId); + return result; + } catch { + // polling failed due to timeout or cancelled, + // we need to clean up the request entry and throw the error + if (this.#isAbortedAndNotTimeoutReason(abortController.signal)) { + throw new Error(`${pollingOptions.fnName}: Request cancelled`); + } + } + await this.#delay(pollInterval); + } + + // The following line will not have an effect as the upper level promise + // will already be rejected by now. + throw new Error(`${pollingOptions.fnName}: Request timed out`); + } + + /** + * Abort the pending requests. + * This will clean up the request entry if it exists, and abort the pending request if it exists. + * + * @param requestId - The ID of the request to abort. + */ + abortPendingRequests(requestId: string) { + // firstly clean up the request entry if it exists + // note: this does not abort the request, it only cleans up the request entry for the next polling cycle + const existingEntry = this.#cleanUpRequestEntryIfExists(requestId); + // then abort the request if it exists + // note: this does abort the request, but it will not trigger the abort handler (hence, {@link cleanUpRequestEntryIfExists} will not be called) + // coz the AbortHandler event listener is already removed from the AbortSignal + existingEntry?.abortController.abort(this.ABORT_REASON_USER_CANCELLATION); + } + + /** + * Insert a new request entry. + * This will create a new abort controller, set a timeout to abort the request if it takes too long, and set the abort handler. + * + * @param requestId - The ID of the request to insert the entry for. + * @param timeout - The timeout for the request. + * @returns The request entry that was inserted. + */ + #insertRequestEntry(requestId: string, timeout: number) { + const abortController = new AbortController(); + + // Set a timeout to abort the request if it takes too long + const timerId = setTimeout( + () => this.#handleRequestTimeout(requestId), + timeout, + ); + + // Set the abort handler and listen to the `abort` event + const abortHandler = () => { + this.#cleanUpOnFinished(requestId); + }; + abortController.signal.addEventListener('abort', abortHandler); + + const requestEntry: RequestEntry = { + abortController, + abortHandler, + timerId, + }; + + // Insert the request entry + this.#requestEntries.set(requestId, requestEntry); + + return requestEntry; + } + + /** + * Handle the request timeout. + * This will abort the request, this will also trigger the abort handler (hence, {@link #cleanUpRequestEntryIfExists} will be called) + * + * @param requestId - The ID of the request to handle the timeout for. + */ + #handleRequestTimeout(requestId: string) { + const requestEntry = this.#cleanUpOnFinished(requestId); + if (requestEntry) { + // Abort the signal, so that the polling loop will exit + requestEntry.abortController.abort(this.ABORT_REASON_TIMEOUT); + } + } + + /** + * Clean up the request entry upon finished (success or failure). + * This will remove the abort handler from the AbortSignal, clear the timeout, and remove the request entry. + * + * @param requestId - The ID of the request to clean up for. + * @returns The request entry that was cleaned up, if it exists. + */ + #cleanUpOnFinished(requestId: string): RequestEntry | undefined { + const requestEntry = this.#cleanUpRequestEntryIfExists(requestId); + if (requestEntry) { + requestEntry.abortController.signal.removeEventListener( + 'abort', + requestEntry.abortHandler, + ); + } + return requestEntry; + } + + /** + * Clean up the request entry if it exists. + * This will clear the pending timeout, remove the event listener from the AbortSignal, and remove the request entry. + * + * @param requestId - The ID of the request to handle the abort for. + * @returns The request entry that was aborted, if it exists. + */ + #cleanUpRequestEntryIfExists(requestId: string): RequestEntry | undefined { + const requestEntry = this.#requestEntries.get(requestId); + if (requestEntry) { + clearTimeout(requestEntry.timerId); // Clear the timeout + this.#requestEntries.delete(requestId); // Remove the request entry + } + return requestEntry; + } + + /** + * Check if the abort signal is aborted and not due to timeout. + * + * @param signal - The abort signal to check. + * @returns True if the abort signal is aborted and not due to timeout, false otherwise. + */ + #isAbortedAndNotTimeoutReason(signal: AbortSignal) { + return signal.aborted && signal.reason !== this.ABORT_REASON_TIMEOUT; + } + + async #delay(ms: number) { + return new Promise((resolve) => setTimeout(resolve, ms)); + } +} diff --git a/packages/shield-controller/tests/data.ts b/packages/shield-controller/tests/data.ts new file mode 100644 index 00000000000..f1c7959a3e7 --- /dev/null +++ b/packages/shield-controller/tests/data.ts @@ -0,0 +1,70 @@ +import type { SimulationData } from '@metamask/transaction-controller'; +import { SimulationTokenStandard } from '@metamask/transaction-controller'; + +export const TX_META_SIMULATION_DATA_MOCKS: { + description: string; + previousSimulationData: SimulationData | undefined; + newSimulationData: SimulationData; +}[] = [ + { + description: '`SimulationData.nativeBalanceChange` has changed', + previousSimulationData: undefined, + newSimulationData: { + nativeBalanceChange: { + difference: '0x1', + previousBalance: '0x1', + newBalance: '0x2', + isDecrease: true, + }, + tokenBalanceChanges: [], + }, + }, + { + description: '`SimulationData.tokenBalanceChanges` has changed', + previousSimulationData: { + tokenBalanceChanges: [ + { + difference: '0x1', + previousBalance: '0x1', + standard: SimulationTokenStandard.erc20, + address: '0x1', + newBalance: '0x2', + isDecrease: true, + }, + ], + }, + newSimulationData: { + tokenBalanceChanges: [ + { + difference: '0x2', + previousBalance: '0x1', + standard: SimulationTokenStandard.erc20, + address: '0x1', + newBalance: '0x3', + isDecrease: true, + }, + ], + }, + }, + { + description: '`SimulationData.error` has changed', + previousSimulationData: undefined, + newSimulationData: { + error: { + code: '-123', + message: 'Reverted', + }, + tokenBalanceChanges: [], + }, + }, + { + description: '`SimulationData.isUpdatedAfterSecurityCheck` has changed', + previousSimulationData: { + tokenBalanceChanges: [], + }, + newSimulationData: { + isUpdatedAfterSecurityCheck: true, + tokenBalanceChanges: [], + }, + }, +]; diff --git a/packages/shield-controller/tests/utils.ts b/packages/shield-controller/tests/utils.ts index 8f40bfe94f1..d26c1ed0448 100644 --- a/packages/shield-controller/tests/utils.ts +++ b/packages/shield-controller/tests/utils.ts @@ -99,3 +99,15 @@ export function setupCoverageResultReceived( baseMessenger.subscribe('ShieldController:coverageResultReceived', handler); }); } + +/** + * Delay for a specified amount of time. + * + * @param ms - The number of milliseconds to delay. + * @returns A promise that resolves after the specified amount of time. + */ +export function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +}