diff --git a/src/main/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderReranker.java b/src/main/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderReranker.java index ddb9dbf..9aa8ab2 100644 --- a/src/main/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderReranker.java +++ b/src/main/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderReranker.java @@ -4,7 +4,6 @@ import java.util.Comparator; import java.util.List; -import java.util.stream.Collectors; /** * Default no-op Cross-Encoder reranker that preserves prior order by score. @@ -16,6 +15,6 @@ public List rerank(String query, List candidates) { return candidates.stream() .sorted(Comparator.comparingDouble(Candidate::priorScore).reversed()) .map(Candidate::id) - .collect(Collectors.toList()); + .toList(); } } diff --git a/src/test/java/spring/memewikibe/infrastructure/ai/NaverCrossEncoderRerankerTest.java b/src/test/java/spring/memewikibe/infrastructure/ai/NaverCrossEncoderRerankerTest.java new file mode 100644 index 0000000..8738566 --- /dev/null +++ b/src/test/java/spring/memewikibe/infrastructure/ai/NaverCrossEncoderRerankerTest.java @@ -0,0 +1,401 @@ +package spring.memewikibe.infrastructure.ai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.HttpEntity; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestTemplate; +import spring.memewikibe.annotation.UnitTest; +import spring.memewikibe.infrastructure.ai.CrossEncoderReranker.Candidate; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@UnitTest +@ExtendWith(MockitoExtension.class) +@DisplayName("NaverCrossEncoderReranker 단위 테스트") +class NaverCrossEncoderRerankerTest { + + @Mock + private RestTemplate mockRestTemplate; + + private ObjectMapper objectMapper; + private NaverCrossEncoderReranker sut; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + sut = new NaverCrossEncoderReranker(mockRestTemplate, objectMapper); + + // Set default config values + ReflectionTestUtils.setField(sut, "naverApiKey", "test-api-key"); + ReflectionTestUtils.setField(sut, "naverRequestId", "test-request-id"); + ReflectionTestUtils.setField(sut, "rerankerApiEndpoint", "https://test.api.endpoint"); + } + + @Test + @DisplayName("rerank: AI API가 문서를 재정렬하여 반환") + void rerank_succeeds_withValidApiResponse() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.5) + ); + + String apiResponse = """ + { + "result": { + "citedDocuments": [ + {"id": "3"}, + {"id": "1"}, + {"id": "2"} + ] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(3L, 1L, 2L); + verify(mockRestTemplate).postForObject(any(String.class), any(HttpEntity.class), eq(String.class)); + } + + @Test + @DisplayName("rerank: AI가 일부 문서만 인용한 경우 나머지는 원래 순서로 뒤에 배치") + void rerank_succeeds_appendsUnrankedDocumentsAfterRankedOnes() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.8), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.7), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.6), + new Candidate(4L, "제목4", "사용맥락4", "#태그4", 0.5) + ); + + // AI가 2번과 4번만 선택 + String apiResponse = """ + { + "result": { + "citedDocuments": [ + {"id": "2"}, + {"id": "4"} + ] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + List result = sut.rerank(query, candidates); + + // then - AI가 선택한 2, 4가 먼저 오고, 나머지 1, 3이 원래 순서로 뒤에 옴 + assertThat(result).containsExactly(2L, 4L, 1L, 3L); + } + + @Test + @DisplayName("rerank: AI가 문서를 인용하지 않은 경우 fallback 순서 반환") + void rerank_fallsBack_whenNoDocumentsCited() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.5) + ); + + String apiResponse = """ + { + "result": { + "citedDocuments": [] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + List result = sut.rerank(query, candidates); + + // then - priorScore 순으로 정렬 (0.8, 0.5, 0.3) + assertThat(result).containsExactly(2L, 3L, 1L); + } + + @Test + @DisplayName("rerank: null 쿼리인 경우 원래 순서 반환") + void rerank_returnsOriginalOrder_withNullQuery() { + // given + String query = null; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(1L, 2L); + } + + @Test + @DisplayName("rerank: 빈 쿼리인 경우 원래 순서 반환") + void rerank_returnsOriginalOrder_withBlankQuery() { + // given + String query = " "; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(1L, 2L); + } + + @Test + @DisplayName("rerank: 빈 후보 리스트인 경우 빈 리스트 반환") + void rerank_returnsEmptyList_withEmptyCandidates() { + // given + String query = "재미있는 밈"; + List candidates = List.of(); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).isEmpty(); + } + + @Test + @DisplayName("rerank: API 키가 설정되지 않은 경우 원래 순서 반환") + void rerank_returnsOriginalOrder_withoutApiKey() { + // given + ReflectionTestUtils.setField(sut, "naverApiKey", ""); + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(1L, 2L); + } + + @Test + @DisplayName("rerank: API 호출 실패 시 fallback 순서 반환") + void rerank_fallsBack_onApiError() { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.5) + ); + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenThrow(new RestClientException("API 호출 실패")); + + // when + List result = sut.rerank(query, candidates); + + // then - priorScore 순으로 정렬 (0.8, 0.5, 0.3) + assertThat(result).containsExactly(2L, 3L, 1L); + } + + @Test + @DisplayName("rerank: 잘못된 JSON 응답 시 fallback 순서 반환") + void rerank_fallsBack_onInvalidJsonResponse() { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn("{ invalid json }"); + + // when + List result = sut.rerank(query, candidates); + + // then - priorScore 순으로 정렬 + assertThat(result).containsExactly(2L, 1L); + } + + @Test + @DisplayName("rerank: API 호출 시 올바른 요청 본문 전송") + void rerank_sendsCorrectRequestBody() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + String apiResponse = """ + { + "result": { + "citedDocuments": [ + {"id": "2"}, + {"id": "1"} + ] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + sut.rerank(query, candidates); + + // then + ArgumentCaptor captor = ArgumentCaptor.forClass(HttpEntity.class); + verify(mockRestTemplate).postForObject(any(String.class), captor.capture(), eq(String.class)); + + HttpEntity> capturedEntity = captor.getValue(); + Map body = capturedEntity.getBody(); + + assertThat(body).isNotNull(); + assertThat(body.get("query")).isEqualTo(query); + assertThat(body.get("documents")).isInstanceOf(List.class); + + List documents = (List) body.get("documents"); + assertThat(documents).hasSize(2); + } + + @Test + @DisplayName("rerank: API 호출 시 올바른 헤더 전송") + void rerank_sendsCorrectHeaders() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3) + ); + + String apiResponse = """ + { + "result": { + "citedDocuments": [ + {"id": "1"} + ] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + sut.rerank(query, candidates); + + // then + ArgumentCaptor captor = ArgumentCaptor.forClass(HttpEntity.class); + verify(mockRestTemplate).postForObject(any(String.class), captor.capture(), eq(String.class)); + + HttpEntity capturedEntity = captor.getValue(); + assertThat(capturedEntity.getHeaders().getContentType()).hasToString("application/json"); + assertThat(capturedEntity.getHeaders().get("Authorization")).containsExactly("Bearer test-api-key"); + assertThat(capturedEntity.getHeaders().get("X-NCP-CLOVASTUDIO-REQUEST-ID")).containsExactly("test-request-id"); + } + + @Test + @DisplayName("rerank: 후보의 title과 usageContext를 조합하여 문서 생성") + void rerank_combinesTitleAndUsageContext() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "밈 제목", "사용 맥락", "#태그", 0.8) + ); + + String apiResponse = """ + { + "result": { + "citedDocuments": [ + {"id": "1"} + ] + } + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + sut.rerank(query, candidates); + + // then + ArgumentCaptor captor = ArgumentCaptor.forClass(HttpEntity.class); + verify(mockRestTemplate).postForObject(any(String.class), captor.capture(), eq(String.class)); + + HttpEntity> capturedEntity = captor.getValue(); + Map body = capturedEntity.getBody(); + + assertThat(body).isNotNull(); + assertThat(body.get("documents")).isInstanceOf(List.class); + + List documents = (List) body.get("documents"); + assertThat(documents).hasSize(1); + + // Verify the document structure by serializing to JSON and back + String documentsJson = objectMapper.writeValueAsString(documents.get(0)); + Map firstDocAsMap = objectMapper.readValue(documentsJson, Map.class); + + assertThat(firstDocAsMap.get("id")).isEqualTo("1"); + assertThat(firstDocAsMap.get("doc")).isEqualTo("밈 제목. 사용 맥락"); + } + + @Test + @DisplayName("rerank: null result 응답 시 fallback 순서 반환") + void rerank_fallsBack_withNullResult() throws Exception { + // given + String query = "재미있는 밈"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + String apiResponse = """ + { + "result": null + } + """; + + when(mockRestTemplate.postForObject(any(String.class), any(HttpEntity.class), eq(String.class))) + .thenReturn(apiResponse); + + // when + List result = sut.rerank(query, candidates); + + // then - priorScore 순으로 정렬 + assertThat(result).containsExactly(2L, 1L); + } +} diff --git a/src/test/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderRerankerTest.java b/src/test/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderRerankerTest.java new file mode 100644 index 0000000..50df2d3 --- /dev/null +++ b/src/test/java/spring/memewikibe/infrastructure/ai/NoopCrossEncoderRerankerTest.java @@ -0,0 +1,204 @@ +package spring.memewikibe.infrastructure.ai; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import spring.memewikibe.annotation.UnitTest; +import spring.memewikibe.infrastructure.ai.CrossEncoderReranker.Candidate; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@UnitTest +@DisplayName("NoopCrossEncoderReranker 단위 테스트") +class NoopCrossEncoderRerankerTest { + + private NoopCrossEncoderReranker sut; + + @BeforeEach + void setUp() { + sut = new NoopCrossEncoderReranker(); + } + + @Test + @DisplayName("rerank: 후보들을 priorScore 기준으로 내림차순 정렬하여 반환") + void rerank_succeeds_sortsByPriorScoreDescending() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.5), + new Candidate(4L, "제목4", "사용맥락4", "#태그4", 0.9) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(4L, 2L, 3L, 1L); + } + + @Test + @DisplayName("rerank: 동일한 priorScore를 가진 후보들의 순서 보존") + void rerank_succeeds_preservesOrderForEqualScores() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.5), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.5), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.5) + ); + + // when + List result = sut.rerank(query, candidates); + + // then - stable sort preserves input order for equal scores + assertThat(result).containsExactly(1L, 2L, 3L); + } + + @Test + @DisplayName("rerank: 단일 후보도 정상 처리") + void rerank_succeeds_withSingleCandidate() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.7) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(1L); + } + + @Test + @DisplayName("rerank: 빈 후보 리스트도 정상 처리") + void rerank_succeeds_withEmptyCandidates() { + // given + String query = "테스트 쿼리"; + List candidates = List.of(); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).isEmpty(); + } + + @Test + @DisplayName("rerank: null 쿼리도 정상 처리") + void rerank_succeeds_withNullQuery() { + // given + String query = null; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.3), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.8) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(2L, 1L); + } + + @Test + @DisplayName("rerank: 빈 쿼리도 정상 처리") + void rerank_succeeds_withBlankQuery() { + // given + String query = " "; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.4), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.6) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(2L, 1L); + } + + @Test + @DisplayName("rerank: 음수 priorScore도 정상 처리") + void rerank_succeeds_withNegativePriorScores() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", -0.5), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.2), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", -0.1) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(2L, 3L, 1L); + } + + @Test + @DisplayName("rerank: 0.0 priorScore도 정상 처리") + void rerank_succeeds_withZeroPriorScores() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.0), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.0), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.1) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(3L, 1L, 2L); + } + + @Test + @DisplayName("rerank: 매우 큰 priorScore 값도 정상 처리") + void rerank_succeeds_withLargePriorScores() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 999.9), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 1000.0), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 999.8) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).containsExactly(2L, 1L, 3L); + } + + @Test + @DisplayName("rerank: 많은 수의 후보들도 정상 처리") + void rerank_succeeds_withManyCandidates() { + // given + String query = "테스트 쿼리"; + List candidates = List.of( + new Candidate(1L, "제목1", "사용맥락1", "#태그1", 0.1), + new Candidate(2L, "제목2", "사용맥락2", "#태그2", 0.2), + new Candidate(3L, "제목3", "사용맥락3", "#태그3", 0.3), + new Candidate(4L, "제목4", "사용맥락4", "#태그4", 0.4), + new Candidate(5L, "제목5", "사용맥락5", "#태그5", 0.5), + new Candidate(6L, "제목6", "사용맥락6", "#태그6", 0.6), + new Candidate(7L, "제목7", "사용맥락7", "#태그7", 0.7), + new Candidate(8L, "제목8", "사용맥락8", "#태그8", 0.8), + new Candidate(9L, "제목9", "사용맥락9", "#태그9", 0.9), + new Candidate(10L, "제목10", "사용맥락10", "#태그10", 1.0) + ); + + // when + List result = sut.rerank(query, candidates); + + // then + assertThat(result).hasSize(10); + assertThat(result).containsExactly(10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L); + } +}