mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-24 01:47:36 +08:00
Go: implement ASR in ZhipuAI driver (#15134)
### What problem does this PR solve? This PR implements ASR and TTS support for the ZhipuAI Go driver. The ZhipuAI model config already advertises `glm-asr-2512` as an ASR model, but the Go driver returned `zhipu, no such method` from `TranscribeAudio`. This adds the documented audio transcription endpoint suffix and sends multipart transcription requests with `model`, `stream=false`, and `file` fields. Per maintainer review, this also adds the ZhipuAI TTS endpoint suffix and implements `AudioSpeech` / `AudioSpeechWithSender` for `glm-tts`. Closes #15133 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -9,6 +9,8 @@
|
||||
"async_result": "async-result",
|
||||
"embedding": "embeddings",
|
||||
"rerank": "rerank",
|
||||
"asr": "audio/transcriptions",
|
||||
"tts": "audio/speech",
|
||||
"files": "files",
|
||||
"models": "models"
|
||||
},
|
||||
@ -268,4 +270,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,7 +22,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ragflow/internal/common"
|
||||
"strings"
|
||||
"time"
|
||||
@ -668,8 +671,129 @@ func (z *ZhipuAIModel) Rerank(modelName *string, query string, documents []strin
|
||||
}
|
||||
|
||||
// TranscribeAudio transcribe audio
|
||||
func (o *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", o.Name())
|
||||
func (z *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
if file == nil || *file == "" {
|
||||
return nil, fmt.Errorf("file is required")
|
||||
}
|
||||
if z.URLSuffix.ASR == "" {
|
||||
return nil, fmt.Errorf("zhipu-ai: ASR URL suffix is not configured")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
baseURL, ok := z.BaseURL[region]
|
||||
if !ok || baseURL == "" {
|
||||
return nil, fmt.Errorf("zhipu-ai: no base URL configured for region %q", region)
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.ASR, "/"))
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if err := writer.WriteField("model", *modelName); err != nil {
|
||||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||||
}
|
||||
if err := writer.WriteField("stream", "false"); err != nil {
|
||||
return nil, fmt.Errorf("failed to write stream field: %w", err)
|
||||
}
|
||||
if err := writeZhipuASRParams(writer, asrConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
audioFile, err := os.Open(*file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(*file))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create multipart file: %w", err)
|
||||
}
|
||||
if _, err = io.Copy(part, audioFile); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy audio data: %w", err)
|
||||
}
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, &body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("ZhipuAI ASR API error: %s, body: %s", resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err = json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
return &ASRResponse{Text: result.Text}, nil
|
||||
}
|
||||
|
||||
func writeZhipuASRParams(writer *multipart.Writer, asrConfig *ASRConfig) error {
|
||||
if asrConfig == nil || asrConfig.Params == nil {
|
||||
return nil
|
||||
}
|
||||
for key, value := range asrConfig.Params {
|
||||
switch key {
|
||||
case "model", "stream", "file", "file_base64":
|
||||
continue
|
||||
}
|
||||
if err := writeZhipuASRField(writer, key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeZhipuASRField(writer *multipart.Writer, key string, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []string:
|
||||
for _, item := range v {
|
||||
if err := writer.WriteField(key, item); err != nil {
|
||||
return fmt.Errorf("failed to write field %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if err := writer.WriteField(key, fmt.Sprint(item)); err != nil {
|
||||
return fmt.Errorf("failed to write field %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
if err := writer.WriteField(key, fmt.Sprint(v)); err != nil {
|
||||
return fmt.Errorf("failed to write field %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
@ -677,12 +801,135 @@ func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string
|
||||
}
|
||||
|
||||
// AudioSpeech convert text to audio
|
||||
func (o *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", o.Name())
|
||||
func (z *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("ZhipuAI TTS API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return &TTSResponse{Audio: body}, nil
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", z.Name())
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender is required")
|
||||
}
|
||||
reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := z.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("ZhipuAI stream TTS API error: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := string(buf[:n])
|
||||
if errSend := sender(&chunk, nil); errSend != nil {
|
||||
return errSend
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("error reading ZhipuAI binary audio stream: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *ZhipuAIModel) buildTTSRequest(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (map[string]interface{}, string, error) {
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, "", fmt.Errorf("api key is required")
|
||||
}
|
||||
if modelName == nil || *modelName == "" {
|
||||
return nil, "", fmt.Errorf("model name is required")
|
||||
}
|
||||
if audioContent == nil || *audioContent == "" {
|
||||
return nil, "", fmt.Errorf("audio content is empty")
|
||||
}
|
||||
if z.URLSuffix.TTS == "" {
|
||||
return nil, "", fmt.Errorf("zhipu-ai: TTS URL suffix is not configured")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
baseURL, ok := z.BaseURL[region]
|
||||
if !ok || baseURL == "" {
|
||||
return nil, "", fmt.Errorf("zhipu-ai: no base URL configured for region %q", region)
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": *audioContent,
|
||||
"stream": stream,
|
||||
}
|
||||
if ttsConfig != nil {
|
||||
for key, value := range ttsConfig.Params {
|
||||
switch key {
|
||||
case "model", "input", "stream", "response_format":
|
||||
continue
|
||||
}
|
||||
reqBody[key] = value
|
||||
}
|
||||
if ttsConfig.Format != "" {
|
||||
reqBody["response_format"] = ttsConfig.Format
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.TTS, "/"))
|
||||
return reqBody, url, nil
|
||||
}
|
||||
|
||||
// OCRFile OCR file
|
||||
|
||||
339
internal/entity/models/zhipu_ai_test.go
Normal file
339
internal/entity/models/zhipu_ai_test.go
Normal file
@ -0,0 +1,339 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newZhipuAIForTest(baseURL string) *ZhipuAIModel {
|
||||
return NewZhipuAIModel(
|
||||
map[string]string{"default": baseURL},
|
||||
URLSuffix{ASR: "audio/transcriptions", TTS: "audio/speech"},
|
||||
)
|
||||
}
|
||||
|
||||
func writeZhipuAITestAudio(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
file, err := os.CreateTemp(t.TempDir(), "speech-*.mp3")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp audio: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err = file.WriteString("fake audio"); err != nil {
|
||||
t.Fatalf("write temp audio: %v", err)
|
||||
}
|
||||
return file.Name()
|
||||
}
|
||||
|
||||
func TestZhipuAITranscribeAudio(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s", r.Method)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/audio/transcriptions" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") {
|
||||
t.Errorf("Content-Type=%q", got)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
|
||||
t.Errorf("ParseMultipartForm: %v", err)
|
||||
return
|
||||
}
|
||||
if got := r.FormValue("model"); got != "glm-asr-2512" {
|
||||
t.Errorf("model=%q", got)
|
||||
}
|
||||
if got := r.MultipartForm.Value["model"]; len(got) != 1 {
|
||||
t.Errorf("model values=%v", got)
|
||||
}
|
||||
if got := r.FormValue("stream"); got != "false" {
|
||||
t.Errorf("stream=%q", got)
|
||||
}
|
||||
if got := r.MultipartForm.Value["stream"]; len(got) != 1 {
|
||||
t.Errorf("stream values=%v", got)
|
||||
}
|
||||
if got := r.FormValue("prompt"); got != "previous transcript" {
|
||||
t.Errorf("prompt=%q", got)
|
||||
}
|
||||
if got := r.FormValue("user_id"); got != "12345" {
|
||||
t.Errorf("user_id=%q", got)
|
||||
}
|
||||
if got := r.MultipartForm.Value["hotwords"]; len(got) != 2 || got[0] != "RAGFlow" || got[1] != "ZhipuAI" {
|
||||
t.Errorf("hotwords=%v", got)
|
||||
}
|
||||
if got := r.MultipartForm.Value["file"]; len(got) != 0 {
|
||||
t.Errorf("file values=%v", got)
|
||||
}
|
||||
if _, _, err := r.FormFile("file"); err != nil {
|
||||
t.Errorf("file field: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-asr-2512"
|
||||
file := writeZhipuAITestAudio(t)
|
||||
resp, err := newZhipuAIForTest(srv.URL).TranscribeAudio(
|
||||
&modelName,
|
||||
&file,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&ASRConfig{Params: map[string]interface{}{
|
||||
"prompt": "previous transcript",
|
||||
"hotwords": []string{"RAGFlow", "ZhipuAI"},
|
||||
"model": "ignored-model",
|
||||
"stream": true,
|
||||
"file": "ignored-file",
|
||||
"user_id": 12345,
|
||||
"nil_value": nil,
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("TranscribeAudio: %v", err)
|
||||
}
|
||||
if resp.Text != "hello world" {
|
||||
t.Errorf("Text=%q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAITranscribeAudioValidation(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-asr-2512"
|
||||
file := "speech.mp3"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName *string
|
||||
file *string
|
||||
apiConfig *APIConfig
|
||||
want string
|
||||
}{
|
||||
{name: "missing api key", modelName: &modelName, file: &file, apiConfig: &APIConfig{}, want: "api key is required"},
|
||||
{name: "missing model", modelName: nil, file: &file, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"},
|
||||
{name: "missing file", modelName: &modelName, file: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "file is required"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := newZhipuAIForTest("http://unused").TranscribeAudio(tt.modelName, tt.file, tt.apiConfig, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("error=%v, want %q", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAITranscribeAudioRequiresASRSuffix(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-asr-2512"
|
||||
file := writeZhipuAITestAudio(t)
|
||||
_, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).TranscribeAudio(
|
||||
&modelName,
|
||||
&file,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), "ASR URL suffix is not configured") {
|
||||
t.Fatalf("error=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAITranscribeAudioHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, `{"error":"bad request"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-asr-2512"
|
||||
file := writeZhipuAITestAudio(t)
|
||||
_, err := newZhipuAIForTest(srv.URL).TranscribeAudio(&modelName, &file, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "ZhipuAI ASR API error: 400 Bad Request") {
|
||||
t.Fatalf("error=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAIAudioSpeech(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method=%s", r.Method)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/audio/speech" {
|
||||
t.Errorf("path=%s", r.URL.Path)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization=%q", got)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
|
||||
t.Errorf("Content-Type=%q", got)
|
||||
return
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("Decode: %v", err)
|
||||
return
|
||||
}
|
||||
if body["model"] != "glm-tts" {
|
||||
t.Errorf("model=%v", body["model"])
|
||||
}
|
||||
if body["input"] != "hello" {
|
||||
t.Errorf("input=%v", body["input"])
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("stream=%v", body["stream"])
|
||||
}
|
||||
if body["voice"] != "zhipu" {
|
||||
t.Errorf("voice=%v", body["voice"])
|
||||
}
|
||||
if body["response_format"] != "mp3" {
|
||||
t.Errorf("response_format=%v", body["response_format"])
|
||||
}
|
||||
|
||||
_, _ = w.Write([]byte("audio-bytes"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-tts"
|
||||
content := "hello"
|
||||
resp, err := newZhipuAIForTest(srv.URL).AudioSpeech(
|
||||
&modelName,
|
||||
&content,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
&TTSConfig{Format: "mp3", Params: map[string]interface{}{
|
||||
"voice": "zhipu",
|
||||
"model": "ignored-model",
|
||||
"input": "ignored-input",
|
||||
"stream": true,
|
||||
"response_format": "wav",
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AudioSpeech: %v", err)
|
||||
}
|
||||
if string(resp.Audio) != "audio-bytes" {
|
||||
t.Errorf("Audio=%q", string(resp.Audio))
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAIAudioSpeechWithSender(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("Decode: %v", err)
|
||||
return
|
||||
}
|
||||
if body["stream"] != true {
|
||||
t.Errorf("stream=%v", body["stream"])
|
||||
}
|
||||
_, _ = w.Write([]byte("chunk-one"))
|
||||
_, _ = w.Write([]byte("chunk-two"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-tts"
|
||||
content := "hello"
|
||||
var chunks []string
|
||||
err := newZhipuAIForTest(srv.URL).AudioSpeechWithSender(
|
||||
&modelName,
|
||||
&content,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
func(content *string, reasoning *string) error {
|
||||
if content != nil {
|
||||
chunks = append(chunks, *content)
|
||||
}
|
||||
if reasoning != nil {
|
||||
t.Errorf("reasoning=%q", *reasoning)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("AudioSpeechWithSender: %v", err)
|
||||
}
|
||||
if strings.Join(chunks, "") != "chunk-onechunk-two" {
|
||||
t.Errorf("chunks=%q", strings.Join(chunks, ""))
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAIAudioSpeechValidation(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-tts"
|
||||
content := "hello"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName *string
|
||||
audioContent *string
|
||||
apiConfig *APIConfig
|
||||
want string
|
||||
}{
|
||||
{name: "missing api key", modelName: &modelName, audioContent: &content, apiConfig: &APIConfig{}, want: "api key is required"},
|
||||
{name: "missing model", modelName: nil, audioContent: &content, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"},
|
||||
{name: "missing content", modelName: &modelName, audioContent: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "audio content is empty"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := newZhipuAIForTest("http://unused").AudioSpeech(tt.modelName, tt.audioContent, tt.apiConfig, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("error=%v, want %q", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAIAudioSpeechRequiresTTSSuffix(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-tts"
|
||||
content := "hello"
|
||||
_, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).AudioSpeech(
|
||||
&modelName,
|
||||
&content,
|
||||
&APIConfig{ApiKey: &apiKey},
|
||||
nil,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), "TTS URL suffix is not configured") {
|
||||
t.Fatalf("error=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhipuAIAudioSpeechHTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, `{"error":"bad request"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
apiKey := "test-key"
|
||||
modelName := "glm-tts"
|
||||
content := "hello"
|
||||
_, err := newZhipuAIForTest(srv.URL).AudioSpeech(&modelName, &content, &APIConfig{ApiKey: &apiKey}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "ZhipuAI TTS API error: 400 Bad Request") {
|
||||
t.Fatalf("error=%v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user