Skip to content

Commit

Permalink
Add rendering of hits (and trace and timing etc) in llm rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
Lester Solbakken committed Jun 10, 2024
1 parent 367b751 commit 3e31b1f
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 14 deletions.
2 changes: 1 addition & 1 deletion container-search/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -9227,7 +9227,7 @@
"methods" : [
"public void <init>(ai.vespa.search.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)",
"public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)",
"protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt)",
"protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt, com.yahoo.search.Result, com.yahoo.search.searchchain.Execution)",
"public java.lang.String getPrompt(com.yahoo.search.Query)",
"public java.lang.String getPropertyPrefix()",
"public java.lang.String lookupProperty(java.lang.String, com.yahoo.search.Query)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.rendering.JsonRenderer;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.text.Utf8;

import java.io.ByteArrayOutputStream;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
Expand All @@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher {
private static final String API_KEY_HEADER = "X-LLM-API-KEY";
private static final String STREAM_PROPERTY = "stream";
private static final String PROMPT_PROPERTY = "prompt";
private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt";
private static final String INCLUDE_HITS_IN_RESULT = "includeHits";

private final JsonRenderer jsonRenderer;

private final String propertyPrefix;
private final boolean stream;
Expand All @@ -50,11 +57,13 @@ public LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> la
this.languageModelId = config.providerId();
this.languageModel = findLanguageModel(languageModelId, languageModels);
this.propertyPrefix = config.propertyPrefix();

this.jsonRenderer = new JsonRenderer();
}

@Override
public Result search(Query query, Execution execution) {
return complete(query, StringPrompt.from(getPrompt(query)));
return complete(query, StringPrompt.from(getPrompt(query)), null, execution);
}

private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
Expand All @@ -81,30 +90,37 @@ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<Lan
return languageModel;
}

protected Result complete(Query query, Prompt prompt) {
protected Result complete(Query query, Prompt prompt, Result result, Execution execution) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
try {
return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
if (stream) {
return completeAsync(query, prompt, options, result, execution);
}
return completeSync(query, prompt, options, result, execution);
} catch (RejectedExecutionException e) {
return new Result(query, new ErrorMessage(429, e.getMessage()));
}
}

private boolean shouldAddPrompt(Query query) {
return query.getTrace().getLevel() >= 1;
var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false);
return query.getTrace().getLevel() >= 1 || includePrompt;
}

private boolean shouldAddTokenStats(Query query) {
return query.getTrace().getLevel() >= 1;
}

private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
final EventStream eventStream = new EventStream();

if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
if (shouldAddHits(query) && result != null) {
eventStream.add(renderHits(result, execution), "hits");
}

final TokenStats tokenStats = new TokenStats();
languageModel.completeAsync(prompt, options, completion -> {
Expand Down Expand Up @@ -143,12 +159,15 @@ private void handleException(EventStream eventStream, Throwable exception) {
eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
}

private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
EventStream eventStream = new EventStream();

if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
if (shouldAddHits(query) && result != null) {
eventStream.add(renderHits(result, execution), "hits");
}

List<Completion> completions = languageModel.complete(prompt, options);
eventStream.add(completions.get(0).text(), "completion");
Expand Down Expand Up @@ -200,6 +219,18 @@ public String getApiKeyHeader(Query query) {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}

private boolean shouldAddHits(Query query) {
return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false);
}

private String renderHits(Result results, Execution execution) {
var bs = new ByteArrayOutputStream();
var renderer = jsonRenderer.clone();
renderer.init();
renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete
return Utf8.toString(bs.toByteArray());
}

private static class TokenStats {

private final long start;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> la
public Result search(Query query, Execution execution) {
Result result = execution.search(query);
execution.fill(result);
return complete(query, buildPrompt(query, result));
return complete(query, buildPrompt(query, result), result, execution);
}

protected Prompt buildPrompt(Query query, Result result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ public void data(Data data) throws IOException {
generator.writeRaw("event: " + event.type() + "\n");
}
generator.writeRaw("data: ");
generator.writeStartObject();
generator.writeStringField(event.type(), event.toString());
generator.writeEndObject();
if (event.type().equals("hits")) {
generator.writeRaw(event.toString());
} else {
generator.writeStartObject();
generator.writeStringField(event.type(), event.toString());
generator.writeEndObject();
}
generator.writeRaw("\n\n");
generator.flush();
}
// Todo: support other types of data such as search results (hits), timing and trace
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
return execution.search(query);
}

private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package ai.vespa.search.llm;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.search.Result;
import com.yahoo.search.rendering.EventRenderer;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.text.Utf8;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayOutputStream;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class RAGWithEventRendererTest {

@Test
public void testPromptAndHitsAreRendered() throws Exception {
var params = Map.of(
"query", "why are ducks better than cats?",
"llm.stream", "false",
"llm.includePrompt", "true",
"llm.includeHits", "true"
);
var llm = LLMSearcherTest.createLLMClient();
var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm));
var results = RAGSearcherTest.runMockSearch(searcher, params);

var result = render(results);

var promptEvent = extractEvent(result, "prompt");
assertNotNull(promptEvent);
assertTrue(promptEvent.has("prompt"));

var resultsEvent = extractEvent(result, "hits");
assertNotNull(resultsEvent);
assertTrue(resultsEvent.has("root"));
assertEquals(2, resultsEvent.get("root").get("children").size());
}

private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException {
var lines = result.split("\n");
for (int i = 0; i < lines.length; i++) {
if (lines[i].startsWith("event: " + eventName)) {
var data = lines[i + 1].substring("data: ".length()).trim();
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.readTree(data);
}
}
return null;
}

private String render(Result r) throws InterruptedException, ExecutionException {
var execution = new Execution(Execution.Context.createContextStub());
return render(execution, r);
}

private String render(Execution execution, Result r) throws ExecutionException, InterruptedException {
var renderer = new EventRenderer();
try {
renderer.init();
ByteArrayOutputStream bs = new ByteArrayOutputStream();
CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
assertTrue(f.get());
return Utf8.toString(bs.toByteArray());
} finally {
renderer.deconstruct();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public void testResultRenderingIsSkipped() throws ExecutionException, Interrupte
event: end
""";
assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
assertEquals(expected, result);
}

static HitGroup newHitGroup(EventStream eventStream, String id) {
Expand Down

0 comments on commit 3e31b1f

Please sign in to comment.