Skip to content

Commit 3126884

Browse files
daiyiplangfun authors
authored andcommitted
Keep response metadata in QueryInvocation.
PiperOrigin-RevId: 735521190
1 parent 900dc85 commit 3126884

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

langfun/core/structured/querying.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ def _result(message: lf.Message):
391391
if pg.MISSING_VALUE != prompt and not skip_lm:
392392
trackers = lf.context_value('__query_trackers__', [])
393393
if trackers:
394+
# To minimize payload for serialization, we remove the result and usage
395+
# fields from the metadata. They will be computed on the fly when the
396+
# invocation is rendered.
397+
metadata = dict(output_message.metadata)
398+
metadata.pop('result', None)
399+
metadata.pop('usage', None)
400+
394401
invocation = QueryInvocation(
395402
input=pg.Ref(query_input),
396403
schema=(
@@ -399,7 +406,7 @@ def _result(message: lf.Message):
399406
),
400407
lm=pg.Ref(lm),
401408
examples=pg.Ref(examples) if examples else [],
402-
lm_response=lf.AIMessage(output_message.text),
409+
lm_response=lf.AIMessage(output_message.text, metadata=metadata),
403410
usage_summary=usage_summary,
404411
start_time=start_time,
405412
end_time=end_time,

langfun/core/structured/querying_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,8 @@ def test_to_html(self):
10691069
querying.query('foo', Activity, lm=lm)
10701070

10711071
self.assertIn('schema', queries[0].to_html_str())
1072+
self.assertEqual(queries[0].lm_response.score, 1.0)
1073+
self.assertFalse(queries[0].lm_response.is_cached)
10721074

10731075

10741076
class TrackQueriesTest(unittest.TestCase):

0 commit comments

Comments
 (0)