@@ -133,6 +133,7 @@ pub enum MessageEvent {
133133 } ,
134134 Finish {
135135 reason : String ,
136+ token_state : TokenState ,
136137 } ,
137138 ModelChange {
138139 model : String ,
@@ -149,6 +150,27 @@ pub enum MessageEvent {
149150 Ping ,
150151}
151152
153+ async fn get_token_state ( session_id : & str ) -> TokenState {
154+ SessionManager :: get_session ( session_id, false )
155+ . await
156+ . map ( |session| TokenState {
157+ input_tokens : session. input_tokens . unwrap_or ( 0 ) ,
158+ output_tokens : session. output_tokens . unwrap_or ( 0 ) ,
159+ total_tokens : session. total_tokens . unwrap_or ( 0 ) ,
160+ accumulated_input_tokens : session. accumulated_input_tokens . unwrap_or ( 0 ) ,
161+ accumulated_output_tokens : session. accumulated_output_tokens . unwrap_or ( 0 ) ,
162+ accumulated_total_tokens : session. accumulated_total_tokens . unwrap_or ( 0 ) ,
163+ } )
164+ . inspect_err ( |e| {
165+ tracing:: warn!(
166+ "Failed to fetch session token state for {}: {}" ,
167+ session_id,
168+ e
169+ ) ;
170+ } )
171+ . unwrap_or_default ( )
172+ }
173+
152174async fn stream_event (
153175 event : MessageEvent ,
154176 tx : & mpsc:: Sender < String > ,
@@ -321,29 +343,7 @@ pub async fn reply(
321343
322344 all_messages. push( message. clone( ) ) ;
323345
324- let token_state = match SessionManager :: get_session( & session_id, false ) . await {
325- Ok ( session) => {
326- TokenState {
327- input_tokens: session. input_tokens. unwrap_or( 0 ) ,
328- output_tokens: session. output_tokens. unwrap_or( 0 ) ,
329- total_tokens: session. total_tokens. unwrap_or( 0 ) ,
330- accumulated_input_tokens: session. accumulated_input_tokens. unwrap_or( 0 ) ,
331- accumulated_output_tokens: session. accumulated_output_tokens. unwrap_or( 0 ) ,
332- accumulated_total_tokens: session. accumulated_total_tokens. unwrap_or( 0 ) ,
333- }
334- } ,
335- Err ( e) => {
336- tracing:: warn!( "Failed to fetch session for token state: {}" , e) ;
337- TokenState {
338- input_tokens: 0 ,
339- output_tokens: 0 ,
340- total_tokens: 0 ,
341- accumulated_input_tokens: 0 ,
342- accumulated_output_tokens: 0 ,
343- accumulated_total_tokens: 0 ,
344- }
345- }
346- } ;
346+ let token_state = get_token_state( & session_id) . await ;
347347
348348 stream_event( MessageEvent :: Message { message, token_state } , & tx, & cancel_token) . await ;
349349 }
@@ -437,9 +437,12 @@ pub async fn reply(
437437 ) ;
438438 }
439439
440+ let final_token_state = get_token_state ( & session_id) . await ;
441+
440442 let _ = stream_event (
441443 MessageEvent :: Finish {
442444 reason : "stop" . to_string ( ) ,
445+ token_state : final_token_state,
443446 } ,
444447 & task_tx,
445448 & cancel_token,
0 commit comments