Skip to content

Commit 3fb3c2d

Browse files
authored
Simplify quickstart.py (#171)
1 parent ed6328d commit 3fb3c2d

File tree

1 file changed

+24
-151
lines changed

1 file changed

+24
-151
lines changed

evi/evi-python-quickstart/quickstart.py

Lines changed: 24 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -4,215 +4,88 @@
44
import os
55
from dotenv import load_dotenv
66
from hume.client import AsyncHumeClient
7-
from hume.empathic_voice.chat.socket_client import ChatConnectOptions, ChatWebsocketConnection
7+
from hume.empathic_voice.chat.socket_client import ChatConnectOptions
88
from hume.empathic_voice.chat.types import SubscribeEvent
9-
from hume.empathic_voice.types import UserInput
10-
from hume.core.api_error import ApiError
119
from hume import MicrophoneInterface, Stream
1210

1311
class WebSocketHandler:
14-
"""Handler for containing the EVI WebSocket and associated socket handling behavior."""
15-
1612
def __init__(self):
17-
"""Construct the WebSocketHandler, initially assigning the socket to None and the byte stream to a new Stream object."""
18-
self.socket = None
1913
self.byte_strs = Stream.new()
2014

21-
def set_socket(self, socket: ChatWebsocketConnection):
22-
"""Set the socket.
23-
24-
This method assigns the provided asynchronous WebSocket connection
25-
to the instance variable `self.socket`. It is invoked after successfully
26-
establishing a connection using the client's connect method.
27-
28-
Args:
29-
socket (ChatWebsocketConnection): EVI asynchronous WebSocket returned by the client's connect method.
30-
"""
31-
self.socket = socket
32-
3315
async def on_open(self):
34-
"""Logic invoked when the WebSocket connection is opened."""
3516
print("WebSocket connection opened.")
3617

3718
async def on_message(self, message: SubscribeEvent):
38-
"""Callback function to handle a WebSocket message event.
39-
40-
This asynchronous method decodes the message, determines its type, and
41-
handles it accordingly. Depending on the type of message, it
42-
might log metadata, handle user or assistant messages, process
43-
audio data, raise an error if the message type is "error", and more.
44-
45-
This method interacts with the following message types to demonstrate logging output to the terminal:
46-
- [chat_metadata](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#receive.Chat%20Metadata.type)
47-
- [user_message](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#receive.User%20Message.type)
48-
- [assistant_message](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#receive.Assistant%20Message.type)
49-
- [audio_output](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#receive.Audio%20Output.type)
50-
51-
Args:
52-
data (SubscribeEvent): This represents any type of message that is received through the EVI WebSocket, formatted in JSON. See the full list of messages in the API Reference [here](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#receive).
53-
"""
54-
55-
# Create an empty dictionary to store expression inference scores
56-
scores = {}
57-
5819
if message.type == "chat_metadata":
59-
message_type = message.type.upper()
60-
chat_id = message.chat_id
61-
chat_group_id = message.chat_group_id
62-
text = f"<{message_type}> Chat ID: {chat_id}, Chat Group ID: {chat_group_id}"
63-
elif message.type in ["user_message", "assistant_message"]:
64-
role = message.message.role.upper()
65-
message_text = message.message.content
66-
text = f"{role}: {message_text}"
67-
if message.from_text is False:
68-
scores = dict(message.models.prosody.scores)
20+
self._print_prompt(f"<{message.type}> Chat ID: {message.chat_id}, Chat Group ID: {message.chat_group_id}")
21+
return
22+
elif message.type == "user_message" or message.type == "assistant_message":
23+
self._print_prompt(f"{message.message.role}: {message.message.content}")
24+
if message.models.prosody is not None:
25+
self._print_emotion_scores(
26+
self._extract_top_n_emotions(dict(message.models.prosody.scores), 3)
27+
)
28+
else:
29+
print("Emotion scores not available.")
30+
return
6931
elif message.type == "audio_output":
70-
message_str: str = message.data
71-
message_bytes = base64.b64decode(message_str.encode("utf-8"))
72-
await self.byte_strs.put(message_bytes)
32+
await self.byte_strs.put(
33+
base64.b64decode(message.data.encode("utf-8"))
34+
)
7335
return
7436
elif message.type == "error":
75-
error_message: str = message.message
76-
error_code: str = message.code
77-
raise ApiError(f"Error ({error_code}): {error_message}")
37+
raise RuntimeError(f"Received error message from Hume websocket ({message.code}): {message.message}")
7838
else:
79-
message_type = message.type.upper()
80-
text = f"<{message_type}>"
81-
82-
# Print the formatted message
83-
self._print_prompt(text)
39+
self._print_prompt(f"<{message.type}>")
8440

85-
# Extract and print the top 3 emotions inferred from user and assistant expressions
86-
if len(scores) > 0:
87-
top_3_emotions = self._extract_top_n_emotions(scores, 3)
88-
self._print_emotion_scores(top_3_emotions)
89-
print("")
90-
else:
91-
print("")
9241

9342
async def on_close(self):
94-
"""Logic invoked when the WebSocket connection is closed."""
9543
print("WebSocket connection closed.")
9644

9745
async def on_error(self, error):
98-
"""Logic invoked when an error occurs in the WebSocket connection.
99-
100-
See the full list of errors [here](https://dev.hume.ai/docs/resources/errors).
101-
102-
Args:
103-
error (Exception): The error that occurred during the WebSocket communication.
104-
"""
10546
print(f"Error: {error}")
10647

10748
def _print_prompt(self, text: str) -> None:
108-
"""Print a formatted message with a timestamp.
109-
110-
Args:
111-
text (str): The message text to be printed.
112-
"""
113-
now = datetime.datetime.now(tz=datetime.timezone.utc)
114-
now_str = now.strftime("%H:%M:%S")
115-
print(f"[{now_str}] {text}")
49+
now = datetime.datetime.now(tz=datetime.timezone.utc).strftime("%H:%M:%S")
50+
print(f"[{now}] {text}")
11651

11752
def _extract_top_n_emotions(self, emotion_scores: dict, n: int) -> dict:
118-
"""
119-
Extract the top N emotions based on confidence scores.
120-
121-
Args:
122-
emotion_scores (dict): A dictionary of emotions and their corresponding confidence scores.
123-
n (int): The number of top emotions to extract.
124-
125-
Returns:
126-
dict: A dictionary containing the top N emotions as keys and their raw scores as values.
127-
"""
128-
# Convert the dictionary into a list of tuples and sort by the score in descending order
12953
sorted_emotions = sorted(emotion_scores.items(), key=lambda item: item[1], reverse=True)
130-
131-
# Extract the top N emotions
13254
top_n_emotions = {emotion: score for emotion, score in sorted_emotions[:n]}
13355

13456
return top_n_emotions
13557

13658
def _print_emotion_scores(self, emotion_scores: dict) -> None:
137-
"""
138-
Print the emotions and their scores in a formatted, single-line manner.
139-
140-
Args:
141-
emotion_scores (dict): A dictionary of emotions and their corresponding confidence scores.
142-
"""
143-
# Format the output string
144-
formatted_emotions = ' | '.join([f"{emotion} ({score:.2f})" for emotion, score in emotion_scores.items()])
145-
146-
# Print the formatted string
147-
print(f"|{formatted_emotions}|")
148-
149-
150-
async def sending_handler(socket: ChatWebsocketConnection):
151-
"""Handle sending a message over the socket.
152-
153-
This method waits 3 seconds and sends a UserInput message, which takes a `text` parameter as input.
154-
- https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#send.User%20Input.type
155-
156-
See the full list of messages to send [here](https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#send).
157-
158-
Args:
159-
socket (ChatWebsocketConnection): The WebSocket connection used to send messages.
160-
"""
161-
# Wait 3 seconds before executing the rest of the method
162-
await asyncio.sleep(3)
163-
164-
# Construct a user input message
165-
# user_input_message = UserInput(text="Hello there!")
166-
167-
# Send the user input as text to the socket
168-
# await socket.send_user_input(user_input_message)
59+
print(
60+
' | '.join([f"{emotion} ({score:.2f})" for emotion, score in emotion_scores.items()])
61+
)
16962

17063
async def main() -> None:
171-
# Retrieve any environment variables stored in the .env file
17264
load_dotenv()
17365

174-
# Retrieve the API key, Secret key, and EVI config id from the environment variables
17566
HUME_API_KEY = os.getenv("HUME_API_KEY")
17667
HUME_SECRET_KEY = os.getenv("HUME_SECRET_KEY")
17768
HUME_CONFIG_ID = os.getenv("HUME_CONFIG_ID")
17869

179-
# Initialize the asynchronous client, authenticating with your API key
18070
client = AsyncHumeClient(api_key=HUME_API_KEY)
181-
182-
# Define options for the WebSocket connection, such as an EVI config id and a secret key for token authentication
183-
# See the full list of query parameters here: https://dev.hume.ai/reference/empathic-voice-interface-evi/chat/chat#request.query
18471
options = ChatConnectOptions(config_id=HUME_CONFIG_ID, secret_key=HUME_SECRET_KEY)
18572

186-
# Instantiate the WebSocketHandler
18773
websocket_handler = WebSocketHandler()
18874

189-
# Open the WebSocket connection with the configuration options and the handler's functions
19075
async with client.empathic_voice.chat.connect_with_callbacks(
19176
options=options,
19277
on_open=websocket_handler.on_open,
19378
on_message=websocket_handler.on_message,
19479
on_close=websocket_handler.on_close,
19580
on_error=websocket_handler.on_error
19681
) as socket:
197-
198-
# Set the socket instance in the handler
199-
websocket_handler.set_socket(socket)
200-
201-
# Create an asynchronous task to continuously detect and process input from the microphone, as well as play audio
202-
microphone_task = asyncio.create_task(
82+
await asyncio.create_task(
20383
MicrophoneInterface.start(
20484
socket,
20585
allow_user_interrupt=False,
20686
byte_stream=websocket_handler.byte_strs
20787
)
20888
)
209-
210-
# Create an asynchronous task to send messages over the WebSocket connection
211-
message_sending_task = asyncio.create_task(sending_handler(socket))
212-
213-
# Schedule the coroutines to occur simultaneously
214-
await asyncio.gather(microphone_task, message_sending_task)
21589

216-
# Execute the main asynchronous function using asyncio's event loop
21790
if __name__ == "__main__":
218-
asyncio.run(main())
91+
asyncio.run(main())

0 commit comments

Comments
 (0)