Skip to content

Commit 3e31b1f

Browse files
committed
Add rendering of hits (and trace and timing etc) in llm rendering
1 parent 367b751 commit 3e31b1f

File tree

7 files changed

+125
-14
lines changed

7 files changed

+125
-14
lines changed

container-search/abi-spec.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9227,7 +9227,7 @@
92279227
"methods" : [
92289228
"public void <init>(ai.vespa.search.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)",
92299229
"public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)",
9230-
"protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt)",
9230+
"protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt, com.yahoo.search.Result, com.yahoo.search.searchchain.Execution)",
92319231
"public java.lang.String getPrompt(com.yahoo.search.Query)",
92329232
"public java.lang.String getPropertyPrefix()",
92339233
"public java.lang.String lookupProperty(java.lang.String, com.yahoo.search.Query)",

container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
import com.yahoo.search.Query;
1515
import com.yahoo.search.Result;
1616
import com.yahoo.search.Searcher;
17+
import com.yahoo.search.rendering.JsonRenderer;
1718
import com.yahoo.search.result.ErrorMessage;
1819
import com.yahoo.search.result.EventStream;
1920
import com.yahoo.search.result.HitGroup;
2021
import com.yahoo.search.searchchain.Execution;
22+
import com.yahoo.text.Utf8;
2123

24+
import java.io.ByteArrayOutputStream;
2225
import java.util.List;
2326
import java.util.concurrent.RejectedExecutionException;
2427
import java.util.function.Function;
@@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher {
3841
private static final String API_KEY_HEADER = "X-LLM-API-KEY";
3942
private static final String STREAM_PROPERTY = "stream";
4043
private static final String PROMPT_PROPERTY = "prompt";
44+
private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt";
45+
private static final String INCLUDE_HITS_IN_RESULT = "includeHits";
46+
47+
private final JsonRenderer jsonRenderer;
4148

4249
private final String propertyPrefix;
4350
private final boolean stream;
@@ -50,11 +57,13 @@ public LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> la
5057
this.languageModelId = config.providerId();
5158
this.languageModel = findLanguageModel(languageModelId, languageModels);
5259
this.propertyPrefix = config.propertyPrefix();
60+
61+
this.jsonRenderer = new JsonRenderer();
5362
}
5463

5564
@Override
5665
public Result search(Query query, Execution execution) {
57-
return complete(query, StringPrompt.from(getPrompt(query)));
66+
return complete(query, StringPrompt.from(getPrompt(query)), null, execution);
5867
}
5968

6069
private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
@@ -81,30 +90,37 @@ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<Lan
8190
return languageModel;
8291
}
8392

84-
protected Result complete(Query query, Prompt prompt) {
93+
protected Result complete(Query query, Prompt prompt, Result result, Execution execution) {
8594
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
8695
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
8796
try {
88-
return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
97+
if (stream) {
98+
return completeAsync(query, prompt, options, result, execution);
99+
}
100+
return completeSync(query, prompt, options, result, execution);
89101
} catch (RejectedExecutionException e) {
90102
return new Result(query, new ErrorMessage(429, e.getMessage()));
91103
}
92104
}
93105

94106
private boolean shouldAddPrompt(Query query) {
95-
return query.getTrace().getLevel() >= 1;
107+
var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false);
108+
return query.getTrace().getLevel() >= 1 || includePrompt;
96109
}
97110

98111
private boolean shouldAddTokenStats(Query query) {
99112
return query.getTrace().getLevel() >= 1;
100113
}
101114

102-
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
115+
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
103116
final EventStream eventStream = new EventStream();
104117

105118
if (shouldAddPrompt(query)) {
106119
eventStream.add(prompt.asString(), "prompt");
107120
}
121+
if (shouldAddHits(query) && result != null) {
122+
eventStream.add(renderHits(result, execution), "hits");
123+
}
108124

