Skip to content

Commit 5955887

Browse files
committed
refactor cosmosDB vector store
1 parent 28fc23c commit 5955887

File tree

13 files changed

+169
-1230
lines changed

13 files changed

+169
-1230
lines changed

apps/acme-assist/pom.xml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@
3535
<version>1.0.0-beta.7</version>
3636
</dependency>
3737
<dependency>
38-
<groupId>org.springframework.boot</groupId>
39-
<artifactId>spring-boot-starter-data-mongodb</artifactId>
40-
<version>3.1.2</version>
38+
<groupId>org.springframework.data</groupId>
39+
<artifactId>spring-data-mongodb</artifactId>
40+
</dependency>
41+
<dependency>
42+
<groupId>org.mongodb</groupId>
43+
<artifactId>mongodb-driver-sync</artifactId>
4144
</dependency>
4245
<dependency>
4346
<groupId>org.springframework.ai</groupId>

apps/acme-assist/src/main/java/com/example/acme/assist/AzureOpenAIClient.java

Lines changed: 0 additions & 36 deletions
This file was deleted.

apps/acme-assist/src/main/java/com/example/acme/assist/ChatService.java

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import org.springframework.ai.chat.prompt.Prompt;
1414
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
1515
import org.springframework.ai.document.Document;
16-
import org.springframework.ai.vectorstore.SimpleVectorStore;
16+
import org.springframework.ai.vectorstore.SearchRequest;
17+
import org.springframework.ai.vectorstore.VectorStore;
1718
import org.springframework.beans.factory.annotation.Autowired;
1819
import org.springframework.beans.factory.annotation.Value;
1920
import org.springframework.core.io.Resource;
@@ -29,13 +30,7 @@
2930
public class ChatService {
3031

3132
@Autowired
32-
private SimpleVectorStore store;
33-
34-
@Autowired
35-
private AzureOpenAIClient openAIClient;
36-
37-
@Autowired
38-
private CosmosDBVectorStore cosmosDBVectorStore;
33+
private VectorStore store;
3934

4035
@Autowired
4136
private ProductRepository productRepository;
@@ -49,8 +44,6 @@ public class ChatService {
4944
@Value("classpath:/prompts/chatWithProductId.st")
5045
private Resource chatWithProductIdResource;
5146

52-
@Value("${spring.data.mongodb.enabled}")
53-
private String cosmosEnabled;
5447
/**
5548
* Chat with the OpenAI API. Use the product details as the context.
5649
*
@@ -75,23 +68,11 @@ private List<String> chatWithProductId(Product product, List<AcmeChatRequest.Mes
7568
// We have a specific Product
7669
String question = chatRequestMessages.get(chatRequestMessages.size() - 1).getContent();
7770

78-
var response = openAIClient.getEmbeddings(List.of(question));
79-
var embedding = response.getData().get(0).getEmbedding();
80-
81-
82-
List<Document> candidateDocuments = new ArrayList<>();;
8371
// step 1. Query for documents that are related to the question from the vector store
84-
if (cosmosEnabled.equals("true")) {
85-
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
86-
for (DocEntry docEntry : cosmosVectorStoreDocs) {
87-
Document document = new Document(docEntry.getText());
88-
candidateDocuments.add(document);
89-
}
90-
}
91-
else
92-
{
93-
candidateDocuments = this.store.similaritySearch(question, 5, 0.4);
94-
}
72+
SearchRequest request = SearchRequest.query(question).
73+
withTopK(5).
74+
withSimilarityThreshold(0.4);
75+
List<Document> candidateDocuments = this.store.similaritySearch(request);
9576

9677
// step 2. Create a SystemMessage that contains the product information in addition to related documents.
9778
List<Message> messages = new ArrayList<>();
@@ -113,21 +94,11 @@ protected List<String> chatWithoutProductId(List<AcmeChatRequest.Message> acmeCh
11394

11495
String question = acmeChatRequestMessages.get(acmeChatRequestMessages.size() - 1).getContent();
11596

116-
var response = openAIClient.getEmbeddings(List.of(question));
117-
var embedding = response.getData().get(0).getEmbedding();
118-
11997
// step 1. Query for documents that are related to the question from the vector store
120-
List<Document> relatedDocuments = new ArrayList<>();;
121-
if (cosmosEnabled.equals("true")) {
122-
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
123-
for (DocEntry docEntry : cosmosVectorStoreDocs) {
124-
Document document = new Document(docEntry.getText());
125-
relatedDocuments.add(document);
126-
}
127-
}
128-
else {
129-
relatedDocuments = this.store.similaritySearch(question, 5, 0.4);
130-
}
98+
SearchRequest request = SearchRequest.query(question).
99+
withTopK(5).
100+
withSimilarityThreshold(0.4);
101+
List<Document> relatedDocuments = store.similaritySearch(request);
131102

132103

133104
// step 2. Create the system message with the related documents;
Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,29 @@
11
package com.example.acme.assist.config;
22

3-
import com.azure.ai.openai.OpenAIClientBuilder;
4-
import com.azure.core.credential.AzureKeyCredential;
5-
import com.example.acme.assist.AzureOpenAIClient;
6-
import com.example.acme.assist.vectorstore.CosmosDBVectorStore;
73
import org.springframework.ai.embedding.EmbeddingClient;
84
import org.springframework.ai.vectorstore.SimpleVectorStore;
95
import org.springframework.beans.factory.annotation.Value;
6+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
107
import org.springframework.context.annotation.Bean;
118
import org.springframework.context.annotation.Configuration;
129
import org.springframework.core.io.Resource;
13-
import org.springframework.data.mongodb.core.MongoTemplate;
14-
15-
import java.io.IOException;
1610

1711
@Configuration
1812
public class FitAssistConfiguration {
1913

20-
@Value("${spring.ai.azure.openai.embedding-model}")
21-
private String embeddingDeploymentId;
22-
23-
@Value("${spring.ai.azure.openai.deployment-name}")
24-
private String chatDeploymentId;
25-
26-
@Value("${spring.ai.azure.openai.endpoint}")
27-
private String endpoint;
28-
29-
@Value("${spring.ai.azure.openai.api-key}")
30-
private String apiKey;
31-
32-
@Value("${vector-store.file}")
33-
private String cosmosVectorJsonFile;
34-
35-
@Value("${spring.data.mongodb.enabled}")
36-
private String cosmosEnabled;
37-
38-
//@Autowired
39-
private MongoTemplate mongoTemplate;
40-
public FitAssistConfiguration(MongoTemplate mongoTemplate) {
41-
this.mongoTemplate = mongoTemplate;
14+
public FitAssistConfiguration() {
15+
4216
}
4317

44-
45-
4618
@Value("classpath:/vector_store.json")
4719
private Resource vectorDbResource;
4820

4921
@Bean
22+
@ConditionalOnProperty(value="vectorstore", havingValue = "simple", matchIfMissing = true)
5023
public SimpleVectorStore simpleVectorStore(EmbeddingClient embeddingClient) {
5124
SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient);
52-
if (cosmosEnabled.equals("false")) {
53-
simpleVectorStore.load(vectorDbResource);
54-
}
25+
simpleVectorStore.load(vectorDbResource);
5526
return simpleVectorStore;
5627
}
5728

58-
@Bean
59-
public CosmosDBVectorStore vectorStore() throws IOException {
60-
CosmosDBVectorStore store = null;
61-
if (cosmosEnabled.equals("true")) {
62-
store = new CosmosDBVectorStore(mongoTemplate);
63-
String currentPath = new java.io.File(".").getCanonicalPath();
64-
String path = currentPath + cosmosVectorJsonFile.replace("\\", "//");
65-
store.loadFromJsonFile(path);
66-
}
67-
else {
68-
store = new CosmosDBVectorStore(null);
69-
}
70-
return store;
71-
}
72-
73-
@Bean
74-
public AzureOpenAIClient AzureOpenAIClient() {
75-
var innerClient = new OpenAIClientBuilder()
76-
.endpoint(endpoint)
77-
.credential(new AzureKeyCredential(apiKey))
78-
.buildClient();
79-
return new AzureOpenAIClient(innerClient, embeddingDeploymentId, chatDeploymentId);
80-
}
81-
82-
8329
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package com.example.acme.assist.mongodb;
2+
3+
import java.io.IOException;
4+
import java.util.ArrayList;
5+
import java.util.HashMap;
6+
import java.util.List;
7+
import java.util.Map;
8+
import java.util.Optional;
9+
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
import org.springframework.ai.document.Document;
13+
import org.springframework.ai.embedding.EmbeddingClient;
14+
import org.springframework.ai.vectorstore.SearchRequest;
15+
import org.springframework.ai.vectorstore.VectorStore;
16+
import org.springframework.beans.factory.annotation.Autowired;
17+
import org.springframework.beans.factory.annotation.Value;
18+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
19+
import org.springframework.core.io.Resource;
20+
import org.springframework.data.mongodb.core.MongoTemplate;
21+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
22+
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
23+
import org.springframework.stereotype.Component;
24+
25+
import com.fasterxml.jackson.core.type.TypeReference;
26+
import com.fasterxml.jackson.databind.ObjectMapper;
27+
28+
import jakarta.annotation.PostConstruct;
29+
30+
@Component
31+
@ConditionalOnProperty(value = "vectorstore", havingValue = "mongodb", matchIfMissing = false)
32+
public class CosmosDBVectorStore implements VectorStore {
33+
34+
private static Logger LOGGER = LoggerFactory.getLogger(CosmosDBVectorStore.class);
35+
36+
private static String COLLECTION = "vectorstore";
37+
38+
@Value("classpath:/vector_store.json")
39+
private Resource vectorDbResource;
40+
41+
@Autowired
42+
private MongoTemplate template;
43+
44+
@Autowired
45+
protected EmbeddingClient embeddingClient;
46+
47+
@PostConstruct
48+
protected void init() {
49+
template.dropCollection(COLLECTION);
50+
this.load(vectorDbResource);
51+
LOGGER.info("initialized collection in mongodb");
52+
}
53+
54+
public void load(Resource resource) {
55+
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
56+
};
57+
ObjectMapper objectMapper = new ObjectMapper();
58+
try {
59+
Map<String, Document> docs = objectMapper.readValue(resource.getInputStream(), typeRef);
60+
Optional<Document> doc = docs.values().stream().findFirst();
61+
if (doc.isPresent()) {
62+
int dimensions = doc.get().getEmbedding().size();
63+
template.insert(docs.values(), COLLECTION);
64+
createVectorIndex(5, dimensions, "COS");
65+
}
66+
} catch (IOException ex) {
67+
throw new RuntimeException(ex);
68+
}
69+
}
70+
71+
public void createVectorIndex(int numLists, int dimensions, String similarity) {
72+
String bsonCmd = "{\"createIndexes\":\"" + COLLECTION + "\",\"indexes\":"
73+
+ "[{\"name\":\"vectorsearch\",\"key\":{\"embedding\":\"cosmosSearch\"},\"cosmosSearchOptions\":"
74+
+ "{\"kind\":\"vector-ivf\",\"numLists\":" + numLists + ",\"similarity\":\"" + similarity
75+
+ "\",\"dimensions\":" + dimensions + "}}]}";
76+
LOGGER.info("creating vector index in Cosmos DB Mongo vCore...");
77+
try {
78+
template.executeCommand(bsonCmd);
79+
} catch (Exception e) {
80+
LOGGER.warn("Failed to create vector index in Cosmos DB Mongo vCore", e);
81+
}
82+
}
83+
84+
@Override
85+
public void add(List<Document> documents) {
86+
// TODO Auto-generated method stub
87+
}
88+
89+
@Override
90+
public Optional<Boolean> delete(List<String> idList) {
91+
return Optional.empty();
92+
}
93+
94+
private List<Double> getUserQueryEmbedding(String query) {
95+
return this.embeddingClient.embed(query);
96+
}
97+
98+
@Override
99+
public List<Document> similaritySearch(SearchRequest request) {
100+
List<Double> embedding = getUserQueryEmbedding(request.getQuery());
101+
102+
// perform vector search in Cosmos DB Mongo API - vCore
103+
String command = "{\"$search\":{\"cosmosSearch\":{\"vector\":" + embedding + ",\"path\":\"embedding\",\"k\":"
104+
+ request.getTopK() + "}}}";
105+
Aggregation agg = Aggregation.newAggregation(Aggregation.stage(command));
106+
AggregationResults<org.bson.Document> results = template.aggregate(agg, COLLECTION, org.bson.Document.class);
107+
List<Document> ret = new ArrayList<>();
108+
results.getMappedResults().forEach(bDoc -> {
109+
String content = bDoc.getString("content");
110+
Document doc = new Document(content);
111+
ret.add(doc);
112+
});
113+
return ret;
114+
}
115+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.example.acme.assist.mongodb;
2+
3+
import org.springframework.beans.factory.annotation.Value;
4+
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
5+
import org.springframework.context.annotation.Bean;
6+
import org.springframework.context.annotation.Configuration;
7+
import org.springframework.data.mongodb.core.MongoTemplate;
8+
9+
import com.mongodb.ConnectionString;
10+
import com.mongodb.MongoClientSettings;
11+
import com.mongodb.client.MongoClients;
12+
13+
@Configuration
14+
@ConditionalOnProperty(value = "vectorstore", havingValue = "mongodb", matchIfMissing = false)
15+
public class MongoDBConfiguration {
16+
17+
@Value("${spring.data.mongodb.uri}")
18+
private String url;
19+
20+
@Value("${spring.data.mongodb.database}")
21+
private String database;
22+
23+
@Bean
24+
public MongoTemplate mongoTemplate() {
25+
ConnectionString cs = new ConnectionString(url);
26+
MongoClientSettings settings = MongoClientSettings.builder().applyConnectionString(cs).build();
27+
28+
return new MongoTemplate(MongoClients.create(settings), database);
29+
}
30+
}

0 commit comments

Comments
 (0)