Skip to content

Commit b44142a

Browse files
authoredApr 18, 2024··
Merge pull request #101 from AssemblyAI/niels/close-properly
Terminate streaming session properly
2 parents 478309b + eecc79a commit b44142a

File tree

3 files changed

+115
-15
lines changed

3 files changed

+115
-15
lines changed
 

‎sample-app/src/main/java/sample/App.java

+15-6
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
import com.assemblyai.api.resources.lemur.requests.LemurTaskParams;
77
import com.assemblyai.api.resources.lemur.types.LemurTaskResponse;
88
import com.assemblyai.api.resources.realtime.types.AudioEncoding;
9+
import com.assemblyai.api.resources.realtime.types.SessionInformation;
910
import com.assemblyai.api.resources.transcripts.requests.*;
1011
import com.assemblyai.api.resources.transcripts.types.*;
1112
import java.io.File;
1213
import java.io.FileInputStream;
1314
import java.io.IOException;
1415
import java.nio.file.Files;
1516
import java.util.List;
17+
import java.util.concurrent.ExecutionException;
18+
import java.util.concurrent.Future;
1619

1720
public final class App {
1821

19-
public static void main(String... args) throws IOException, InterruptedException {
22+
public static void main(String... args) throws IOException, InterruptedException, ExecutionException {
2023
AssemblyAI client = AssemblyAI.builder()
2124
.apiKey(System.getenv("ASSEMBLYAI_API_KEY"))
2225
.build();
@@ -86,18 +89,24 @@ public static void main(String... args) throws IOException, InterruptedException
8689
TranscriptList transcripts = client.transcripts().list();
8790
System.out.println("List transcript. " + transcripts);
8891

89-
RealtimeTranscriber realtimeTranscriber = RealtimeTranscriber.builder()
92+
try (RealtimeTranscriber realtimeTranscriber = RealtimeTranscriber.builder()
9093
.apiKey(System.getenv("ASSEMBLYAI_API_KEY"))
9194
.encoding(AudioEncoding.PCM_S16LE)
9295
.onSessionBegins(System.out::println)
9396
.onPartialTranscript(System.out::println)
9497
.onFinalTranscript(System.out::println)
9598
.onError((err) -> System.out.println(err.getMessage()))
9699
.onClose((code, reason) -> System.out.printf("%s: %s", code, reason))
97-
.build();
98-
realtimeTranscriber.connect();
99-
streamFile("sample-app/src/main/resources/gore-short.wav", realtimeTranscriber);
100-
realtimeTranscriber.close();
100+
.onSessionInformation(System.out::println)
101+
.build()) {
102+
realtimeTranscriber.connect();
103+
streamFile("sample-app/src/main/resources/gore-short.wav", realtimeTranscriber);
104+
Future<SessionInformation> closeFuture = realtimeTranscriber.closeWithSessionTermination();
105+
SessionInformation info = closeFuture.get();
106+
// Force exit is necessary for some reason.
107+
// The program will end after a while, but not immediately as it should.
108+
System.exit(0);
109+
}
101110
}
102111

103112
public static void streamFile(String filePath, RealtimeTranscriber realtimeTranscriber) {

‎src/main/java/com/assemblyai/api/RealtimeTranscriber.java

+91-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import java.util.List;
1111
import java.util.Optional;
1212
import java.util.*;
13+
import java.util.concurrent.CompletableFuture;
14+
import java.util.concurrent.Future;
1315
import java.util.function.BiConsumer;
1416
import java.util.function.Consumer;
1517

@@ -40,7 +42,11 @@ public final class RealtimeTranscriber implements AutoCloseable {
4042
private final Consumer<Throwable> onError;
4143
private final BiConsumer<Integer, String> onClose;
4244
private final RealtimeMessageVisitor realtimeMessageVisitor;
45+
private final Consumer<SessionInformation> onSessionInformation;
4346
private WebSocket webSocket;
47+
private SessionInformation sessionInformation;
48+
private CompletableFuture<SessionInformation> sessionTerminatedFuture;
49+
private boolean isConnected;
4450

4551
private RealtimeTranscriber(
4652
String apiKey,
@@ -55,6 +61,7 @@ private RealtimeTranscriber(
5561
Consumer<FinalTranscript> onFinalTranscript,
5662
Consumer<RealtimeTranscript> onTranscript,
5763
Consumer<Throwable> onError,
64+
Consumer<SessionInformation> onSessionInformation,
5865
BiConsumer<Integer, String> onClose) {
5966
this.apiKey = apiKey;
6067
this.token = token;
@@ -68,6 +75,7 @@ private RealtimeTranscriber(
6875
this.onFinalTranscript = onFinalTranscript;
6976
this.onTranscript = onTranscript;
7077
this.onError = onError;
78+
this.onSessionInformation = onSessionInformation;
7179
this.onClose = onClose;
7280
this.realtimeMessageVisitor = new RealtimeMessageVisitor();
7381
}
@@ -83,6 +91,10 @@ public void connect() {
8391
if (disablePartialTranscripts) {
8492
url += "&disable_partial_transcripts=true";
8593
}
94+
95+
// always set so it can be return from closeWithSessionTermination
96+
url += "&enable_extra_session_information=true";
97+
8698
if (wordBoost.isPresent() && !wordBoost.get().isEmpty()) {
8799
try {
88100
url += "&word_boost=" + ObjectMappers.JSON_MAPPER.writeValueAsString(wordBoost.get());
@@ -144,15 +156,33 @@ public void configureEndUtteranceSilenceThreshold(int threshold) {
144156
));
145157
}
146158

159+
public Future<SessionInformation> closeWithSessionTermination() {
160+
this.sessionTerminatedFuture = new CompletableFuture<SessionInformation>();
161+
this.webSocket.send("{\"terminate_session\":true}");
162+
sessionTerminatedFuture.whenComplete((sessionInformation1, throwable) -> this.closeSocket());
163+
return this.sessionTerminatedFuture;
164+
}
165+
147166
/**
148-
* Closes the websocket connection.
167+
* Closes the websocket connection immediately, without waiting for session termination.
168+
* Use closeWithSessionTermination() if possible.
169+
*
170+
* @see #closeWithSessionTermination
171+
* Terminate the session, wait for session termination, and then close the connection.
149172
*/
150173
@Override
151174
public void close() {
152-
boolean closed = this.webSocket.close(1000, "Shutting down");
153-
if (!closed) {
154-
this.webSocket.cancel();
175+
if (isConnected) {
176+
this.webSocket.send("{\"terminate_session\":true}");
155177
}
178+
this.closeSocket();
179+
}
180+
181+
private void closeSocket() {
182+
if(webSocket == null) return;
183+
this.webSocket.close(1000, "Shutting down");
184+
this.webSocket.cancel();
185+
this.webSocket = null;
156186
}
157187

158188
public static RealtimeTranscriber.Builder builder() {
@@ -174,6 +204,7 @@ public static final class Builder {
174204
private Consumer<RealtimeTranscript> onTranscript;
175205
private Consumer<Throwable> onError;
176206
private BiConsumer<Integer, String> onClose;
207+
private Consumer<SessionInformation> onSessionInformation;
177208

178209
/**
179210
* Sets the AssemblyAI API key used to authenticate the RealtimeTranscriber
@@ -323,6 +354,19 @@ public RealtimeTranscriber.Builder onError(Consumer<Throwable> onError) {
323354
return this;
324355
}
325356

357+
/**
358+
* Sets onSessionInformation
359+
*
360+
* @param onSessionInformation an event handler for the session information event.
361+
* This message is sent at the end of the session, before the SessionTerminated message.
362+
* Defaults to a noop.
363+
* @return this
364+
*/
365+
public RealtimeTranscriber.Builder onSessionInformation(Consumer<SessionInformation> onSessionInformation) {
366+
this.onSessionInformation = onSessionInformation;
367+
return this;
368+
}
369+
326370
/**
327371
* Sets onClose
328372
*
@@ -351,6 +395,7 @@ public RealtimeTranscriber build() {
351395
onFinalTranscript,
352396
onTranscript,
353397
onError,
398+
onSessionInformation,
354399
onClose);
355400
}
356401
}
@@ -364,6 +409,7 @@ public Listener(Consumer<Response> onOpen) {
364409

365410
@Override
366411
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
412+
isConnected = true;
367413
if (onOpen != null) {
368414
onOpen.accept(response);
369415
}
@@ -372,12 +418,29 @@ public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
372418
@Override
373419
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
374420
try {
375-
RealtimeMessage realtimeMessage = ObjectMappers.JSON_MAPPER.readValue(text, RealtimeMessage.class);
376-
try {
377-
realtimeMessage.visit(realtimeMessageVisitor);
378-
} catch (IllegalStateException ignored) {
379-
// when a new message is added to the API, this should not throw an exception
421+
RealtimeBaseMessage baseMessage = ObjectMappers.parseOrThrow(text, RealtimeBaseMessage.class);
422+
MessageType messageType = baseMessage.getMessageType();
423+
if (messageType == MessageType.SESSION_BEGINS) {
424+
realtimeMessageVisitor.visit(
425+
ObjectMappers.JSON_MAPPER.readValue(text, SessionBegins.class)
426+
);
427+
} else if (messageType == MessageType.PARTIAL_TRANSCRIPT) {
428+
realtimeMessageVisitor.visit(
429+
ObjectMappers.JSON_MAPPER.readValue(text, PartialTranscript.class)
430+
);
431+
} else if (messageType == MessageType.FINAL_TRANSCRIPT) {
432+
realtimeMessageVisitor.visit(
433+
ObjectMappers.JSON_MAPPER.readValue(text, FinalTranscript.class)
434+
);
435+
} else if (messageType == MessageType.SESSION_INFORMATION) {
436+
realtimeMessageVisitor.visit(
437+
ObjectMappers.JSON_MAPPER.readValue(text, SessionInformation.class)
438+
);
439+
} else if (messageType == MessageType.SESSION_TERMINATED) {
440+
realtimeMessageVisitor.visit((SessionTerminated) null);
380441
}
442+
// Intentionally don't throw an exception for unknown message type.
443+
// New message types shouldn't cause this to break.
381444
} catch (JsonProcessingException e) {
382445
if (onError == null) return;
383446
onError.accept(e);
@@ -386,6 +449,7 @@ public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
386449

387450
@Override
388451
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
452+
isConnected = false;
389453
if (onError == null) return;
390454
onError.accept(t);
391455
}
@@ -399,6 +463,12 @@ public void onClosing(@NotNull WebSocket webSocket, int code, String reason) {
399463
onClose.accept(code, reason);
400464
super.onClosing(webSocket, code, reason);
401465
}
466+
467+
@Override
468+
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
469+
isConnected = false;
470+
super.onClosed(webSocket, code, reason);
471+
}
402472
}
403473

404474
private final class RealtimeMessageVisitor implements RealtimeMessage.Visitor<Void> {
@@ -423,8 +493,20 @@ public Void visit(FinalTranscript value) {
423493
return null;
424494
}
425495

496+
@Override
497+
public Void visit(SessionInformation value) {
498+
sessionInformation = value;
499+
if (onSessionInformation == null) return null;
500+
onSessionInformation.accept(value);
501+
return null;
502+
}
503+
504+
426505
@Override
427506
public Void visit(SessionTerminated value) {
507+
if (sessionTerminatedFuture != null) {
508+
sessionTerminatedFuture.complete(sessionInformation);
509+
}
428510
return null;
429511
}
430512

‎src/main/java/com/assemblyai/api/resources/realtime/types/RealtimeMessage.java

+9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import com.fasterxml.jackson.databind.DeserializationContext;
1111
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
1212
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
13+
1314
import java.io.IOException;
1415
import java.util.Objects;
1516
import java.util.Optional;
@@ -41,6 +42,8 @@ public <T> T visit(Visitor<T> visitor) {
4142
return visitor.visit((SessionTerminated) this.value);
4243
} else if (this.type == 4) {
4344
return visitor.visit((RealtimeError) this.value);
45+
} else if (this.type == 5) {
46+
return visitor.visit((SessionInformation) this.value);
4447
}
4548
throw new IllegalStateException("Failed to visit value. This should never happen.");
4649
}
@@ -85,6 +88,10 @@ public static RealtimeMessage of(RealtimeError value) {
8588
return new RealtimeMessage(value, 4);
8689
}
8790

91+
public static RealtimeMessage of(SessionInformation value) {
92+
return new RealtimeMessage(value, 5);
93+
}
94+
8895
public interface Visitor<T> {
8996
T visit(SessionBegins value);
9097

@@ -95,6 +102,8 @@ public interface Visitor<T> {
95102
T visit(SessionTerminated value);
96103

97104
T visit(RealtimeError value);
105+
106+
T visit(SessionInformation value);
98107
}
99108

100109
static final class Deserializer extends StdDeserializer<RealtimeMessage> {

0 commit comments

Comments
 (0)