diff --git a/pkg/aiusechat/openai/openai-convertmessage.go b/pkg/aiusechat/openai/openai-convertmessage.go index 3accb354fa..045cc65cf2 100644 --- a/pkg/aiusechat/openai/openai-convertmessage.go +++ b/pkg/aiusechat/openai/openai-convertmessage.go @@ -403,6 +403,8 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe } var contentBlocks []OpenAIMessageContent + imageCount := 0 + imageFailCount := 0 for i, part := range aiMsg.Parts { switch part.Type { @@ -416,8 +418,14 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe }) case uctypes.AIMessagePartTypeFile: + if strings.HasPrefix(part.MimeType, "image/") { + imageCount++ + } block, err := convertFileAIMessagePart(part) if err != nil { + if strings.HasPrefix(part.MimeType, "image/") { + imageFailCount++ + } log.Printf("openai: %v", err) continue } @@ -430,6 +438,13 @@ func ConvertAIMessageToOpenAIChatMessage(aiMsg uctypes.AIMessage) (*OpenAIChatMe } } + if len(contentBlocks) == 0 { + if imageCount > 0 && imageFailCount == imageCount { + return nil, fmt.Errorf("all %d image conversions failed", imageCount) + } + return nil, errors.New("message has no valid content after processing all parts") + } + return &OpenAIChatMessage{ MessageId: aiMsg.MessageId, Message: &OpenAIMessage{ diff --git a/pkg/aiusechat/openaichat/openaichat-backend.go b/pkg/aiusechat/openaichat/openaichat-backend.go index 04df1a65d3..7b90aee674 100644 --- a/pkg/aiusechat/openaichat/openaichat-backend.go +++ b/pkg/aiusechat/openaichat/openaichat-backend.go @@ -46,14 +46,6 @@ func RunChatStep( // Convert stored messages to chat completions format var messages []ChatRequestMessage - // Add system prompt if provided - if len(chatOpts.SystemPrompt) > 0 { - messages = append(messages, ChatRequestMessage{ - Role: "system", - Content: strings.Join(chatOpts.SystemPrompt, "\n"), - }) - } - // Convert native messages for _, genMsg := range chat.NativeMessages { chatMsg, ok := genMsg.(*StoredChatMessage) diff --git a/pkg/aiusechat/openaichat/openaichat-convertmessage.go b/pkg/aiusechat/openaichat/openaichat-convertmessage.go index c2a7dcd070..90fe8c9f1f 100644 --- a/pkg/aiusechat/openaichat/openaichat-convertmessage.go +++ b/pkg/aiusechat/openaichat/openaichat-convertmessage.go @@ -28,7 +28,14 @@ const ( func appendToLastUserMessage(messages []ChatRequestMessage, text string) { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { - messages[i].Content += "\n\n" + text + if len(messages[i].ContentParts) > 0 { + messages[i].ContentParts = append(messages[i].ContentParts, ChatContentPart{ + Type: "text", + Text: text, + }) + } else { + messages[i].Content += "\n\n" + text + } break } } @@ -167,6 +174,21 @@ func ConvertAIMessageToStoredChatMessage(aiMsg uctypes.AIMessage) (*StoredChatMe return nil, fmt.Errorf("invalid AIMessage: %w", err) } + hasImages := false + for _, part := range aiMsg.Parts { + if strings.HasPrefix(part.MimeType, "image/") { + hasImages = true + break + } + } + + if hasImages { + return convertAIMessageMultimodal(aiMsg) + } + return convertAIMessageTextOnly(aiMsg) +} + +func convertAIMessageTextOnly(aiMsg uctypes.AIMessage) (*StoredChatMessage, error) { var textBuilder strings.Builder firstText := true for _, part := range aiMsg.Parts { @@ -213,6 +235,89 @@ func ConvertAIMessageToStoredChatMessage(aiMsg uctypes.AIMessage) (*StoredChatMe }, nil } +func convertAIMessageMultimodal(aiMsg uctypes.AIMessage) (*StoredChatMessage, error) { + var contentParts []ChatContentPart + imageCount := 0 + imageFailCount := 0 + + for _, part := range aiMsg.Parts { + switch { + case part.Type == uctypes.AIMessagePartTypeText: + if part.Text != "" { + contentParts = append(contentParts, ChatContentPart{ + Type: "text", + Text: part.Text, + }) + } + + case strings.HasPrefix(part.MimeType, "image/"): + imageCount++ + imageUrl, err := aiutil.ExtractImageUrl(part.Data, part.URL, part.MimeType) + if err != nil { + imageFailCount++ + log.Printf("openaichat: error extracting image URL for %s: %v\n", part.FileName, err) + continue + } + contentParts = append(contentParts, ChatContentPart{ + Type: "image_url", + ImageUrl: &ChatImageUrl{Url: imageUrl}, + FileName: part.FileName, + PreviewUrl: part.PreviewUrl, + MimeType: part.MimeType, + }) + + case part.MimeType == "text/plain": + textData, err := aiutil.ExtractTextData(part.Data, part.URL) + if err != nil { + log.Printf("openaichat: error extracting text data for %s: %v\n", part.FileName, err) + continue + } + formattedText := aiutil.FormatAttachedTextFile(part.FileName, textData) + if formattedText != "" { + contentParts = append(contentParts, ChatContentPart{ + Type: "text", + Text: formattedText, + }) + } + + case part.MimeType == "directory": + if len(part.Data) == 0 { + log.Printf("openaichat: directory listing part missing data for %s\n", part.FileName) + continue + } + formattedText := aiutil.FormatAttachedDirectoryListing(part.FileName, string(part.Data)) + if formattedText != "" { + contentParts = append(contentParts, ChatContentPart{ + Type: "text", + Text: formattedText, + }) + } + + case part.MimeType == "application/pdf": + log.Printf("openaichat: PDF attachments are not supported by Chat Completions API, skipping %s\n", part.FileName) + continue + + default: + continue + } + } + + if len(contentParts) == 0 { + if imageCount > 0 && imageFailCount == imageCount { + return nil, fmt.Errorf("all %d image conversions failed", imageCount) + } + return nil, errors.New("message has no valid content after processing all parts") + } + + return &StoredChatMessage{ + MessageId: aiMsg.MessageId, + Message: ChatRequestMessage{ + Role: "user", + ContentParts: contentParts, + }, + }, nil +} + // ConvertToolResultsToNativeChatMessage converts tool results to OpenAI tool messages func ConvertToolResultsToNativeChatMessage(toolResults []uctypes.AIToolResult) ([]uctypes.GenAIMessage, error) { if len(toolResults) == 0 { @@ -261,8 +366,36 @@ func ConvertAIChatToUIChat(aiChat uctypes.AIChat) (*uctypes.UIChat, error) { var parts []uctypes.UIMessagePart - // Add text content if present - if chatMsg.Message.Content != "" { + if len(chatMsg.Message.ContentParts) > 0 { + for _, cp := range chatMsg.Message.ContentParts { + switch cp.Type { + case "text": + if found, part := aiutil.ConvertDataUserFile(cp.Text); found { + if part != nil { + parts = append(parts, *part) + } + } else { + parts = append(parts, uctypes.UIMessagePart{ + Type: "text", + Text: cp.Text, + }) + } + case "image_url": + mimeType := cp.MimeType + if mimeType == "" { + mimeType = "image/*" + } + parts = append(parts, uctypes.UIMessagePart{ + Type: "data-userfile", + Data: uctypes.UIMessageDataUserFile{ + FileName: cp.FileName, + MimeType: mimeType, + PreviewUrl: cp.PreviewUrl, + }, + }) + } + } + } else if chatMsg.Message.Content != "" { parts = append(parts, uctypes.UIMessagePart{ Type: "text", Text: chatMsg.Message.Content, diff --git a/pkg/aiusechat/openaichat/openaichat-types.go b/pkg/aiusechat/openaichat/openaichat-types.go index f0bcc41614..18d28e3b20 100644 --- a/pkg/aiusechat/openaichat/openaichat-types.go +++ b/pkg/aiusechat/openaichat/openaichat-types.go @@ -4,6 +4,9 @@ package openaichat import ( + "bytes" + "encoding/json" + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" ) @@ -20,22 +23,115 @@ type ChatRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // "auto", "none", or struct } +type ChatContentPart struct { + Type string `json:"type"` // "text" or "image_url" + Text string `json:"text,omitempty"` // for type "text" + ImageUrl *ChatImageUrl `json:"image_url,omitempty"` // for type "image_url" + + FileName string `json:"filename,omitempty"` // internal: original filename + PreviewUrl string `json:"previewurl,omitempty"` // internal: 128x128 webp preview + MimeType string `json:"mimetype,omitempty"` // internal: original mimetype +} + +func (cp *ChatContentPart) clean() *ChatContentPart { + if cp.FileName == "" && cp.PreviewUrl == "" && cp.MimeType == "" { + return cp + } + rtn := *cp + rtn.FileName = "" + rtn.PreviewUrl = "" + rtn.MimeType = "" + return &rtn +} + +type ChatImageUrl struct { + Url string `json:"url"` + Detail string `json:"detail,omitempty"` // "auto", "low", "high" +} + type ChatRequestMessage struct { - Role string `json:"role"` // "system","user","assistant","tool" - Content string `json:"content,omitempty"` // normal text messages - ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message - ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool" - Name string `json:"name,omitempty"` // tool name on role:"tool" + Role string `json:"role"` // "system","user","assistant","tool" + Content string `json:"-"` // plain text (used when ContentParts is nil) + ContentParts []ChatContentPart `json:"-"` // multimodal parts (used when images present) + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // assistant tool-call message + ToolCallID string `json:"tool_call_id,omitempty"` // for role:"tool" + Name string `json:"name,omitempty"` // tool name on role:"tool" } -func (cm *ChatRequestMessage) clean() *ChatRequestMessage { - if len(cm.ToolCalls) == 0 { - return cm +// chatRequestMessageJSON is the wire format for ChatRequestMessage +type chatRequestMessageJSON struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +func (cm ChatRequestMessage) MarshalJSON() ([]byte, error) { + raw := chatRequestMessageJSON{ + Role: cm.Role, + ToolCalls: cm.ToolCalls, + ToolCallID: cm.ToolCallID, + Name: cm.Name, + } + if len(cm.ContentParts) > 0 { + b, err := json.Marshal(cm.ContentParts) + if err != nil { + return nil, err + } + raw.Content = b + } else if cm.Content != "" { + b, err := json.Marshal(cm.Content) + if err != nil { + return nil, err + } + raw.Content = b } + return json.Marshal(raw) +} + +func (cm *ChatRequestMessage) UnmarshalJSON(data []byte) error { + var raw chatRequestMessageJSON + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + cm.Role = raw.Role + cm.ToolCalls = raw.ToolCalls + cm.ToolCallID = raw.ToolCallID + cm.Name = raw.Name + cm.Content = "" + cm.ContentParts = nil + if len(raw.Content) == 0 || bytes.Equal(raw.Content, []byte("null")) { + return nil + } + // try array first + var parts []ChatContentPart + if err := json.Unmarshal(raw.Content, &parts); err == nil { + cm.ContentParts = parts + return nil + } + // fall back to string + var s string + if err := json.Unmarshal(raw.Content, &s); err != nil { + return err + } + cm.Content = s + return nil +} + +func (cm *ChatRequestMessage) clean() *ChatRequestMessage { rtn := *cm - rtn.ToolCalls = make([]ToolCall, len(cm.ToolCalls)) - for i, tc := range cm.ToolCalls { - rtn.ToolCalls[i] = *tc.clean() + if len(cm.ToolCalls) > 0 { + rtn.ToolCalls = make([]ToolCall, len(cm.ToolCalls)) + for i, tc := range cm.ToolCalls { + rtn.ToolCalls[i] = *tc.clean() + } + } + if len(cm.ContentParts) > 0 { + rtn.ContentParts = make([]ChatContentPart, len(cm.ContentParts)) + for i, cp := range cm.ContentParts { + rtn.ContentParts[i] = *cp.clean() + } } return &rtn } @@ -163,6 +259,10 @@ func (m *StoredChatMessage) Copy() *StoredChatMessage { } } } + if len(m.Message.ContentParts) > 0 { + copied.Message.ContentParts = make([]ChatContentPart, len(m.Message.ContentParts)) + copy(copied.Message.ContentParts, m.Message.ContentParts) + } if m.Usage != nil { usageCopy := *m.Usage copied.Usage = &usageCopy