10
10
import java .util .List ;
11
11
import java .util .Optional ;
12
12
import java .util .*;
13
+ import java .util .concurrent .CompletableFuture ;
14
+ import java .util .concurrent .Future ;
13
15
import java .util .function .BiConsumer ;
14
16
import java .util .function .Consumer ;
15
17
@@ -40,7 +42,11 @@ public final class RealtimeTranscriber implements AutoCloseable {
40
42
private final Consumer <Throwable > onError ;
41
43
private final BiConsumer <Integer , String > onClose ;
42
44
private final RealtimeMessageVisitor realtimeMessageVisitor ;
45
+ private final Consumer <SessionInformation > onSessionInformation ;
43
46
private WebSocket webSocket ;
47
+ private SessionInformation sessionInformation ;
48
+ private CompletableFuture <SessionInformation > sessionTerminatedFuture ;
49
+ private boolean isConnected ;
44
50
45
51
private RealtimeTranscriber (
46
52
String apiKey ,
@@ -55,6 +61,7 @@ private RealtimeTranscriber(
55
61
Consumer <FinalTranscript > onFinalTranscript ,
56
62
Consumer <RealtimeTranscript > onTranscript ,
57
63
Consumer <Throwable > onError ,
64
+ Consumer <SessionInformation > onSessionInformation ,
58
65
BiConsumer <Integer , String > onClose ) {
59
66
this .apiKey = apiKey ;
60
67
this .token = token ;
@@ -68,6 +75,7 @@ private RealtimeTranscriber(
68
75
this .onFinalTranscript = onFinalTranscript ;
69
76
this .onTranscript = onTranscript ;
70
77
this .onError = onError ;
78
+ this .onSessionInformation = onSessionInformation ;
71
79
this .onClose = onClose ;
72
80
this .realtimeMessageVisitor = new RealtimeMessageVisitor ();
73
81
}
@@ -83,6 +91,10 @@ public void connect() {
83
91
if (disablePartialTranscripts ) {
84
92
url += "&disable_partial_transcripts=true" ;
85
93
}
94
+
95
+ // always set so it can be return from closeWithSessionTermination
96
+ url += "&enable_extra_session_information=true" ;
97
+
86
98
if (wordBoost .isPresent () && !wordBoost .get ().isEmpty ()) {
87
99
try {
88
100
url += "&word_boost=" + ObjectMappers .JSON_MAPPER .writeValueAsString (wordBoost .get ());
@@ -144,15 +156,33 @@ public void configureEndUtteranceSilenceThreshold(int threshold) {
144
156
));
145
157
}
146
158
159
+ public Future <SessionInformation > closeWithSessionTermination () {
160
+ this .sessionTerminatedFuture = new CompletableFuture <SessionInformation >();
161
+ this .webSocket .send ("{\" terminate_session\" :true}" );
162
+ sessionTerminatedFuture .whenComplete ((sessionInformation1 , throwable ) -> this .closeSocket ());
163
+ return this .sessionTerminatedFuture ;
164
+ }
165
+
147
166
/**
148
- * Closes the websocket connection.
167
+ * Closes the websocket connection immediately, without waiting for session termination.
168
+ * Use closeWithSessionTermination() if possible.
169
+ *
170
+ * @see #closeWithSessionTermination
171
+ * Terminate the session, wait for session termination, and then close the connection.
149
172
*/
150
173
@ Override
151
174
public void close () {
152
- boolean closed = this .webSocket .close (1000 , "Shutting down" );
153
- if (!closed ) {
154
- this .webSocket .cancel ();
175
+ if (isConnected ) {
176
+ this .webSocket .send ("{\" terminate_session\" :true}" );
155
177
}
178
+ this .closeSocket ();
179
+ }
180
+
181
+ private void closeSocket () {
182
+ if (webSocket == null ) return ;
183
+ this .webSocket .close (1000 , "Shutting down" );
184
+ this .webSocket .cancel ();
185
+ this .webSocket = null ;
156
186
}
157
187
158
188
public static RealtimeTranscriber .Builder builder () {
@@ -174,6 +204,7 @@ public static final class Builder {
174
204
private Consumer <RealtimeTranscript > onTranscript ;
175
205
private Consumer <Throwable > onError ;
176
206
private BiConsumer <Integer , String > onClose ;
207
+ private Consumer <SessionInformation > onSessionInformation ;
177
208
178
209
/**
179
210
* Sets the AssemblyAI API key used to authenticate the RealtimeTranscriber
@@ -323,6 +354,19 @@ public RealtimeTranscriber.Builder onError(Consumer<Throwable> onError) {
323
354
return this ;
324
355
}
325
356
357
+ /**
358
+ * Sets onSessionInformation
359
+ *
360
+ * @param onSessionInformation an event handler for the session information event.
361
+ * This message is sent at the end of the session, before the SessionTerminated message.
362
+ * Defaults to a noop.
363
+ * @return this
364
+ */
365
+ public RealtimeTranscriber .Builder onSessionInformation (Consumer <SessionInformation > onSessionInformation ) {
366
+ this .onSessionInformation = onSessionInformation ;
367
+ return this ;
368
+ }
369
+
326
370
/**
327
371
* Sets onClose
328
372
*
@@ -351,6 +395,7 @@ public RealtimeTranscriber build() {
351
395
onFinalTranscript ,
352
396
onTranscript ,
353
397
onError ,
398
+ onSessionInformation ,
354
399
onClose );
355
400
}
356
401
}
@@ -364,6 +409,7 @@ public Listener(Consumer<Response> onOpen) {
364
409
365
410
@ Override
366
411
public void onOpen (@ NotNull WebSocket webSocket , @ NotNull Response response ) {
412
+ isConnected = true ;
367
413
if (onOpen != null ) {
368
414
onOpen .accept (response );
369
415
}
@@ -372,12 +418,29 @@ public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
372
418
@ Override
373
419
public void onMessage (@ NotNull WebSocket webSocket , @ NotNull String text ) {
374
420
try {
375
- RealtimeMessage realtimeMessage = ObjectMappers .JSON_MAPPER .readValue (text , RealtimeMessage .class );
376
- try {
377
- realtimeMessage .visit (realtimeMessageVisitor );
378
- } catch (IllegalStateException ignored ) {
379
- // when a new message is added to the API, this should not throw an exception
421
+ RealtimeBaseMessage baseMessage = ObjectMappers .parseOrThrow (text , RealtimeBaseMessage .class );
422
+ MessageType messageType = baseMessage .getMessageType ();
423
+ if (messageType == MessageType .SESSION_BEGINS ) {
424
+ realtimeMessageVisitor .visit (
425
+ ObjectMappers .JSON_MAPPER .readValue (text , SessionBegins .class )
426
+ );
427
+ } else if (messageType == MessageType .PARTIAL_TRANSCRIPT ) {
428
+ realtimeMessageVisitor .visit (
429
+ ObjectMappers .JSON_MAPPER .readValue (text , PartialTranscript .class )
430
+ );
431
+ } else if (messageType == MessageType .FINAL_TRANSCRIPT ) {
432
+ realtimeMessageVisitor .visit (
433
+ ObjectMappers .JSON_MAPPER .readValue (text , FinalTranscript .class )
434
+ );
435
+ } else if (messageType == MessageType .SESSION_INFORMATION ) {
436
+ realtimeMessageVisitor .visit (
437
+ ObjectMappers .JSON_MAPPER .readValue (text , SessionInformation .class )
438
+ );
439
+ } else if (messageType == MessageType .SESSION_TERMINATED ) {
440
+ realtimeMessageVisitor .visit ((SessionTerminated ) null );
380
441
}
442
+ // Intentionally don't throw an exception for unknown message type.
443
+ // New message types shouldn't cause this to break.
381
444
} catch (JsonProcessingException e ) {
382
445
if (onError == null ) return ;
383
446
onError .accept (e );
@@ -386,6 +449,7 @@ public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
386
449
387
450
@ Override
388
451
public void onFailure (@ NotNull WebSocket webSocket , @ NotNull Throwable t , @ Nullable Response response ) {
452
+ isConnected = false ;
389
453
if (onError == null ) return ;
390
454
onError .accept (t );
391
455
}
@@ -399,6 +463,12 @@ public void onClosing(@NotNull WebSocket webSocket, int code, String reason) {
399
463
onClose .accept (code , reason );
400
464
super .onClosing (webSocket , code , reason );
401
465
}
466
+
467
+ @ Override
468
+ public void onClosed (@ NotNull WebSocket webSocket , int code , @ NotNull String reason ) {
469
+ isConnected = false ;
470
+ super .onClosed (webSocket , code , reason );
471
+ }
402
472
}
403
473
404
474
private final class RealtimeMessageVisitor implements RealtimeMessage .Visitor <Void > {
@@ -423,8 +493,20 @@ public Void visit(FinalTranscript value) {
423
493
return null ;
424
494
}
425
495
496
+ @ Override
497
+ public Void visit (SessionInformation value ) {
498
+ sessionInformation = value ;
499
+ if (onSessionInformation == null ) return null ;
500
+ onSessionInformation .accept (value );
501
+ return null ;
502
+ }
503
+
504
+
426
505
@ Override
427
506
public Void visit (SessionTerminated value ) {
507
+ if (sessionTerminatedFuture != null ) {
508
+ sessionTerminatedFuture .complete (sessionInformation );
509
+ }
428
510
return null ;
429
511
}
430
512
0 commit comments