diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index b10f18b5d..2587b82f1 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -9,6 +9,8 @@ "async_result": "async-result", "embedding": "embeddings", "rerank": "rerank", + "asr": "audio/transcriptions", + "tts": "audio/speech", "files": "files", "models": "models" }, @@ -268,4 +270,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index d90d63559..cd0bf86fc 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -22,7 +22,10 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "os" + "path/filepath" "ragflow/internal/common" "strings" "time" @@ -668,8 +671,129 @@ func (z *ZhipuAIModel) Rerank(modelName *string, query string, documents []strin } // TranscribeAudio transcribe audio -func (o *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) +func (z *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if file == nil || *file == "" { + return nil, fmt.Errorf("file is required") + } + if z.URLSuffix.ASR == "" { + return nil, fmt.Errorf("zhipu-ai: ASR URL suffix is not configured") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return nil, fmt.Errorf("zhipu-ai: no base URL configured for region %q", region) + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.ASR, "/")) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + if err := writer.WriteField("stream", "false"); err != nil { + return nil, fmt.Errorf("failed to write stream field: %w", err) + } + if err := writeZhipuASRParams(writer, asrConfig); err != nil { + return nil, err + } + + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ZhipuAI ASR API error: %s, body: %s", resp.Status, string(respBody)) + } + + var result struct { + Text string `json:"text"` + } + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &ASRResponse{Text: result.Text}, nil +} + +func writeZhipuASRParams(writer *multipart.Writer, asrConfig *ASRConfig) error { + if asrConfig == nil || asrConfig.Params == nil { + return nil + } + for key, value := range asrConfig.Params { + switch key { + case "model", "stream", "file", "file_base64": + continue + } + if err := writeZhipuASRField(writer, key, value); err != nil { + return err + } + } + return nil +} + +func writeZhipuASRField(writer *multipart.Writer, key string, value interface{}) error { + switch v := value.(type) { + case nil: + return nil + case []string: + for _, item := range v { + if err := writer.WriteField(key, item); err != nil { + return fmt.Errorf("failed to write field %s: %w", key, err) + } + } + return nil + case []interface{}: + for _, item := range v { + if err := writer.WriteField(key, fmt.Sprint(item)); err != nil { + return fmt.Errorf("failed to write field %s: %w", key, err) + } + } + return nil + default: + if err := writer.WriteField(key, fmt.Sprint(v)); err != nil { + return fmt.Errorf("failed to write field %s: %w", key, err) + } + return nil + } } func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { @@ -677,12 +801,135 @@ func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string } // AudioSpeech convert text to audio -func (o *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { - return nil, fmt.Errorf("%s, no such method", o.Name()) +func (z *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, false) + if err != nil { + return nil, err + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ZhipuAI TTS API error: %s, body: %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil } func (z *ZhipuAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { - return fmt.Errorf("%s, no such method", z.Name()) + if sender == nil { + return fmt.Errorf("sender is required") + } + reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, true) + if err != nil { + return err + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("ZhipuAI stream TTS API error: %s, body: %s", resp.Status, string(body)) + } + + buf := make([]byte, 32*1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + chunk := string(buf[:n]) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + if err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("error reading ZhipuAI binary audio stream: %w", err) + } + } + return nil +} + +func (z *ZhipuAIModel) buildTTSRequest(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (map[string]interface{}, string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, "", fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, "", fmt.Errorf("model name is required") + } + if audioContent == nil || *audioContent == "" { + return nil, "", fmt.Errorf("audio content is empty") + } + if z.URLSuffix.TTS == "" { + return nil, "", fmt.Errorf("zhipu-ai: TTS URL suffix is not configured") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return nil, "", fmt.Errorf("zhipu-ai: no base URL configured for region %q", region) + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + "stream": stream, + } + if ttsConfig != nil { + for key, value := range ttsConfig.Params { + switch key { + case "model", "input", "stream", "response_format": + continue + } + reqBody[key] = value + } + if ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.TTS, "/")) + return reqBody, url, nil } // OCRFile OCR file diff --git a/internal/entity/models/zhipu_ai_test.go b/internal/entity/models/zhipu_ai_test.go new file mode 100644 index 000000000..f68d6aa3c --- /dev/null +++ b/internal/entity/models/zhipu_ai_test.go @@ -0,0 +1,339 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func newZhipuAIForTest(baseURL string) *ZhipuAIModel { + return NewZhipuAIModel( + map[string]string{"default": baseURL}, + URLSuffix{ASR: "audio/transcriptions", TTS: "audio/speech"}, + ) +} + +func writeZhipuAITestAudio(t *testing.T) string { + t.Helper() + + file, err := os.CreateTemp(t.TempDir(), "speech-*.mp3") + if err != nil { + t.Fatalf("create temp audio: %v", err) + } + defer file.Close() + + if _, err = file.WriteString("fake audio"); err != nil { + t.Fatalf("write temp audio: %v", err) + } + return file.Name() +} + +func TestZhipuAITranscribeAudio(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s", r.Method) + return + } + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("path=%s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") { + t.Errorf("Content-Type=%q", got) + return + } + + if err := r.ParseMultipartForm(1024 * 1024); err != nil { + t.Errorf("ParseMultipartForm: %v", err) + return + } + if got := r.FormValue("model"); got != "glm-asr-2512" { + t.Errorf("model=%q", got) + } + if got := r.MultipartForm.Value["model"]; len(got) != 1 { + t.Errorf("model values=%v", got) + } + if got := r.FormValue("stream"); got != "false" { + t.Errorf("stream=%q", got) + } + if got := r.MultipartForm.Value["stream"]; len(got) != 1 { + t.Errorf("stream values=%v", got) + } + if got := r.FormValue("prompt"); got != "previous transcript" { + t.Errorf("prompt=%q", got) + } + if got := r.FormValue("user_id"); got != "12345" { + t.Errorf("user_id=%q", got) + } + if got := r.MultipartForm.Value["hotwords"]; len(got) != 2 || got[0] != "RAGFlow" || got[1] != "ZhipuAI" { + t.Errorf("hotwords=%v", got) + } + if got := r.MultipartForm.Value["file"]; len(got) != 0 { + t.Errorf("file values=%v", got) + } + if _, _, err := r.FormFile("file"); err != nil { + t.Errorf("file field: %v", err) + return + } + + _ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"}) + })) + defer srv.Close() + + apiKey := "test-key" + modelName := "glm-asr-2512" + file := writeZhipuAITestAudio(t) + resp, err := newZhipuAIForTest(srv.URL).TranscribeAudio( + &modelName, + &file, + &APIConfig{ApiKey: &apiKey}, + &ASRConfig{Params: map[string]interface{}{ + "prompt": "previous transcript", + "hotwords": []string{"RAGFlow", "ZhipuAI"}, + "model": "ignored-model", + "stream": true, + "file": "ignored-file", + "user_id": 12345, + "nil_value": nil, + }}, + ) + if err != nil { + t.Fatalf("TranscribeAudio: %v", err) + } + if resp.Text != "hello world" { + t.Errorf("Text=%q", resp.Text) + } +} + +func TestZhipuAITranscribeAudioValidation(t *testing.T) { + apiKey := "test-key" + modelName := "glm-asr-2512" + file := "speech.mp3" + + tests := []struct { + name string + modelName *string + file *string + apiConfig *APIConfig + want string + }{ + {name: "missing api key", modelName: &modelName, file: &file, apiConfig: &APIConfig{}, want: "api key is required"}, + {name: "missing model", modelName: nil, file: &file, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"}, + {name: "missing file", modelName: &modelName, file: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "file is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newZhipuAIForTest("http://unused").TranscribeAudio(tt.modelName, tt.file, tt.apiConfig, nil) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("error=%v, want %q", err, tt.want) + } + }) + } +} + +func TestZhipuAITranscribeAudioRequiresASRSuffix(t *testing.T) { + apiKey := "test-key" + modelName := "glm-asr-2512" + file := writeZhipuAITestAudio(t) + _, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).TranscribeAudio( + &modelName, + &file, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err == nil || !strings.Contains(err.Error(), "ASR URL suffix is not configured") { + t.Fatalf("error=%v", err) + } +} + +func TestZhipuAITranscribeAudioHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"bad request"}`) + })) + defer srv.Close() + + apiKey := "test-key" + modelName := "glm-asr-2512" + file := writeZhipuAITestAudio(t) + _, err := newZhipuAIForTest(srv.URL).TranscribeAudio(&modelName, &file, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "ZhipuAI ASR API error: 400 Bad Request") { + t.Fatalf("error=%v", err) + } +} + +func TestZhipuAIAudioSpeech(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method=%s", r.Method) + return + } + if r.URL.Path != "/audio/speech" { + t.Errorf("path=%s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("Content-Type=%q", got) + return + } + + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("Decode: %v", err) + return + } + if body["model"] != "glm-tts" { + t.Errorf("model=%v", body["model"]) + } + if body["input"] != "hello" { + t.Errorf("input=%v", body["input"]) + } + if body["stream"] != false { + t.Errorf("stream=%v", body["stream"]) + } + if body["voice"] != "zhipu" { + t.Errorf("voice=%v", body["voice"]) + } + if body["response_format"] != "mp3" { + t.Errorf("response_format=%v", body["response_format"]) + } + + _, _ = w.Write([]byte("audio-bytes")) + })) + defer srv.Close() + + apiKey := "test-key" + modelName := "glm-tts" + content := "hello" + resp, err := newZhipuAIForTest(srv.URL).AudioSpeech( + &modelName, + &content, + &APIConfig{ApiKey: &apiKey}, + &TTSConfig{Format: "mp3", Params: map[string]interface{}{ + "voice": "zhipu", + "model": "ignored-model", + "input": "ignored-input", + "stream": true, + "response_format": "wav", + }}, + ) + if err != nil { + t.Fatalf("AudioSpeech: %v", err) + } + if string(resp.Audio) != "audio-bytes" { + t.Errorf("Audio=%q", string(resp.Audio)) + } +} + +func TestZhipuAIAudioSpeechWithSender(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("Decode: %v", err) + return + } + if body["stream"] != true { + t.Errorf("stream=%v", body["stream"]) + } + _, _ = w.Write([]byte("chunk-one")) + _, _ = w.Write([]byte("chunk-two")) + })) + defer srv.Close() + + apiKey := "test-key" + modelName := "glm-tts" + content := "hello" + var chunks []string + err := newZhipuAIForTest(srv.URL).AudioSpeechWithSender( + &modelName, + &content, + &APIConfig{ApiKey: &apiKey}, + nil, + func(content *string, reasoning *string) error { + if content != nil { + chunks = append(chunks, *content) + } + if reasoning != nil { + t.Errorf("reasoning=%q", *reasoning) + } + return nil + }, + ) + if err != nil { + t.Fatalf("AudioSpeechWithSender: %v", err) + } + if strings.Join(chunks, "") != "chunk-onechunk-two" { + t.Errorf("chunks=%q", strings.Join(chunks, "")) + } +} + +func TestZhipuAIAudioSpeechValidation(t *testing.T) { + apiKey := "test-key" + modelName := "glm-tts" + content := "hello" + + tests := []struct { + name string + modelName *string + audioContent *string + apiConfig *APIConfig + want string + }{ + {name: "missing api key", modelName: &modelName, audioContent: &content, apiConfig: &APIConfig{}, want: "api key is required"}, + {name: "missing model", modelName: nil, audioContent: &content, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"}, + {name: "missing content", modelName: &modelName, audioContent: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "audio content is empty"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newZhipuAIForTest("http://unused").AudioSpeech(tt.modelName, tt.audioContent, tt.apiConfig, nil) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("error=%v, want %q", err, tt.want) + } + }) + } +} + +func TestZhipuAIAudioSpeechRequiresTTSSuffix(t *testing.T) { + apiKey := "test-key" + modelName := "glm-tts" + content := "hello" + _, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).AudioSpeech( + &modelName, + &content, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err == nil || !strings.Contains(err.Error(), "TTS URL suffix is not configured") { + t.Fatalf("error=%v", err) + } +} + +func TestZhipuAIAudioSpeechHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"bad request"}`) + })) + defer srv.Close() + + apiKey := "test-key" + modelName := "glm-tts" + content := "hello" + _, err := newZhipuAIForTest(srv.URL).AudioSpeech(&modelName, &content, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "ZhipuAI TTS API error: 400 Bad Request") { + t.Fatalf("error=%v", err) + } +}