11package com.grepp.quizy.matching.infra.match.repository
22
3- import com.grepp.quizy.matching.domain.match.MATCHING_K
43import com.grepp.quizy.matching.domain.match.MatchingPoolRepository
54import com.grepp.quizy.matching.domain.match.UserStatus
6- import com.grepp.quizy.matching.infra.match.converter.toByteArray
7- import com.grepp.quizy.matching.infra.match.converter.toFloatArray
85import com.grepp.quizy.matching.domain.user.UserId
96import com.grepp.quizy.matching.domain.user.UserVector
10- import org.springframework.data.redis.core.RedisTemplate
7+ import com.grepp.quizy.matching.infra.redis.util.toByteArray
8+ import com.grepp.quizy.matching.infra.redis.util.toFloatArray
9+ import io.github.oshai.kotlinlogging.KotlinLogging
10+ import jakarta.annotation.PostConstruct
1111import org.springframework.stereotype.Repository
12- import kotlin.math.sqrt
12+ import redis.clients.jedis.JedisPooled
13+ import redis.clients.jedis.search.IndexDefinition
14+ import redis.clients.jedis.search.IndexOptions
15+ import redis.clients.jedis.search.Query
16+ import redis.clients.jedis.search.Schema
1317
14- private const val MATCHING_THRESHOLD = 0.5
15- private const val MATCHING_POOL_KEY = " MATCHING_POOL"
18+ private const val MATCHING_INDEX_PREFIX = " MATCHING_VECTOR_POOL:"
19+ private const val MATCHING_INDEX = " MATCHING_INDEX"
20+ private const val MATCHING_K = 5
21+
22+ private const val ID_FIELD = " id"
23+ private const val VECTOR_FIELD = " vector"
1624
1725@Repository
1826class MatchingPoolRepositoryAdapter (
19- private val redisTemplate : RedisTemplate < String , Any >,
27+ private val jedis : JedisPooled
2028) : MatchingPoolRepository {
29+ private val log = KotlinLogging .logger {}
2130
22- override fun saveVector (userId : UserId , userVector : UserVector ) {
23- val key = createMatchingPoolKey(userId)
24- redisTemplate.opsForHash<String , Any >().putAll(
25- key, mapOf (
26- " id" to userId.value,
27- " vector" to userVector.value.toByteArray()
28- )
29- )
31+ @PostConstruct
32+ fun initialize () {
33+ log.info { " Initializing RediSearchRepository" }
34+ createIndex()
3035 }
3136
32- override fun findNearestUser (userStatus : UserStatus ): List <UserStatus > {
33- val keyPattern = " $MATCHING_POOL_KEY :*"
34- val keys = redisTemplate.keys(keyPattern) ? : return emptyList()
37+ private fun createIndex () {
38+ val definition = IndexDefinition ().setPrefixes(MATCHING_INDEX_PREFIX )
39+ val vectorAttr = mapOf (
40+ Pair (" TYPE" , " FLOAT32" ),
41+ Pair (" DIM" , 15 ),
42+ Pair (" DISTANCE_METRIC" , " COSINE" ),
43+ )
44+ val schema = Schema ()
45+ .addNumericField(ID_FIELD )
46+ .addHNSWVectorField(VECTOR_FIELD , vectorAttr)
3547
36- val results = keys.mapNotNull { key ->
37- val vectorBytes = redisTemplate.opsForHash<String , Any >().get(key, " vector" ) as ByteArray?
38- val id = redisTemplate.opsForHash<String , Any >().get(key, " id" ) as Long?
39- if (vectorBytes != null && id != null ) {
40- val vector = vectorBytes.toFloatArray()
41- val score = cosineSimilarity(userStatus.vector.value, vector)
42- Triple (id, vector, score) // score도 포함
43- } else null
48+ try {
49+ jedis.ftCreate(MATCHING_INDEX , IndexOptions .defaultOptions().setDefinition(definition), schema)
50+ } catch (e: Exception ) {
51+ log.error { " Could not create index: ${e.message} " }
4452 }
45-
46- return results
47- .filter { it.third >= MATCHING_THRESHOLD }
48- .sortedByDescending { it.third } // score를 기준으로 정렬
49- .take(MATCHING_K ) // 상위 K개 선택
50- .map { (id, vector, _) ->
51- UserStatus (
52- userId = UserId (id),
53- vector = UserVector (vector)
54- )
55- }
5653 }
5754
58- override fun remove (userId : UserId ) {
59- redisTemplate.delete(createMatchingPoolKey(userId))
55+ override fun saveVector (userStatus : UserStatus ) {
56+ val key = " $MATCHING_INDEX_PREFIX${userStatus.userId.value} "
57+ jedis.hset(key, ID_FIELD , userStatus.userId.value.toString())
58+ jedis.hset(key.toByteArray(), VECTOR_FIELD .toByteArray(), userStatus.vector.value.toByteArray())
6059 }
6160
62- private fun createMatchingPoolKey (userId : UserId ) = " $MATCHING_POOL_KEY :${userId.value} "
61+ override fun findNearestUser (pivot : UserStatus ): List <UserStatus > {
62+ val query = Query (" *=>[KNN ${' $' } K @$VECTOR_FIELD ${' $' } query_vector]" )
63+ .returnFields(ID_FIELD )
64+ .addParam(" K" , MATCHING_K )
65+ .addParam(" query_vector" , pivot.vector.value.toByteArray())
66+ .dialect(2 )
67+ val docs = jedis.ftSearch(MATCHING_INDEX , query).documents
6368
64- private fun cosineSimilarity (vec1 : FloatArray , vec2 : FloatArray ): Double {
65- val dotProduct = vec1.indices.fold(0.0 ) { acc, i -> acc + vec1[i] * vec2[i] }
66- val magnitude1 = sqrt(vec1.fold(0.0 ) { acc, v -> acc + v * v })
67- val magnitude2 = sqrt(vec2.fold(0.0 ) { acc, v -> acc + v * v })
68- return dotProduct / (magnitude1 * magnitude2)
69+ return docs.map { doc ->
70+ val id = doc.get(ID_FIELD ) as String
71+ val vector = jedis.hget(doc.id.toByteArray(), VECTOR_FIELD .toByteArray())
72+ UserStatus (
73+ userId = UserId (id.toLong()),
74+ vector = UserVector (vector.toFloatArray())
75+ )
76+ }
77+ }
78+
79+ override fun remove (userId : UserId ) {
80+ jedis.del(" $MATCHING_INDEX_PREFIX${userId.value} " )
6981 }
7082}
0 commit comments