109125
final TokenStats tokenStats = new TokenStats();
110126
languageModel.completeAsync(prompt, options, completion -> {
@@ -143,12 +159,15 @@ private void handleException(EventStream eventStream, Throwable exception) {
143159
eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
144160
}
145161

146-
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
162+
private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
147163
EventStream eventStream = new EventStream();
148164

149165
if (shouldAddPrompt(query)) {
150166
eventStream.add(prompt.asString(), "prompt");
151167
}
168+
if (shouldAddHits(query) && result != null) {
169+
eventStream.add(renderHits(result, execution), "hits");
170+
}
152171

153172
List<Completion> completions = languageModel.complete(prompt, options);
154173
eventStream.add(completions.get(0).text(), "completion");
@@ -200,6 +219,18 @@ public String getApiKeyHeader(Query query) {
200219
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
201220
}
202221

222+
private boolean shouldAddHits(Query query) {
223+
return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false);
224+
}
225+
226+
private String renderHits(Result results, Execution execution) {
227+
var bs = new ByteArrayOutputStream();
228+
var renderer = jsonRenderer.clone();
229+
renderer.init();
230+
renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete
231+
return Utf8.toString(bs.toByteArray());
232+
}
233+
203234
private static class TokenStats {
204235

205236
private final long start;

container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> la
3737
public Result search(Query query, Execution execution) {
3838
Result result = execution.search(query);
3939
execution.fill(result);
40-
return complete(query, buildPrompt(query, result));
40+
return complete(query, buildPrompt(query, result), result, execution);
4141
}
4242

4343
protected Prompt buildPrompt(Query query, Result result) {

container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,16 @@ public void data(Data data) throws IOException {
7979
generator.writeRaw("event: " + event.type() + "\n");
8080
}
8181
generator.writeRaw("data: ");
82-
generator.writeStartObject();
83-
generator.writeStringField(event.type(), event.toString());
84-
generator.writeEndObject();
82+
if (event.type().equals("hits")) {
83+
generator.writeRaw(event.toString());
84+
} else {
85+
generator.writeStartObject();
86+
generator.writeStringField(event.type(), event.toString());
87+
generator.writeEndObject();
88+
}
8589
generator.writeRaw("\n\n");
8690
generator.flush();
8791
}
88-
// Todo: support other types of data such as search results (hits), timing and trace
8992
}
9093

9194
@Override

container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
115115
return execution.search(query);
116116
}
117117

118-
private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
118+
static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
119119
var config = new LlmSearcherConfig.Builder().stream(false).build();
120120
ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
121121
llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package ai.vespa.search.llm;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import com.fasterxml.jackson.databind.JsonNode;
5+
import com.fasterxml.jackson.databind.ObjectMapper;
6+
import com.yahoo.search.Result;
7+
import com.yahoo.search.rendering.EventRenderer;
8+
import com.yahoo.search.searchchain.Execution;
9+
import com.yahoo.text.Utf8;
10+
import org.junit.jupiter.api.Test;
11+
12+
import java.io.ByteArrayOutputStream;
13+
import java.util.Map;
14+
import java.util.concurrent.CompletableFuture;
15+
import java.util.concurrent.ExecutionException;
16+
17+
import static org.junit.jupiter.api.Assertions.assertEquals;
18+
import static org.junit.jupiter.api.Assertions.assertNotNull;
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
20+
21+
public class RAGWithEventRendererTest {
22+
23+
@Test
24+
public void testPromptAndHitsAreRendered() throws Exception {
25+
var params = Map.of(
26+
"query", "why are ducks better than cats?",
27+
"llm.stream", "false",
28+
"llm.includePrompt", "true",
29+
"llm.includeHits", "true"
30+
);
31+
var llm = LLMSearcherTest.createLLMClient();
32+
var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm));
33+
var results = RAGSearcherTest.runMockSearch(searcher, params);
34+
35+
var result = render(results);
36+
37+
var promptEvent = extractEvent(result, "prompt");
38+
assertNotNull(promptEvent);
39+
assertTrue(promptEvent.has("prompt"));
40+
41+
var resultsEvent = extractEvent(result, "hits");
42+
assertNotNull(resultsEvent);
43+
assertTrue(resultsEvent.has("root"));
44+
assertEquals(2, resultsEvent.get("root").get("children").size());
45+
}
46+
47+
private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException {
48+
var lines = result.split("\n");
49+
for (int i = 0; i < lines.length; i++) {
50+
if (lines[i].startsWith("event: " + eventName)) {
51+
var data = lines[i + 1].substring("data: ".length()).trim();
52+
ObjectMapper objectMapper = new ObjectMapper();
53+
return objectMapper.readTree(data);
54+
}
55+
}
56+
return null;
57+
}
58+
59+
private String render(Result r) throws InterruptedException, ExecutionException {
60+
var execution = new Execution(Execution.Context.createContextStub());
61+
return render(execution, r);
62+
}
63+
64+
private String render(Execution execution, Result r) throws ExecutionException, InterruptedException {
65+
var renderer = new EventRenderer();
66+
try {
67+
renderer.init();
68+
ByteArrayOutputStream bs = new ByteArrayOutputStream();
69+
CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
70+
assertTrue(f.get());
71+
return Utf8.toString(bs.toByteArray());
72+
} finally {
73+
renderer.deconstruct();
74+
}
75+
}
76+
77+
}

container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ public void testResultRenderingIsSkipped() throws ExecutionException, Interrupte
232232
233233
event: end
234234
""";
235-
assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
235+
assertEquals(expected, result);
236236
}
237237

238238
static HitGroup newHitGroup(EventStream eventStream, String id) {

0 commit comments

Comments
 (0)