package models import ( "bufio" "bytes" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "os" "path/filepath" "strconv" "strings" "time" ) type CoHereModel struct { BaseURL map[string]string URLSuffix URLSuffix httpClient *http.Client } func (c *CoHereModel) NewInstance(baseURL map[string]string) ModelDriver { return &CoHereModel{ BaseURL: baseURL, URLSuffix: c.URLSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, }, } } func NewCoHereModel(baseURL map[string]string, urlSuffix URLSuffix) *CoHereModel { return &CoHereModel{ BaseURL: baseURL, URLSuffix: urlSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, }, } } func (c *CoHereModel) Name() string { return "cohere" } func (c *CoHereModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is nil or empty") } if len(messages) == 0 { return nil, fmt.Errorf("messages is empty") } var region = "default" if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) for i, msg := range messages { apiMessages[i] = map[string]interface{}{ "role": msg.Role, "content": msg.Content, } } // Build request body reqBody := map[string]interface{}{ "model": modelName, "messages": apiMessages, "stream": false, "temperature": 0.3, } if chatModelConfig != nil { if chatModelConfig.Stream != nil { reqBody["stream"] = *chatModelConfig.Stream } if chatModelConfig.MaxTokens != nil { reqBody["max_tokens"] = *chatModelConfig.MaxTokens } if chatModelConfig.Temperature != nil { reqBody["temperature"] = *chatModelConfig.Temperature } if chatModelConfig.TopP != nil { reqBody["top_p"] = *chatModelConfig.TopP } if chatModelConfig.Thinking != nil { if *chatModelConfig.Thinking { reqBody["thinking"] = map[string]interface{}{ "type": "enabled", } } else { reqBody["thinking"] = map[string]interface{}{ "type": "disabled", } } } } jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("content-Type", "application/json") req.Header.Set("accept", "application/json") req.Header.Set("Authorization", fmt.Sprintf("bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Cohere chat API error: %d %s", resp.StatusCode, string(body)) } // Parse response var result map[string]interface{} if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) } messageMap, ok := result["message"].(map[string]interface{}) if !ok { return nil, fmt.Errorf("no message found in Cohere response: %s", string(body)) } contentArray, ok := messageMap["content"].([]interface{}) if !ok { return nil, fmt.Errorf("content is not an array in Cohere response") } var fullContent string var reasonContent string for _, cBlock := range contentArray { cmap, ok := cBlock.(map[string]interface{}) if !ok { continue } if blockType, ok := cmap["type"].(string); ok && blockType == "thinking" { if thinkingText, ok := cmap["thinking"].(string); ok { reasonContent += thinkingText } } else if text, ok := cmap["text"].(string); ok { fullContent += text } } chatResponse := &ChatResponse{ Answer: &fullContent, ReasonContent: &reasonContent, } return chatResponse, nil } func (c *CoHereModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { if len(messages) == 0 { return fmt.Errorf("messages is empty") } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Chat) apiMessages := make([]map[string]interface{}, len(messages)) for i, msg := range messages { apiMessages[i] = map[string]interface{}{ "role": msg.Role, "content": msg.Content, } } reqBody := map[string]interface{}{ "model": modelName, "messages": apiMessages, "stream": true, "temperature": 1, } if modelConfig != nil { if modelConfig.MaxTokens != nil { reqBody["max_tokens"] = *modelConfig.MaxTokens } if modelConfig.Temperature != nil { reqBody["temperature"] = *modelConfig.Temperature } if modelConfig.TopP != nil { reqBody["p"] = *modelConfig.TopP } } if modelConfig != nil { if modelConfig.Stream != nil { reqBody["stream"] = *modelConfig.Stream } if modelConfig.MaxTokens != nil { reqBody["max_tokens"] = *modelConfig.MaxTokens } if modelConfig.Temperature != nil { reqBody["temperature"] = *modelConfig.Temperature } if modelConfig.TopP != nil { reqBody["top_p"] = *modelConfig.TopP } if modelConfig.Thinking != nil { if *modelConfig.Thinking { reqBody["thinking"] = map[string]interface{}{ "type": "enabled", } } else { reqBody["thinking"] = map[string]interface{}{ "type": "disabled", } } } } jsonData, err := json.Marshal(reqBody) if err != nil { return fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("content-type", "application/json") req.Header.Set("accept", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) resp, err := c.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("Cohere stream API error %d: %s", resp.StatusCode, string(body)) } scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() data := strings.TrimSpace(line) if strings.HasPrefix(data, "data:") { data = strings.TrimSpace(data[5:]) } if data == "" || data == "[DONE]" { continue } var event map[string]interface{} if err = json.Unmarshal([]byte(data), &event); err != nil { continue } eventType, ok := event["type"].(string) if !ok { continue } if eventType == "message-end" { break } if eventType == "content-delta" { delta, ok := event["delta"].(map[string]interface{}) if !ok { continue } msg, ok := delta["message"].(map[string]interface{}) if !ok { continue } content, ok := msg["content"].(map[string]interface{}) if !ok { continue } if thinking, ok := content["thinking"].(string); ok && thinking != "" { if err := sender(nil, &thinking); err != nil { return err } } if text, ok := content["text"].(string); ok && text != "" { if err := sender(&text, nil); err != nil { return err } } } } endOfStream := "[DONE]" if err = sender(&endOfStream, nil); err != nil { return err } return scanner.Err() } func (c *CoHereModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { return []EmbeddingData{}, nil } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } baseURL := strings.TrimSuffix(c.BaseURL[region], "/") suffix := strings.TrimPrefix(c.URLSuffix.Embedding, "/") url := fmt.Sprintf("%s/%s", baseURL, suffix) reqBody := map[string]interface{}{ "model": *modelName, "texts": texts, "input_type": "search_document", "embedding_types": []string{"float"}, } jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Cohere embedding API error: status %d, body: %s", resp.StatusCode, string(body)) } var result struct { Embeddings struct { Float [][]float64 `json:"float"` } `json:"embeddings"` } if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } if len(result.Embeddings.Float) == 0 { return nil, fmt.Errorf("Cohere embedding response contains no float data: %s", string(body)) } var embeddings []EmbeddingData for i, floatArr := range result.Embeddings.Float { embeddings = append(embeddings, EmbeddingData{ Embedding: floatArr, Index: i, }) } return embeddings, nil } func (c *CoHereModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { if len(documents) == 0 { return &RerankResponse{}, nil } var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } baseURL := strings.TrimSuffix(c.BaseURL[region], "/") suffix := strings.TrimPrefix(c.URLSuffix.Rerank, "/") url := fmt.Sprintf("%s/%s", baseURL, suffix) var topN = rerankConfig.TopN if rerankConfig.TopN == 0 { topN = len(documents) } reqBody := map[string]interface{}{ "model": *modelName, "query": query, "documents": documents, "top_n": topN, } jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Cohere rerank API error: status %d, body: %s", resp.StatusCode, string(body)) } var rerankResp struct { Results []struct { Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` } `json:"results"` } if err := json.Unmarshal(body, &rerankResp); err != nil { return nil, fmt.Errorf("failed to decode response: %w", err) } var rerankResponse RerankResponse for _, result := range rerankResp.Results { rerankResult := RerankResult{ Index: result.Index, RelevanceScore: result.RelevanceScore, } rerankResponse.Data = append(rerankResponse.Data, rerankResult) } return &rerankResponse, nil } // TranscribeAudio transcribe audio func (c *CoHereModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { if file == nil || *file == "" { return nil, fmt.Errorf("file is missing") } region := "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.ASR) // multipart body var body bytes.Buffer writer := multipart.NewWriter(&body) // open audio file audioFile, err := os.Open(*file) if err != nil { return nil, fmt.Errorf("failed to open audio file: %w", err) } defer audioFile.Close() // create multipart file field if err = writer.WriteField("model", *modelName); err != nil { return nil, fmt.Errorf("failed to write model name: %w", err) } // extra params if asrConfig != nil && asrConfig.Params != nil { for key, value := range asrConfig.Params { var val string switch v := value.(type) { case string: val = v case bool: val = strconv.FormatBool(v) case int: val = strconv.Itoa(v) case int64: val = strconv.FormatInt(v, 10) case float32: val = strconv.FormatFloat(float64(v), 'f', -1, 32) case float64: val = strconv.FormatFloat(v, 'f', -1, 64) default: val = fmt.Sprintf("%v", v) } if err = writer.WriteField(key, val); err != nil { return nil, fmt.Errorf("failed to write field %s: %w", key, err) } } } // all form fields (model, language) must appear before the file part in the multipart body part, err := writer.CreateFormFile("file", filepath.Base(*file)) if err != nil { return nil, fmt.Errorf("failed to create form file: %w", err) } if _, err := io.Copy(part, audioFile); err != nil { return nil, fmt.Errorf("failed to copy audio file: %w", err) } if err = writer.Close(); err != nil { return nil, fmt.Errorf("failed to close writer: %w", err) } // build request req, err := http.NewRequest("POST", url, &body) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) req.Header.Set("Content-Type", writer.FormDataContentType()) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Cohere ASR API error: status %d, body: %s", resp.StatusCode, string(respBody)) } var result struct { Text string `json:"text"` } if err = json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) } return &ASRResponse{Text: result.Text}, nil } func (z *CoHereModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { return fmt.Errorf("%s, no such method", z.Name()) } // AudioSpeech convert text to audio func (c *CoHereModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { return nil, fmt.Errorf("%s, no such method", c.Name()) } func (z *CoHereModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { return fmt.Errorf("%s, no such method", z.Name()) } // OCRFile OCR file func (c *CoHereModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { return nil, fmt.Errorf("%s, no such method", c.Name()) } // ParseFile parse file func (z *CoHereModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } func (c *CoHereModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Models) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("accept", "application/json") if apiConfig != nil && apiConfig.ApiKey != nil { req.Header.Set("Authorization", fmt.Sprintf("bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) } resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Cohere API request failed with status %d: %s", resp.StatusCode, string(body)) } var result map[string]interface{} if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } models := make([]string, 0) if modelsRaw, ok := result["models"].([]interface{}); ok { for _, model := range modelsRaw { if modelMap, ok := model.(map[string]interface{}); ok { if modelName, ok := modelMap["name"].(string); ok { models = append(models, modelName) } } } } else { return nil, fmt.Errorf("failed to find 'models' array in response") } return models, nil } func (c *CoHereModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("%s, no such method", c.Name()) } func (c *CoHereModel) CheckConnection(apiConfig *APIConfig) error { _, err := c.ListModels(apiConfig) return err } func (z *CoHereModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } func (z *CoHereModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) }