Skip to content

Commit f95d781

Browse files
authored
implement openaichat images for APIs that support them (using content parts) (#2849)
1 parent 458fcc2 commit f95d781

File tree

4 files changed

+262
-22
lines changed

4 files changed

+262
-22
lines changed

pkg/aiusechat/openai/openai-convertmessage.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe
403403
}
404404

405405
var contentBlocks []OpenAIMessageContent
406+
imageCount := 0
407+
imageFailCount := 0
406408

407409
for i, part := range aiMsg.Parts {
408410
switch part.Type {
@@ -416,8 +418,14 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe
416418
})
417419

418420
case uctypes.AIMessagePartTypeFile:
421+
if strings.HasPrefix(part.MimeType, "image/") {
422+
imageCount++
423+
}
419424
block, err := convertFileAIMessagePart(part)
420425
if err != nil {
426+
if strings.HasPrefix(part.MimeType, "image/") {
427+
imageFailCount++
428+
}
421429
log.Printf("openai: %v", err)
422430
continue
423431
}
@@ -430,6 +438,13 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe
430438
}
431439
}
432440

441+
if len(contentBlocks) == 0 {
442+
if imageCount > 0 && imageFailCount == imageCount {
443+
return nil, fmt.Errorf("all %d image conversions failed", imageCount)
444+
}
445+
return nil, errors.New("message has no valid content after processing all parts")
446+
}
447+
433448
return &OpenAIChatMessage{
434449
MessageId: aiMsg.MessageId,
435450
Message: &OpenAIMessage{

pkg/aiusechat/openaichat/openaichat-backend.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@ func RunChatStep(
4646
// Convert stored messages to chat completions format
4747
var messages []ChatRequestMessage
4848

49-
// Add system prompt if provided
50-
if len(chatOpts.SystemPrompt) > 0 {
51-
messages = append(messages, ChatRequestMessage{
52-
Role: "system",
53-
Content: strings.Join(chatOpts.SystemPrompt, "\n"),
54-
})
55-
}
56-
5749
// Convert native messages
5850
for _, genMsg := range chat.NativeMessages {
5951
chatMsg, ok := genMsg.(*StoredChatMessage)

pkg/aiusechat/openaichat/openaichat-convertmessage.go

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ const (
2828
func appendToLastUserMessage(messages []ChatRequestMessage, text string) {
2929
for i := len(messages) - 1; i >= 0; i-- {
3030
if messages[i].Role == "user" {
31-
messages[i].Content += "\n\n" + text
31+
if len(messages[i].ContentParts) > 0 {
32+
messages[i].ContentParts = append(messages[i].ContentParts, ChatContentPart{
33+
Type: "text",
34+
Text: text,
35+
})
36+
} else {
37+
messages[i].Content += "\n\n" + text
38+
}
3239
break
3340
}
3441
}
@@ -167,6 +174,21 @@ func ConvertAIMessageToStoredChatMessage(aiMsg uctypes.AIMessage) (*StoredChatMe
167174
return nil, fmt.Errorf("invalid AIMessage: %w", err)
168175
}
169176

177+
hasImages := false
178+
for _, part := range aiMsg.Parts {
179+
if strings.HasPrefix(part.MimeType, "image/") {
180+
hasImages = true
181+
break
182+
}
183+
}
184+
185+
if hasImages {
186+
return convertAIMessageMultimodal(aiMsg)
187+
}
188+
return convertAIMessageTextOnly(aiMsg)
189+
}
190+
191+
func convertAIMessageTextOnly(aiMsg uctypes.AIMessage) (*StoredChatMessage, error) {
170192
var textBuilder strings.Builder
171193
firstText := true
172194
for _, part := range aiMsg.Parts {
@@ -213,6 +235,89 @@ func ConvertAIMessageToStoredChatMessage(aiMsg uctypes.AIMessage) (*StoredChatMe
213235
}, nil
214236
}
215237

238+
func convertAIMessageMultimodal(aiMsg uctypes.AIMessage) (*StoredChatMessage, error) {
239+
var contentParts []ChatContentPart
240+
imageCount := 0
241+
imageFailCount := 0
242+
243+
for _, part := range aiMsg.Parts {
244+
switch {
245+
case part.Type == uctypes.AIMessagePartTypeText:
246+
if part.Text != "" {
247+
contentParts = append(contentParts, ChatContentPart{
248+
Type: "text",
249+
Text: part.Text,
250+
})
251+
}
252+
253+
case strings.HasPrefix(part.MimeType, "image/"):
254+
imageCount++
255+
imageUrl, err := aiutil.ExtractImageUrl(part.Data, part.URL, part.MimeType)
256+
if err != nil {
257+
imageFailCount++
258+
log.Printf("openaichat: error extracting image URL for %s: %v\n", part.FileName, err)
259+
continue
260+
}
261+
contentParts = append(contentParts, ChatContentPart{
262+
Type: "image_url",
263+
ImageUrl: &ChatImageUrl{Url: imageUrl},
264+
FileName: part.FileName,
265+
PreviewUrl: part.PreviewUrl,
266+
MimeType: part.MimeType,
267+
})
268+
269+
case part.MimeType == "text/plain":
270+
textData, err := aiutil.ExtractTextData(part.Data, part.URL)
271+
if err != nil {
272+
log.Printf("openaichat: error extracting text data for %s: %v\n", part.FileName, err)
273+
continue
274+
}
275+
formattedText := aiutil.FormatAttachedTextFile(part.FileName, textData)
276+
if formattedText != "" {
277+
contentParts = append(contentParts, ChatContentPart{
278+
Type: "text",
279+
Text: formattedText,
280+
})
281+
}
282+
283+
case part.MimeType == "directory":
284+
if len(part.Data) == 0 {
285+
log.Printf("openaichat: directory listing part missing data for %s\n", part.FileName)
286+
continue
287+
}
288+
formattedText := aiutil.FormatAttachedDirectoryListing(part.FileName, string(part.Data))
289+
if formattedText != "" {
290+
contentParts = append(contentParts, ChatContentPart{
291+
Type: "text",
292+
Text: formattedText,
293+
})
294+
}
295+
296+
case part.MimeType == "application/pdf":
297+
log.Printf("openaichat: PDF attachments are not supported by Chat Completions API, skipping %s\n", part.FileName)
298+
continue
299+
300+
default:
301+
continue
302+
}
303+
}
304+
305+
if len(contentParts) == 0 {
306+
if imageCount > 0 && imageFailCount == imageCount {
307+
return nil, fmt.Errorf("all %d image conversions failed", imageCount)
308+
}
309+
return nil, errors.New("message has no valid content after processing all parts")
310+
}
311+
312+
return &StoredChatMessage{
313+
MessageId: aiMsg.MessageId,
314+
Message: ChatRequestMessage{
315+
Role: "user",
316+
ContentParts: contentParts,
317+
},
318+
}, nil
319+
}
320+
216321
// ConvertToolResultsToNativeChatMessage converts tool results to OpenAI tool messages
217322
func ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) {
218323
if len(toolResults) == 0 {
@@ -261,8 +366,36 @@ func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) {
261366

262367
var parts []uctypes.UIMessagePart
263368

264-
// Add text content if present
265-
if chatMsg.Message.Content != "" {
369+
if len(chatMsg.Message.ContentParts) > 0 {
370+
for _, cp := range chatMsg.Message.ContentParts {
371+
switch cp.Type {
372+
case "text":
373+
if found, part := aiutil.ConvertDataUserFile(cp.Text); found {
374+
if part != nil {
375+
parts = append(parts, *part)
376+
}
377+
} else {
378+
parts = append(parts, uctypes.UIMessagePart{
379+
Type: "text",
380+
Text: cp.Text,
381+
})
382+
}
383+
case "image_url":
384+
mimeType := cp.MimeType
385+
if mimeType == "" {
386+
mimeType = "image/*"
387+
}
388+
parts = append(parts, uctypes.UIMessagePart{
389+
Type: "data-userfile",
390+
Data: uctypes.UIMessageDataUserFile{
391+
FileName: cp.FileName,
392+
MimeType: mimeType,
393+
PreviewUrl: cp.PreviewUrl,
394+
},
395+
})
396+
}
397+
}
398+
} else if chatMsg.Message.Content != "" {
266399
parts = append(parts, uctypes.UIMessagePart{
267400
Type: "text",
268401
Text: chatMsg.Message.Content,

pkg/aiusechat/openaichat/openaichat-types.go

Lines changed: 111 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
package openaichat
55

66
import (
7+
"bytes"
8+
"encoding/json"
9+
710
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
811
)
912

@@ -20,22 +23,115 @@ type ChatRequest struct {
2023
ToolChoice any `json:"tool_choice,omitempty"` // "auto", "none", or struct
2124
}
2225

26+
type ChatContentPart struct {
27+
Type string `json:"type"` // "text" or "image_url"
28+
Text string `json:"text,omitempty"` // for type "text"
29+
ImageUrl *ChatImageUrl `json:"image_url,omitempty"` // for type "image_url"
30+
31+
FileName string `json:"filename,omitempty"` // internal: original filename
32+
PreviewUrl string `json:"previewurl,omitempty"` // internal: 128x128 webp preview
33+
MimeType string `json:"mimetype,omitempty"` // internal: original mimetype
34+
}
35+
36+
func (cp *ChatContentPart) clean() *ChatContentPart {
37+
if cp.FileName == "" && cp.PreviewUrl == "" && cp.MimeType == "" {
38+
return cp
39+
}
40+
rtn := *cp
41+
rtn.FileName = ""
42+
rtn.PreviewUrl = ""
43+
rtn.MimeType = ""
44+
return &rtn
45+
}
46+
47+
type ChatImageUrl struct {
48+
Url string `json:"url"`
49+
Detail string `json:"detail,omitempty"` // "auto", "low", "high"
50+
}
51+
2352
type ChatRequestMessage struct {
24-
Role string `json:"role"` // "system","user","assistant","tool"
25-
Content string `json:"content,omitempty"` // normal text messages
26-
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message
27-
ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool"
28-
Name string `json:"name,omitempty"` // tool name on role:"tool"
53+
Role string `json:"role"` // "system","user","assistant","tool"
54+
Content string `json:"-"` // plain text (used when ContentParts is nil)
55+
ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present)
56+
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message
57+
ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool"
58+
Name string `json:"name,omitempty"` // tool name on role:"tool"
2959
}
3060

31-
func (cm *ChatRequestMessage) clean() *ChatRequestMessage {
32-
if len(cm.ToolCalls) == 0 {
33-
return cm
61+
// chatRequestMessageJSON is the wire format for ChatRequestMessage
62+
type chatRequestMessageJSON struct {
63+
Role string `json:"role"`
64+
Content json.RawMessage `json:"content"`
65+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
66+
ToolCallID string `json:"tool_call_id,omitempty"`
67+
Name string `json:"name,omitempty"`
68+
}
69+
70+
func (cm ChatRequestMessage) MarshalJSON() ([]byte, error) {
71+
raw := chatRequestMessageJSON{
72+
Role: cm.Role,
73+
ToolCalls: cm.ToolCalls,
74+
ToolCallID: cm.ToolCallID,
75+
Name: cm.Name,
76+
}
77+
if len(cm.ContentParts) > 0 {
78+
b, err := json.Marshal(cm.ContentParts)
79+
if err != nil {
80+
return nil, err
81+
}
82+
raw.Content = b
83+
} else if cm.Content != "" {
84+
b, err := json.Marshal(cm.Content)
85+
if err != nil {
86+
return nil, err
87+
}
88+
raw.Content = b
3489
}
90+
return json.Marshal(raw)
91+
}
92+
93+
func (cm *ChatRequestMessage) UnmarshalJSON(data []byte) error {
94+
var raw chatRequestMessageJSON
95+
if err := json.Unmarshal(data, &raw); err != nil {
96+
return err
97+
}
98+
cm.Role = raw.Role
99+
cm.ToolCalls = raw.ToolCalls
100+
cm.ToolCallID = raw.ToolCallID
101+
cm.Name = raw.Name
102+
cm.Content = ""
103+
cm.ContentParts = nil
104+
if len(raw.Content) == 0 || bytes.Equal(raw.Content, []byte("null")) {
105+
return nil
106+
}
107+
// try array first
108+
var parts []ChatContentPart
109+
if err := json.Unmarshal(raw.Content, &parts); err == nil {
110+
cm.ContentParts = parts
111+
return nil
112+
}
113+
// fall back to string
114+
var s string
115+
if err := json.Unmarshal(raw.Content, &s); err != nil {
116+
return err
117+
}
118+
cm.Content = s
119+
return nil
120+
}
121+
122+
func (cm *ChatRequestMessage) clean() *ChatRequestMessage {
35123
rtn := *cm
36-
rtn.ToolCalls = make([]ToolCall, len(cm.ToolCalls))
37-
for i, tc := range cm.ToolCalls {
38-
rtn.ToolCalls[i] = *tc.clean()
124+
if len(cm.ToolCalls) > 0 {
125+
rtn.ToolCalls = make([]ToolCall, len(cm.ToolCalls))
126+
for i, tc := range cm.ToolCalls {
127+
rtn.ToolCalls[i] = *tc.clean()
128+
}
129+
}
130+
if len(cm.ContentParts) > 0 {
131+
rtn.ContentParts = make([]ChatContentPart, len(cm.ContentParts))
132+
for i, cp := range cm.ContentParts {
133+
rtn.ContentParts[i] = *cp.clean()
134+
}
39135
}
40136
return &rtn
41137
}
@@ -163,6 +259,10 @@ func (m *StoredChatMessage) Copy() *StoredChatMessage {
163259
}
164260
}
165261
}
262+
if len(m.Message.ContentParts) > 0 {
263+
copied.Message.ContentParts = make([]ChatContentPart, len(m.Message.ContentParts))
264+
copy(copied.Message.ContentParts, m.Message.ContentParts)
265+
}
166266
if m.Usage != nil {
167267
usageCopy := *m.Usage
168268
copied.Usage = &usageCopy

0 commit comments

Comments
 (0)