14
14
import com .yahoo .search .Query ;
15
15
import com .yahoo .search .Result ;
16
16
import com .yahoo .search .Searcher ;
17
+ import com .yahoo .search .rendering .JsonRenderer ;
17
18
import com .yahoo .search .result .ErrorMessage ;
18
19
import com .yahoo .search .result .EventStream ;
19
20
import com .yahoo .search .result .HitGroup ;
20
21
import com .yahoo .search .searchchain .Execution ;
22
+ import com .yahoo .text .Utf8 ;
21
23
24
+ import java .io .ByteArrayOutputStream ;
22
25
import java .util .List ;
23
26
import java .util .concurrent .RejectedExecutionException ;
24
27
import java .util .function .Function ;
@@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher {
38
41
private static final String API_KEY_HEADER = "X-LLM-API-KEY" ;
39
42
private static final String STREAM_PROPERTY = "stream" ;
40
43
private static final String PROMPT_PROPERTY = "prompt" ;
44
+ private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt" ;
45
+ private static final String INCLUDE_HITS_IN_RESULT = "includeHits" ;
46
+
47
+ private final JsonRenderer jsonRenderer ;
41
48
42
49
private final String propertyPrefix ;
43
50
private final boolean stream ;
@@ -50,11 +57,13 @@ public LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> la
50
57
this .languageModelId = config .providerId ();
51
58
this .languageModel = findLanguageModel (languageModelId , languageModels );
52
59
this .propertyPrefix = config .propertyPrefix ();
60
+
61
+ this .jsonRenderer = new JsonRenderer ();
53
62
}
54
63
55
64
@ Override
56
65
public Result search (Query query , Execution execution ) {
57
- return complete (query , StringPrompt .from (getPrompt (query )));
66
+ return complete (query , StringPrompt .from (getPrompt (query )), null , execution );
58
67
}
59
68
60
69
private LanguageModel findLanguageModel (String providerId , ComponentRegistry <LanguageModel > languageModels )
@@ -81,30 +90,37 @@ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<Lan
81
90
return languageModel ;
82
91
}
83
92
84
- protected Result complete (Query query , Prompt prompt ) {
93
+ protected Result complete (Query query , Prompt prompt , Result result , Execution execution ) {
85
94
var options = new InferenceParameters (getApiKeyHeader (query ), s -> lookupProperty (s , query ));
86
95
var stream = lookupPropertyBool (STREAM_PROPERTY , query , this .stream ); // query value overwrites config
87
96
try {
88
- return stream ? completeAsync (query , prompt , options ) : completeSync (query , prompt , options );
97
+ if (stream ) {
98
+ return completeAsync (query , prompt , options , result , execution );
99
+ }
100
+ return completeSync (query , prompt , options , result , execution );
89
101
} catch (RejectedExecutionException e ) {
90
102
return new Result (query , new ErrorMessage (429 , e .getMessage ()));
91
103
}
92
104
}
93
105
94
106
private boolean shouldAddPrompt (Query query ) {
95
- return query .getTrace ().getLevel () >= 1 ;
107
+ var includePrompt = lookupPropertyBool (INCLUDE_PROMPT_IN_RESULT , query , false );
108
+ return query .getTrace ().getLevel () >= 1 || includePrompt ;
96
109
}
97
110
98
111
private boolean shouldAddTokenStats (Query query ) {
99
112
return query .getTrace ().getLevel () >= 1 ;
100
113
}
101
114
102
- private Result completeAsync (Query query , Prompt prompt , InferenceParameters options ) {
115
+ private Result completeAsync (Query query , Prompt prompt , InferenceParameters options , Result result , Execution execution ) {
103
116
final EventStream eventStream = new EventStream ();
104
117
105
118
if (shouldAddPrompt (query )) {
106
119
eventStream .add (prompt .asString (), "prompt" );
107
120
}
121
+ if (shouldAddHits (query ) && result != null ) {
122
+ eventStream .add (renderHits (result , execution ), "hits" );
123
+ }
108
124
109
125
final TokenStats tokenStats = new TokenStats ();
110
126
languageModel .completeAsync (prompt , options , completion -> {
@@ -143,12 +159,15 @@ private void handleException(EventStream eventStream, Throwable exception) {
143
159
eventStream .error (languageModelId , new ErrorMessage (errorCode , exception .getMessage ()));
144
160
}
145
161
146
- private Result completeSync (Query query , Prompt prompt , InferenceParameters options ) {
162
+ private Result completeSync (Query query , Prompt prompt , InferenceParameters options , Result result , Execution execution ) {
147
163
EventStream eventStream = new EventStream ();
148
164
149
165
if (shouldAddPrompt (query )) {
150
166
eventStream .add (prompt .asString (), "prompt" );
151
167
}
168
+ if (shouldAddHits (query ) && result != null ) {
169
+ eventStream .add (renderHits (result , execution ), "hits" );
170
+ }
152
171
153
172
List <Completion > completions = languageModel .complete (prompt , options );
154
173
eventStream .add (completions .get (0 ).text (), "completion" );
@@ -200,6 +219,18 @@ public String getApiKeyHeader(Query query) {
200
219
return lookupPropertyWithOrWithoutPrefix (API_KEY_HEADER , p -> query .getHttpRequest ().getHeader (p ));
201
220
}
202
221
222
+ private boolean shouldAddHits (Query query ) {
223
+ return lookupPropertyBool (INCLUDE_HITS_IN_RESULT , query , false );
224
+ }
225
+
226
+ private String renderHits (Result results , Execution execution ) {
227
+ var bs = new ByteArrayOutputStream ();
228
+ var renderer = jsonRenderer .clone ();
229
+ renderer .init ();
230
+ renderer .renderResponse (bs , results , execution , null ).join (); // wait for renderer to complete
231
+ return Utf8 .toString (bs .toByteArray ());
232
+ }
233
+
203
234
private static class TokenStats {
204
235
205
236
private final long start ;
0 commit comments