Skip to content

Commit

Permalink
refactor cosmosDB vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
dingmeng-xue committed Mar 20, 2024
1 parent 28fc23c commit 5955887
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 1,230 deletions.
9 changes: 6 additions & 3 deletions apps/acme-assist/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@
<version>1.0.0-beta.7</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-mongodb</artifactId>
<version>3.1.2</version>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb</artifactId>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-driver-sync</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
Expand All @@ -29,13 +30,7 @@
public class ChatService {

@Autowired
private SimpleVectorStore store;

@Autowired
private AzureOpenAIClient openAIClient;

@Autowired
private CosmosDBVectorStore cosmosDBVectorStore;
private VectorStore store;

@Autowired
private ProductRepository productRepository;
Expand All @@ -49,8 +44,6 @@ public class ChatService {
@Value("classpath:/prompts/chatWithProductId.st")
private Resource chatWithProductIdResource;

@Value("${spring.data.mongodb.enabled}")
private String cosmosEnabled;
/**
* Chat with the OpenAI API. Use the product details as the context.
*
Expand All @@ -75,23 +68,11 @@ private List<String> chatWithProductId(Product product, List<AcmeChatRequest.Mes
// We have a specific Product
String question = chatRequestMessages.get(chatRequestMessages.size() - 1).getContent();

var response = openAIClient.getEmbeddings(List.of(question));
var embedding = response.getData().get(0).getEmbedding();


List<Document> candidateDocuments = new ArrayList<>();;
// step 1. Query for documents that are related to the question from the vector store
if (cosmosEnabled.equals("true")) {
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
for (DocEntry docEntry : cosmosVectorStoreDocs) {
Document document = new Document(docEntry.getText());
candidateDocuments.add(document);
}
}
else
{
candidateDocuments = this.store.similaritySearch(question, 5, 0.4);
}
SearchRequest request = SearchRequest.query(question).
withTopK(5).
withSimilarityThreshold(0.4);
List<Document> candidateDocuments = this.store.similaritySearch(request);

// step 2. Create a SystemMessage that contains the product information in addition to related documents.
List<Message> messages = new ArrayList<>();
Expand All @@ -113,21 +94,11 @@ protected List<String> chatWithoutProductId(List<AcmeChatRequest.Message> acmeCh

String question = acmeChatRequestMessages.get(acmeChatRequestMessages.size() - 1).getContent();

var response = openAIClient.getEmbeddings(List.of(question));
var embedding = response.getData().get(0).getEmbedding();

// step 1. Query for documents that are related to the question from the vector store
List<Document> relatedDocuments = new ArrayList<>();;
if (cosmosEnabled.equals("true")) {
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
for (DocEntry docEntry : cosmosVectorStoreDocs) {
Document document = new Document(docEntry.getText());
relatedDocuments.add(document);
}
}
else {
relatedDocuments = this.store.similaritySearch(question, 5, 0.4);
}
SearchRequest request = SearchRequest.query(question).
withTopK(5).
withSimilarityThreshold(0.4);
List<Document> relatedDocuments = store.similaritySearch(request);


// step 2. Create the system message with the related documents;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,83 +1,29 @@
package com.example.acme.assist.config;

import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.example.acme.assist.AzureOpenAIClient;
import com.example.acme.assist.vectorstore.CosmosDBVectorStore;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.data.mongodb.core.MongoTemplate;

import java.io.IOException;

@Configuration
public class FitAssistConfiguration {

@Value("${spring.ai.azure.openai.embedding-model}")
private String embeddingDeploymentId;

@Value("${spring.ai.azure.openai.deployment-name}")
private String chatDeploymentId;

@Value("${spring.ai.azure.openai.endpoint}")
private String endpoint;

@Value("${spring.ai.azure.openai.api-key}")
private String apiKey;

@Value("${vector-store.file}")
private String cosmosVectorJsonFile;

@Value("${spring.data.mongodb.enabled}")
private String cosmosEnabled;

//@Autowired
private MongoTemplate mongoTemplate;
public FitAssistConfiguration(MongoTemplate mongoTemplate) {
this.mongoTemplate = mongoTemplate;
public FitAssistConfiguration() {

}



@Value("classpath:/vector_store.json")
private Resource vectorDbResource;

@Bean
@ConditionalOnProperty(value="vectorstore", havingValue = "simple", matchIfMissing = true)
public SimpleVectorStore simpleVectorStore(EmbeddingClient embeddingClient) {
SimpleVectorStore simpleVectorStore = new SimpleVectorStore(embeddingClient);
if (cosmosEnabled.equals("false")) {
simpleVectorStore.load(vectorDbResource);
}
simpleVectorStore.load(vectorDbResource);
return simpleVectorStore;
}

@Bean
public CosmosDBVectorStore vectorStore() throws IOException {
CosmosDBVectorStore store = null;
if (cosmosEnabled.equals("true")) {
store = new CosmosDBVectorStore(mongoTemplate);
String currentPath = new java.io.File(".").getCanonicalPath();
String path = currentPath + cosmosVectorJsonFile.replace("\\", "//");
store.loadFromJsonFile(path);
}
else {
store = new CosmosDBVectorStore(null);
}
return store;
}

@Bean
public AzureOpenAIClient AzureOpenAIClient() {
var innerClient = new OpenAIClientBuilder()
.endpoint(endpoint)
.credential(new AzureKeyCredential(apiKey))
.buildClient();
return new AzureOpenAIClient(innerClient, embeddingDeploymentId, chatDeploymentId);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package com.example.acme.assist.mongodb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.core.io.Resource;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import jakarta.annotation.PostConstruct;

@Component
@ConditionalOnProperty(value = "vectorstore", havingValue = "mongodb", matchIfMissing = false)
public class CosmosDBVectorStore implements VectorStore {

private static Logger LOGGER = LoggerFactory.getLogger(CosmosDBVectorStore.class);

private static String COLLECTION = "vectorstore";

@Value("classpath:/vector_store.json")
private Resource vectorDbResource;

@Autowired
private MongoTemplate template;

@Autowired
protected EmbeddingClient embeddingClient;

@PostConstruct
protected void init() {
template.dropCollection(COLLECTION);
this.load(vectorDbResource);
LOGGER.info("initialized collection in mongodb");
}

public void load(Resource resource) {
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
};
ObjectMapper objectMapper = new ObjectMapper();
try {
Map<String, Document> docs = objectMapper.readValue(resource.getInputStream(), typeRef);
Optional<Document> doc = docs.values().stream().findFirst();
if (doc.isPresent()) {
int dimensions = doc.get().getEmbedding().size();
template.insert(docs.values(), COLLECTION);
createVectorIndex(5, dimensions, "COS");
}
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}

public void createVectorIndex(int numLists, int dimensions, String similarity) {
String bsonCmd = "{\"createIndexes\":\"" + COLLECTION + "\",\"indexes\":"
+ "[{\"name\":\"vectorsearch\",\"key\":{\"embedding\":\"cosmosSearch\"},\"cosmosSearchOptions\":"
+ "{\"kind\":\"vector-ivf\",\"numLists\":" + numLists + ",\"similarity\":\"" + similarity
+ "\",\"dimensions\":" + dimensions + "}}]}";
LOGGER.info("creating vector index in Cosmos DB Mongo vCore...");
try {
template.executeCommand(bsonCmd);
} catch (Exception e) {
LOGGER.warn("Failed to create vector index in Cosmos DB Mongo vCore", e);
}
}

@Override
public void add(List<Document> documents) {
// TODO Auto-generated method stub
}

@Override
public Optional<Boolean> delete(List<String> idList) {
return Optional.empty();
}

private List<Double> getUserQueryEmbedding(String query) {
return this.embeddingClient.embed(query);
}

@Override
public List<Document> similaritySearch(SearchRequest request) {
List<Double> embedding = getUserQueryEmbedding(request.getQuery());

// perform vector search in Cosmos DB Mongo API - vCore
String command = "{\"$search\":{\"cosmosSearch\":{\"vector\":" + embedding + ",\"path\":\"embedding\",\"k\":"
+ request.getTopK() + "}}}";
Aggregation agg = Aggregation.newAggregation(Aggregation.stage(command));
AggregationResults<org.bson.Document> results = template.aggregate(agg, COLLECTION, org.bson.Document.class);
List<Document> ret = new ArrayList<>();
results.getMappedResults().forEach(bDoc -> {
String content = bDoc.getString("content");
Document doc = new Document(content);
ret.add(doc);
});
return ret;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.example.acme.assist.mongodb;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.mongodb.core.MongoTemplate;

import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.client.MongoClients;

@Configuration
@ConditionalOnProperty(value = "vectorstore", havingValue = "mongodb", matchIfMissing = false)
public class MongoDBConfiguration {

@Value("${spring.data.mongodb.uri}")
private String url;

@Value("${spring.data.mongodb.database}")
private String database;

@Bean
public MongoTemplate mongoTemplate() {
ConnectionString cs = new ConnectionString(url);
MongoClientSettings settings = MongoClientSettings.builder().applyConnectionString(cs).build();

return new MongoTemplate(MongoClients.create(settings), database);
}
}
Loading

0 comments on commit 5955887

Please sign in to comment.