diff --git a/eslint-warning-thresholds.json b/eslint-warning-thresholds.json index 6c2cd98ba93..b05522c2a37 100644 --- a/eslint-warning-thresholds.json +++ b/eslint-warning-thresholds.json @@ -416,11 +416,8 @@ "packages/permission-log-controller/tests/PermissionLogController.test.ts": { "import-x/order": 1 }, - "packages/phishing-controller/src/PhishingController.test.ts": { - "import-x/no-named-as-default-member": 1 - }, "packages/phishing-controller/src/PhishingController.ts": { - "jsdoc/check-tag-names": 42, + "jsdoc/check-tag-names": 38, "jsdoc/tag-lines": 1 }, "packages/phishing-controller/src/PhishingDetector.ts": { diff --git a/packages/phishing-controller/CHANGELOG.md b/packages/phishing-controller/CHANGELOG.md index 5bfda2c58d5..f9c66f6c409 100644 --- a/packages/phishing-controller/CHANGELOG.md +++ b/packages/phishing-controller/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add URL scan cache functionality to improve performance ([#5625](https://github.com/MetaMask/core/pull/5625)) + - Added `UrlScanCache` class for caching phishing detection scan results + - Added methods to `PhishingController`: `setUrlScanCacheTTL`, `setUrlScanCacheMaxSize`, `clearUrlScanCache` + - Added URL scan cache state to `PhishingControllerState` + - Added configuration options: `urlScanCacheTTL` and `urlScanCacheMaxSize` - Add `bulkScanUrls` method to `PhishingController` for scanning multiple URLs for phishing in bulk ([#5682](https://github.com/MetaMask/core/pull/5682)) - Add `BulkPhishingDetectionScanResponse` type for bulk URL scan results ([#5682](https://github.com/MetaMask/core/pull/5682)) - Add `PHISHING_DETECTION_BULK_SCAN_ENDPOINT` constant ([#5682](https://github.com/MetaMask/core/pull/5682)) diff --git a/packages/phishing-controller/src/PhishingController.test.ts b/packages/phishing-controller/src/PhishingController.test.ts index e8b3a14a7e3..c1347b6ada2 100644 --- a/packages/phishing-controller/src/PhishingController.test.ts +++ b/packages/phishing-controller/src/PhishingController.test.ts @@ -1,6 +1,6 @@ import { Messenger } from '@metamask/base-controller'; import { strict as assert } from 'assert'; -import nock from 'nock'; +import nock, { cleanAll, isDone, pendingMocks } from 'nock'; import sinon from 'sinon'; import { @@ -56,6 +56,7 @@ function getPhishingController(options?: Partial) { describe('PhishingController', () => { afterEach(() => { sinon.restore(); + cleanAll(); }); it('should have no default phishing lists', () => { @@ -2156,7 +2157,7 @@ describe('PhishingController', () => { describe('PhishingController - isBlockedRequest', () => { afterEach(() => { - nock.cleanAll(); + cleanAll(); }); it('should return false if c2DomainBlocklist is not defined or empty', async () => { @@ -2871,3 +2872,323 @@ describe('PhishingController', () => { }); }); }); + +describe('URL Scan Cache', () => { + let clock: sinon.SinonFakeTimers; + + beforeEach(() => { + clock = sinon.useFakeTimers(); + }); + afterEach(() => { + sinon.restore(); + cleanAll(); + }); + + it('should cache scan results and return them on subsequent calls', async () => { + const testDomain = 'example.com'; + + // Spy on the fetch function to track calls + const fetchSpy = jest.spyOn(global, 'fetch'); + + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController(); + + const result1 = await controller.scanUrl(`https://${testDomain}`); + expect(result1).toStrictEqual({ + domainName: testDomain, + recommendedAction: RecommendedAction.None, + }); + + const result2 = await controller.scanUrl(`https://${testDomain}`); + expect(result2).toStrictEqual({ + domainName: testDomain, + recommendedAction: RecommendedAction.None, + }); + + // Verify that fetch was called exactly once + expect(fetchSpy).toHaveBeenCalledTimes(1); + + fetchSpy.mockRestore(); + }); + + it('should expire cache entries after TTL', async () => { + const testDomain = 'example.com'; + const cacheTTL = 300; // 5 minutes + + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController({ + urlScanCacheTTL: cacheTTL, + }); + + await controller.scanUrl(`https://${testDomain}`); + + // Before TTL expires, should use cache + clock.tick((cacheTTL - 10) * 1000); + await controller.scanUrl(`https://${testDomain}`); + expect(pendingMocks()).toHaveLength(1); // One mock remaining + + // After TTL expires, should fetch again + clock.tick(11 * 1000); + await controller.scanUrl(`https://${testDomain}`); + expect(pendingMocks()).toHaveLength(0); // All mocks used + }); + + it('should evict oldest entries when cache exceeds max size', async () => { + const maxCacheSize = 2; + const domains = ['domain1.com', 'domain2.com', 'domain3.com']; + + // Setup nock to respond to all three domains + domains.forEach((domain) => { + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(domain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + }); + + // Setup a second request for the first domain + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(domains[0])}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.Warn, + }); + + const controller = getPhishingController({ + urlScanCacheMaxSize: maxCacheSize, + }); + + // Fill the cache + await controller.scanUrl(`https://${domains[0]}`); + clock.tick(1000); // Ensure different timestamps + await controller.scanUrl(`https://${domains[1]}`); + + // This should evict the oldest entry (domain1) + clock.tick(1000); + await controller.scanUrl(`https://${domains[2]}`); + + // Now domain1 should not be in cache and require a new fetch + await controller.scanUrl(`https://${domains[0]}`); + + // All mocks should be used + expect(isDone()).toBe(true); + }); + + it('should clear the cache when clearUrlScanCache is called', async () => { + const testDomain = 'example.com'; + + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController(); + + // First call should fetch from API + await controller.scanUrl(`https://${testDomain}`); + + // Clear the cache + controller.clearUrlScanCache(); + + // Should fetch again + await controller.scanUrl(`https://${testDomain}`); + + // All mocks should be used + expect(isDone()).toBe(true); + }); + + it('should allow changing the TTL', async () => { + const testDomain = 'example.com'; + const initialTTL = 300; // 5 minutes + const newTTL = 60; // 1 minute + + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController({ + urlScanCacheTTL: initialTTL, + }); + + // First call should fetch from API + await controller.scanUrl(`https://${testDomain}`); + + // Change TTL + controller.setUrlScanCacheTTL(newTTL); + + // Before new TTL expires, should use cache + clock.tick((newTTL - 10) * 1000); + await controller.scanUrl(`https://${testDomain}`); + expect(pendingMocks()).toHaveLength(1); // One mock remaining + + // After new TTL expires, should fetch again + clock.tick(11 * 1000); + await controller.scanUrl(`https://${testDomain}`); + expect(pendingMocks()).toHaveLength(0); // All mocks used + }); + + it('should allow changing the max cache size', async () => { + const initialMaxSize = 3; + const newMaxSize = 2; + const domains = [ + 'domain1.com', + 'domain2.com', + 'domain3.com', + 'domain4.com', + ]; + + // Setup nock to respond to all domains + domains.forEach((domain) => { + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(domain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + }); + + const controller = getPhishingController({ + urlScanCacheMaxSize: initialMaxSize, + }); + + // Fill the cache to initial size + await controller.scanUrl(`https://${domains[0]}`); + clock.tick(1000); // Ensure different timestamps + await controller.scanUrl(`https://${domains[1]}`); + clock.tick(1000); + await controller.scanUrl(`https://${domains[2]}`); + + // Verify initial cache size + expect(Object.keys(controller.state.urlScanCache)).toHaveLength( + initialMaxSize, + ); + // Reduce the max size + controller.setUrlScanCacheMaxSize(newMaxSize); + + // Add another entry which should trigger eviction + await controller.scanUrl(`https://${domains[3]}`); + + // Verify the cache size doesn't exceed new max size + expect( + Object.keys(controller.state.urlScanCache).length, + ).toBeLessThanOrEqual(newMaxSize); + }); + + it('should handle fetch errors and not cache them', async () => { + const testDomain = 'example.com'; + + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(500, { error: 'Internal Server Error' }) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController(); + + // First call should result in an error response + const result1 = await controller.scanUrl(`https://${testDomain}`); + expect(result1.fetchError).toBeDefined(); + + // Second call should try again (not use cache since errors aren't cached) + const result2 = await controller.scanUrl(`https://${testDomain}`); + expect(result2.fetchError).toBeUndefined(); + expect(result2.recommendedAction).toBe(RecommendedAction.None); + + // All mocks should be used + expect(isDone()).toBe(true); + }); + + it('should handle timeout errors and not cache them', async () => { + const testDomain = 'example.com'; + + // First mock a timeout/error response + nock(PHISHING_DETECTION_BASE_URL) + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .replyWithError('connection timeout') + .get( + `/${PHISHING_DETECTION_SCAN_ENDPOINT}?url=${encodeURIComponent(testDomain)}`, + ) + .reply(200, { + recommendedAction: RecommendedAction.None, + }); + + const controller = getPhishingController(); + + // First call should result in an error + const result1 = await controller.scanUrl(`https://${testDomain}`); + expect(result1.fetchError).toBeDefined(); + + // Second call should succeed (not use cache since errors aren't cached) + const result2 = await controller.scanUrl(`https://${testDomain}`); + expect(result2.fetchError).toBeUndefined(); + expect(result2.recommendedAction).toBe(RecommendedAction.None); + + // All mocks should be used + expect(isDone()).toBe(true); + }); + + it('should handle invalid URLs and not cache them', async () => { + const invalidUrl = 'not-a-valid-url'; + + const controller = getPhishingController(); + + // First call should return an error for invalid URL + const result1 = await controller.scanUrl(invalidUrl); + expect(result1.fetchError).toBeDefined(); + + // Second call should also return an error (not from cache) + const result2 = await controller.scanUrl(invalidUrl); + expect(result2.fetchError).toBeDefined(); + }); +}); diff --git a/packages/phishing-controller/src/PhishingController.ts b/packages/phishing-controller/src/PhishingController.ts index b15edbb95b9..2db6fdaf9c5 100644 --- a/packages/phishing-controller/src/PhishingController.ts +++ b/packages/phishing-controller/src/PhishingController.ts @@ -17,6 +17,12 @@ import { type PhishingDetectionScanResult, RecommendedAction, } from './types'; +import { + DEFAULT_URL_SCAN_CACHE_MAX_SIZE, + DEFAULT_URL_SCAN_CACHE_TTL, + UrlScanCache, + type UrlScanCacheEntry, +} from './UrlScanCache'; import { applyDiffs, fetchTimeNow, @@ -205,6 +211,7 @@ const metadata = { hotlistLastFetched: { persist: true, anonymous: false }, stalelistLastFetched: { persist: true, anonymous: false }, c2DomainBlocklistLastFetched: { persist: true, anonymous: false }, + urlScanCache: { persist: true, anonymous: false }, }; /** @@ -218,6 +225,7 @@ const getDefaultState = (): PhishingControllerState => { hotlistLastFetched: 0, stalelistLastFetched: 0, c2DomainBlocklistLastFetched: 0, + urlScanCache: {}, }; }; @@ -234,20 +242,25 @@ export type PhishingControllerState = { hotlistLastFetched: number; stalelistLastFetched: number; c2DomainBlocklistLastFetched: number; + urlScanCache: Record; }; /** - * @type PhishingControllerOptions + * PhishingControllerOptions * * Phishing controller options - * @property stalelistRefreshInterval - Polling interval used to fetch stale list. - * @property hotlistRefreshInterval - Polling interval used to fetch hotlist diff list. - * @property c2DomainBlocklistRefreshInterval - Polling interval used to fetch c2 domain blocklist. + * stalelistRefreshInterval - Polling interval used to fetch stale list. + * hotlistRefreshInterval - Polling interval used to fetch hotlist diff list. + * c2DomainBlocklistRefreshInterval - Polling interval used to fetch c2 domain blocklist. + * urlScanCacheTTL - Time to live in seconds for cached scan results. + * urlScanCacheMaxSize - Maximum number of entries in the scan cache. */ export type PhishingControllerOptions = { stalelistRefreshInterval?: number; hotlistRefreshInterval?: number; c2DomainBlocklistRefreshInterval?: number; + urlScanCacheTTL?: number; + urlScanCacheMaxSize?: number; messenger: PhishingControllerMessenger; state?: Partial; }; @@ -318,6 +331,8 @@ export class PhishingController extends BaseController< #c2DomainBlocklistRefreshInterval: number; + readonly #urlScanCache: UrlScanCache; + #inProgressHotlistUpdate?: Promise; #inProgressStalelistUpdate?: Promise; @@ -331,6 +346,8 @@ export class PhishingController extends BaseController< * @param config.stalelistRefreshInterval - Polling interval used to fetch stale list. * @param config.hotlistRefreshInterval - Polling interval used to fetch hotlist diff list. * @param config.c2DomainBlocklistRefreshInterval - Polling interval used to fetch c2 domain blocklist. + * @param config.urlScanCacheTTL - Time to live in seconds for cached scan results. + * @param config.urlScanCacheMaxSize - Maximum number of entries in the scan cache. * @param config.messenger - The controller restricted messenger. * @param config.state - Initial state to set on this controller. */ @@ -338,6 +355,8 @@ export class PhishingController extends BaseController< stalelistRefreshInterval = STALELIST_REFRESH_INTERVAL, hotlistRefreshInterval = HOTLIST_REFRESH_INTERVAL, c2DomainBlocklistRefreshInterval = C2_DOMAIN_BLOCKLIST_REFRESH_INTERVAL, + urlScanCacheTTL = DEFAULT_URL_SCAN_CACHE_TTL, + urlScanCacheMaxSize = DEFAULT_URL_SCAN_CACHE_MAX_SIZE, messenger, state = {}, }: PhishingControllerOptions) { @@ -354,6 +373,17 @@ export class PhishingController extends BaseController< this.#stalelistRefreshInterval = stalelistRefreshInterval; this.#hotlistRefreshInterval = hotlistRefreshInterval; this.#c2DomainBlocklistRefreshInterval = c2DomainBlocklistRefreshInterval; + this.#urlScanCache = new UrlScanCache({ + cacheTTL: urlScanCacheTTL, + maxCacheSize: urlScanCacheMaxSize, + initialCache: this.state.urlScanCache, + updateState: (cache) => { + this.update((draftState) => { + draftState.urlScanCache = cache; + }); + }, + }); + this.#registerMessageHandlers(); this.updatePhishingDetector(); @@ -415,6 +445,31 @@ export class PhishingController extends BaseController< this.#c2DomainBlocklistRefreshInterval = interval; } + /** + * Set the time-to-live for URL scan cache entries. + * + * @param ttl - The TTL in seconds. + */ + setUrlScanCacheTTL(ttl: number) { + this.#urlScanCache.setTTL(ttl); + } + + /** + * Set the maximum number of entries in the URL scan cache. + * + * @param maxSize - The maximum cache size. + */ + setUrlScanCacheMaxSize(maxSize: number) { + this.#urlScanCache.setMaxSize(maxSize); + } + + /** + * Clear the URL scan cache. + */ + clearUrlScanCache() { + this.#urlScanCache.clear(); + } + /** * Determine if an update to the stalelist configuration is needed. * @@ -607,6 +662,11 @@ export class PhishingController extends BaseController< }; } + const cachedResult = this.#urlScanCache.get(hostname); + if (cachedResult) { + return cachedResult; + } + const apiResponse = await safelyExecuteWithTimeout( async () => { const res = await fetch( @@ -645,10 +705,14 @@ export class PhishingController extends BaseController< }; } - return { + const result = { domainName: hostname, recommendedAction: apiResponse.recommendedAction, } as PhishingDetectionScanResult; + + this.#urlScanCache.add(hostname, result); + + return result; }; /** @@ -681,23 +745,44 @@ export class PhishingController extends BaseController< }; } + const urlsToScan: string[] = []; + const results: Record = {}; + const errors: Record = {}; + const MAX_URL_LENGTH = 2048; + // Process each URL: validate and check cache for (const url of urls) { if (url.length > MAX_URL_LENGTH) { - return { - results: {}, - errors: { - [url]: [`URL length must not exceed ${MAX_URL_LENGTH} characters`], - }, - }; + errors[url] = [ + `URL length must not exceed ${MAX_URL_LENGTH} characters`, + ]; + continue; } + + const [hostname, ok] = getHostnameFromWebUrl(url); + if (!ok) { + errors[url] = ['url is not a valid web URL']; + continue; + } + + // Check if the result is already cached + const cachedResult = this.#urlScanCache.get(hostname); + if (cachedResult) { + results[hostname] = cachedResult; + } else { + urlsToScan.push(hostname); + } + } + + if (urlsToScan.length === 0) { + return { results, errors }; } // The API has a limit of 50 URLs per request, so we batch the requests const MAX_URLS_PER_BATCH = 50; const batches: string[][] = []; - for (let i = 0; i < urls.length; i += MAX_URLS_PER_BATCH) { - batches.push(urls.slice(i, i + MAX_URLS_PER_BATCH)); + for (let i = 0; i < urlsToScan.length; i += MAX_URLS_PER_BATCH) { + batches.push(urlsToScan.slice(i, i + MAX_URLS_PER_BATCH)); } // Process each batch in parallel @@ -707,11 +792,17 @@ export class PhishingController extends BaseController< // Combine all batch results const combinedResponse: BulkPhishingDetectionScanResponse = { - results: {}, - errors: {}, + results, + errors, }; + // Merge results and errors from all batches batchResults.forEach((batchResponse) => { + // Add API results to the cache + Object.entries(batchResponse.results).forEach(([hostname, result]) => { + this.#urlScanCache.add(hostname, result); + }); + Object.assign(combinedResponse.results, batchResponse.results); Object.entries(batchResponse.errors).forEach(([key, messages]) => { combinedResponse.errors[key] = [ diff --git a/packages/phishing-controller/src/UrlScanCache.test.ts b/packages/phishing-controller/src/UrlScanCache.test.ts new file mode 100644 index 00000000000..fe22255abee --- /dev/null +++ b/packages/phishing-controller/src/UrlScanCache.test.ts @@ -0,0 +1,222 @@ +import sinon from 'sinon'; + +import { RecommendedAction } from './types'; +import { UrlScanCache } from './UrlScanCache'; +import * as utils from './utils'; + +describe('UrlScanCache', () => { + let clock: sinon.SinonFakeTimers; + let updateStateSpy: sinon.SinonSpy; + let cache: UrlScanCache; + + beforeEach(() => { + clock = sinon.useFakeTimers(); + sinon + .stub(utils, 'fetchTimeNow') + .callsFake(() => Math.floor(Date.now() / 1000)); + updateStateSpy = sinon.spy(); + cache = new UrlScanCache({ + cacheTTL: 300, // 5 minutes + maxCacheSize: 3, + updateState: updateStateSpy, + }); + }); + + afterEach(() => { + sinon.restore(); + }); + + describe('constructor', () => { + it('should initialize with empty cache when no initialCache provided', () => { + const emptyCache = new UrlScanCache({ + // eslint-disable-next-line no-empty-function + updateState: () => {}, + }); + expect(emptyCache.get('example.com')).toBeUndefined(); + }); + + it('should initialize with provided initialCache data', () => { + const now = Math.floor(Date.now() / 1000); + const initialCache = { + 'example.com': { + result: { + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }, + timestamp: now, + }, + }; + + const cacheWithInitialData = new UrlScanCache({ + initialCache, + // eslint-disable-next-line no-empty-function + updateState: () => {}, + }); + + expect(cacheWithInitialData.get('example.com')).toStrictEqual({ + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }); + }); + }); + + describe('get', () => { + it('returns undefined for non-existent entries', () => { + expect(cache.get('example.com')).toBeUndefined(); + }); + + it('returns valid entries', () => { + const result = { + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }; + + cache.add('example.com', result); + + expect(cache.get('example.com')).toStrictEqual(result); + }); + + it('removes and returns undefined for expired entries', () => { + const result = { + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }; + + cache.add('example.com', result); + + clock.tick(301 * 1000); + + expect(cache.get('example.com')).toBeUndefined(); + + expect(updateStateSpy.callCount).toBe(2); + }); + }); + + describe('add', () => { + it('adds entries to the cache', () => { + const result = { + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }; + + cache.add('example.com', result); + + expect(cache.get('example.com')).toStrictEqual(result); + expect(updateStateSpy.callCount).toBe(1); + }); + + it('evicts oldest entries when exceeding max size', () => { + cache.add('domain1.com', { + domainName: 'domain1.com', + recommendedAction: RecommendedAction.None, + }); + clock.tick(1000); + cache.add('domain2.com', { + domainName: 'domain2.com', + recommendedAction: RecommendedAction.None, + }); + clock.tick(1000); + cache.add('domain3.com', { + domainName: 'domain3.com', + recommendedAction: RecommendedAction.None, + }); + + expect(cache.get('domain1.com')).toBeDefined(); + expect(cache.get('domain2.com')).toBeDefined(); + expect(cache.get('domain3.com')).toBeDefined(); + + cache.add('domain4.com', { + domainName: 'domain4.com', + recommendedAction: RecommendedAction.None, + }); + + expect(cache.get('domain1.com')).toBeUndefined(); + expect(cache.get('domain2.com')).toBeDefined(); + expect(cache.get('domain3.com')).toBeDefined(); + expect(cache.get('domain4.com')).toBeDefined(); + }); + + it('properly handles multiple evictions', () => { + cache.setMaxSize(2); + + cache.add('domain1.com', { + domainName: 'domain1.com', + recommendedAction: RecommendedAction.None, + }); + cache.add('domain2.com', { + domainName: 'domain2.com', + recommendedAction: RecommendedAction.None, + }); + cache.add('domain3.com', { + domainName: 'domain3.com', + recommendedAction: RecommendedAction.None, + }); + + expect(cache.get('domain1.com')).toBeUndefined(); + expect(cache.get('domain2.com')).toBeDefined(); + expect(cache.get('domain3.com')).toBeDefined(); + }); + }); + + describe('clear', () => { + it('removes all entries from the cache', () => { + cache.add('domain1.com', { + domainName: 'domain1.com', + recommendedAction: RecommendedAction.None, + }); + cache.add('domain2.com', { + domainName: 'domain2.com', + recommendedAction: RecommendedAction.None, + }); + + cache.clear(); + + expect(cache.get('domain1.com')).toBeUndefined(); + expect(cache.get('domain2.com')).toBeUndefined(); + + expect(updateStateSpy.callCount).toBe(3); + }); + }); + + describe('setTTL', () => { + it('updates the cache TTL', () => { + const result = { + domainName: 'example.com', + recommendedAction: RecommendedAction.None, + }; + + cache.add('example.com', result); + + cache.setTTL(60); + + clock.tick(61 * 1000); + + expect(cache.get('example.com')).toBeUndefined(); + }); + }); + + describe('setMaxSize', () => { + it('updates the max cache size and evicts entries if needed', () => { + cache.add('domain1.com', { + domainName: 'domain1.com', + recommendedAction: RecommendedAction.None, + }); + clock.tick(1000); + cache.add('domain2.com', { + domainName: 'domain2.com', + recommendedAction: RecommendedAction.None, + }); + clock.tick(1000); + cache.add('domain3.com', { + domainName: 'domain3.com', + recommendedAction: RecommendedAction.None, + }); + + cache.setMaxSize(2); + + expect(cache.get('domain1.com')).toBeUndefined(); + expect(cache.get('domain2.com')).toBeDefined(); + expect(cache.get('domain3.com')).toBeDefined(); + }); + }); +}); diff --git a/packages/phishing-controller/src/UrlScanCache.ts b/packages/phishing-controller/src/UrlScanCache.ts new file mode 100644 index 00000000000..57720b33ccc --- /dev/null +++ b/packages/phishing-controller/src/UrlScanCache.ts @@ -0,0 +1,153 @@ +import type { PhishingDetectionScanResult } from './types'; +import { fetchTimeNow } from './utils'; + +/** + * Cache entry for URL scan results + */ +export type UrlScanCacheEntry = { + result: PhishingDetectionScanResult; + timestamp: number; +}; + +/** + * Default values for URL scan cache + */ +export const DEFAULT_URL_SCAN_CACHE_TTL = 300; // 5 minutes in seconds +export const DEFAULT_URL_SCAN_CACHE_MAX_SIZE = 100; + +/** + * UrlScanCache class + * + * Handles caching of URL scan results with TTL and size limits + */ +export class UrlScanCache { + #cacheTTL: number; + + #maxCacheSize: number; + + readonly #cache: Map; + + readonly #updateState: (cache: Record) => void; + + /** + * Constructor for UrlScanCache + * + * @param options - Cache configuration options + * @param options.cacheTTL - Time to live in seconds for cached entries + * @param options.maxCacheSize - Maximum number of entries in the cache + * @param options.initialCache - Initial cache state + * @param options.updateState - Function to update the state when cache changes + */ + constructor({ + cacheTTL = DEFAULT_URL_SCAN_CACHE_TTL, + maxCacheSize = DEFAULT_URL_SCAN_CACHE_MAX_SIZE, + initialCache = {}, + updateState, + }: { + cacheTTL?: number; + maxCacheSize?: number; + initialCache?: Record; + updateState: (cache: Record) => void; + }) { + this.#cacheTTL = cacheTTL; + this.#maxCacheSize = maxCacheSize; + this.#cache = new Map(Object.entries(initialCache)); + this.#updateState = updateState; + this.#evictEntries(); + } + + /** + * Set the time-to-live for cached entries + * + * @param ttl - The TTL in seconds + */ + setTTL(ttl: number): void { + this.#cacheTTL = ttl; + } + + /** + * Set the maximum cache size + * + * @param maxSize - The maximum cache size + */ + setMaxSize(maxSize: number): void { + this.#maxCacheSize = maxSize; + this.#evictEntries(); + } + + /** + * Clear the cache + */ + clear(): void { + this.#cache.clear(); + this.#persistCache(); + } + + /** + * Get a cached result if it exists and is not expired + * + * @param hostname - The hostname to check + * @returns The cached scan result or undefined if not found or expired + */ + get(hostname: string): PhishingDetectionScanResult | undefined { + const cacheEntry = this.#cache.get(hostname); + if (!cacheEntry) { + return undefined; + } + + // Check if the entry is expired + const now = fetchTimeNow(); + if (now - cacheEntry.timestamp > this.#cacheTTL) { + // Entry expired, remove it from cache + this.#cache.delete(hostname); + this.#persistCache(); + return undefined; + } + + return cacheEntry.result; + } + + /** + * Add an entry to the cache, evicting oldest entries if necessary + * + * @param hostname - The hostname to cache + * @param result - The scan result to cache + */ + add(hostname: string, result: PhishingDetectionScanResult): void { + this.#cache.set(hostname, { + result, + timestamp: fetchTimeNow(), + }); + + this.#evictEntries(); + + this.#persistCache(); + } + + /** + * Persist the current cache state + */ + #persistCache(): void { + this.#updateState(Object.fromEntries(this.#cache)); + } + + /** + * Evict oldest entries if cache exceeds max size + */ + #evictEntries(): void { + if (this.#cache.size <= this.#maxCacheSize) { + return; + } + + const entriesToRemove = this.#cache.size - this.#maxCacheSize; + let count = 0; + // Delete the oldest entries + for (const key of this.#cache.keys()) { + if (count >= entriesToRemove) { + break; + } + this.#cache.delete(key); + count += 1; + } + } +}