Go: implement ASR in ZhipuAI driver (#15134)

### What problem does this PR solve?

This PR implements ASR and TTS support for the ZhipuAI Go driver.

The ZhipuAI model config already advertises `glm-asr-2512` as an ASR
model, but the Go driver returned `zhipu, no such method` from
`TranscribeAudio`. This adds the documented audio transcription endpoint
suffix and sends multipart transcription requests with `model`,
`stream=false`, and `file` fields.

Per maintainer review, this also adds the ZhipuAI TTS endpoint suffix
and implements `AudioSpeech` / `AudioSpeechWithSender` for `glm-tts`.

Closes #15133

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Jake Armstrong
2026-05-21 17:53:18 -10:00
committed by GitHub
parent b2053cc3c7
commit b2bf9155ed
3 changed files with 594 additions and 6 deletions

View File

@ -9,6 +9,8 @@
"async_result": "async-result",
"embedding": "embeddings",
"rerank": "rerank",
"asr": "audio/transcriptions",
"tts": "audio/speech",
"files": "files",
"models": "models"
},
@ -268,4 +270,4 @@
]
}
]
}
}

View File

@ -22,7 +22,10 @@ import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"ragflow/internal/common"
"strings"
"time"
@ -668,8 +671,129 @@ func (z *ZhipuAIModel) Rerank(modelName *string, query string, documents []strin
}
// TranscribeAudio transcribe audio
func (o *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
return nil, fmt.Errorf("%s, no such method", o.Name())
func (z *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, fmt.Errorf("model name is required")
}
if file == nil || *file == "" {
return nil, fmt.Errorf("file is required")
}
if z.URLSuffix.ASR == "" {
return nil, fmt.Errorf("zhipu-ai: ASR URL suffix is not configured")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, ok := z.BaseURL[region]
if !ok || baseURL == "" {
return nil, fmt.Errorf("zhipu-ai: no base URL configured for region %q", region)
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.ASR, "/"))
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if err := writer.WriteField("model", *modelName); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
if err := writer.WriteField("stream", "false"); err != nil {
return nil, fmt.Errorf("failed to write stream field: %w", err)
}
if err := writeZhipuASRParams(writer, asrConfig); err != nil {
return nil, err
}
audioFile, err := os.Open(*file)
if err != nil {
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
part, err := writer.CreateFormFile("file", filepath.Base(*file))
if err != nil {
return nil, fmt.Errorf("failed to create multipart file: %w", err)
}
if _, err = io.Copy(part, audioFile); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
if err = writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, &body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ZhipuAI ASR API error: %s, body: %s", resp.Status, string(respBody))
}
var result struct {
Text string `json:"text"`
}
if err = json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &ASRResponse{Text: result.Text}, nil
}
func writeZhipuASRParams(writer *multipart.Writer, asrConfig *ASRConfig) error {
if asrConfig == nil || asrConfig.Params == nil {
return nil
}
for key, value := range asrConfig.Params {
switch key {
case "model", "stream", "file", "file_base64":
continue
}
if err := writeZhipuASRField(writer, key, value); err != nil {
return err
}
}
return nil
}
func writeZhipuASRField(writer *multipart.Writer, key string, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case []string:
for _, item := range v {
if err := writer.WriteField(key, item); err != nil {
return fmt.Errorf("failed to write field %s: %w", key, err)
}
}
return nil
case []interface{}:
for _, item := range v {
if err := writer.WriteField(key, fmt.Sprint(item)); err != nil {
return fmt.Errorf("failed to write field %s: %w", key, err)
}
}
return nil
default:
if err := writer.WriteField(key, fmt.Sprint(v)); err != nil {
return fmt.Errorf("failed to write field %s: %w", key, err)
}
return nil
}
}
func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
@ -677,12 +801,135 @@ func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string
}
// AudioSpeech convert text to audio
func (o *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
return nil, fmt.Errorf("%s, no such method", o.Name())
func (z *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, false)
if err != nil {
return nil, err
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("ZhipuAI TTS API error: %s, body: %s", resp.Status, string(body))
}
return &TTSResponse{Audio: body}, nil
}
func (z *ZhipuAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", z.Name())
if sender == nil {
return fmt.Errorf("sender is required")
}
reqBody, url, err := z.buildTTSRequest(modelName, audioContent, apiConfig, ttsConfig, true)
if err != nil {
return err
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := z.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("ZhipuAI stream TTS API error: %s, body: %s", resp.Status, string(body))
}
buf := make([]byte, 32*1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
chunk := string(buf[:n])
if errSend := sender(&chunk, nil); errSend != nil {
return errSend
}
}
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("error reading ZhipuAI binary audio stream: %w", err)
}
}
return nil
}
func (z *ZhipuAIModel) buildTTSRequest(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, stream bool) (map[string]interface{}, string, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, "", fmt.Errorf("api key is required")
}
if modelName == nil || *modelName == "" {
return nil, "", fmt.Errorf("model name is required")
}
if audioContent == nil || *audioContent == "" {
return nil, "", fmt.Errorf("audio content is empty")
}
if z.URLSuffix.TTS == "" {
return nil, "", fmt.Errorf("zhipu-ai: TTS URL suffix is not configured")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
baseURL, ok := z.BaseURL[region]
if !ok || baseURL == "" {
return nil, "", fmt.Errorf("zhipu-ai: no base URL configured for region %q", region)
}
reqBody := map[string]interface{}{
"model": *modelName,
"input": *audioContent,
"stream": stream,
}
if ttsConfig != nil {
for key, value := range ttsConfig.Params {
switch key {
case "model", "input", "stream", "response_format":
continue
}
reqBody[key] = value
}
if ttsConfig.Format != "" {
reqBody["response_format"] = ttsConfig.Format
}
}
url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), strings.TrimLeft(z.URLSuffix.TTS, "/"))
return reqBody, url, nil
}
// OCRFile OCR file

View File

@ -0,0 +1,339 @@
package models
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func newZhipuAIForTest(baseURL string) *ZhipuAIModel {
return NewZhipuAIModel(
map[string]string{"default": baseURL},
URLSuffix{ASR: "audio/transcriptions", TTS: "audio/speech"},
)
}
func writeZhipuAITestAudio(t *testing.T) string {
t.Helper()
file, err := os.CreateTemp(t.TempDir(), "speech-*.mp3")
if err != nil {
t.Fatalf("create temp audio: %v", err)
}
defer file.Close()
if _, err = file.WriteString("fake audio"); err != nil {
t.Fatalf("write temp audio: %v", err)
}
return file.Name()
}
func TestZhipuAITranscribeAudio(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s", r.Method)
return
}
if r.URL.Path != "/audio/transcriptions" {
t.Errorf("path=%s", r.URL.Path)
return
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization=%q", got)
return
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data") {
t.Errorf("Content-Type=%q", got)
return
}
if err := r.ParseMultipartForm(1024 * 1024); err != nil {
t.Errorf("ParseMultipartForm: %v", err)
return
}
if got := r.FormValue("model"); got != "glm-asr-2512" {
t.Errorf("model=%q", got)
}
if got := r.MultipartForm.Value["model"]; len(got) != 1 {
t.Errorf("model values=%v", got)
}
if got := r.FormValue("stream"); got != "false" {
t.Errorf("stream=%q", got)
}
if got := r.MultipartForm.Value["stream"]; len(got) != 1 {
t.Errorf("stream values=%v", got)
}
if got := r.FormValue("prompt"); got != "previous transcript" {
t.Errorf("prompt=%q", got)
}
if got := r.FormValue("user_id"); got != "12345" {
t.Errorf("user_id=%q", got)
}
if got := r.MultipartForm.Value["hotwords"]; len(got) != 2 || got[0] != "RAGFlow" || got[1] != "ZhipuAI" {
t.Errorf("hotwords=%v", got)
}
if got := r.MultipartForm.Value["file"]; len(got) != 0 {
t.Errorf("file values=%v", got)
}
if _, _, err := r.FormFile("file"); err != nil {
t.Errorf("file field: %v", err)
return
}
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
}))
defer srv.Close()
apiKey := "test-key"
modelName := "glm-asr-2512"
file := writeZhipuAITestAudio(t)
resp, err := newZhipuAIForTest(srv.URL).TranscribeAudio(
&modelName,
&file,
&APIConfig{ApiKey: &apiKey},
&ASRConfig{Params: map[string]interface{}{
"prompt": "previous transcript",
"hotwords": []string{"RAGFlow", "ZhipuAI"},
"model": "ignored-model",
"stream": true,
"file": "ignored-file",
"user_id": 12345,
"nil_value": nil,
}},
)
if err != nil {
t.Fatalf("TranscribeAudio: %v", err)
}
if resp.Text != "hello world" {
t.Errorf("Text=%q", resp.Text)
}
}
func TestZhipuAITranscribeAudioValidation(t *testing.T) {
apiKey := "test-key"
modelName := "glm-asr-2512"
file := "speech.mp3"
tests := []struct {
name string
modelName *string
file *string
apiConfig *APIConfig
want string
}{
{name: "missing api key", modelName: &modelName, file: &file, apiConfig: &APIConfig{}, want: "api key is required"},
{name: "missing model", modelName: nil, file: &file, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"},
{name: "missing file", modelName: &modelName, file: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "file is required"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newZhipuAIForTest("http://unused").TranscribeAudio(tt.modelName, tt.file, tt.apiConfig, nil)
if err == nil || !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error=%v, want %q", err, tt.want)
}
})
}
}
func TestZhipuAITranscribeAudioRequiresASRSuffix(t *testing.T) {
apiKey := "test-key"
modelName := "glm-asr-2512"
file := writeZhipuAITestAudio(t)
_, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).TranscribeAudio(
&modelName,
&file,
&APIConfig{ApiKey: &apiKey},
nil,
)
if err == nil || !strings.Contains(err.Error(), "ASR URL suffix is not configured") {
t.Fatalf("error=%v", err)
}
}
func TestZhipuAITranscribeAudioHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, `{"error":"bad request"}`)
}))
defer srv.Close()
apiKey := "test-key"
modelName := "glm-asr-2512"
file := writeZhipuAITestAudio(t)
_, err := newZhipuAIForTest(srv.URL).TranscribeAudio(&modelName, &file, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "ZhipuAI ASR API error: 400 Bad Request") {
t.Fatalf("error=%v", err)
}
}
func TestZhipuAIAudioSpeech(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("method=%s", r.Method)
return
}
if r.URL.Path != "/audio/speech" {
t.Errorf("path=%s", r.URL.Path)
return
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization=%q", got)
return
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") {
t.Errorf("Content-Type=%q", got)
return
}
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("Decode: %v", err)
return
}
if body["model"] != "glm-tts" {
t.Errorf("model=%v", body["model"])
}
if body["input"] != "hello" {
t.Errorf("input=%v", body["input"])
}
if body["stream"] != false {
t.Errorf("stream=%v", body["stream"])
}
if body["voice"] != "zhipu" {
t.Errorf("voice=%v", body["voice"])
}
if body["response_format"] != "mp3" {
t.Errorf("response_format=%v", body["response_format"])
}
_, _ = w.Write([]byte("audio-bytes"))
}))
defer srv.Close()
apiKey := "test-key"
modelName := "glm-tts"
content := "hello"
resp, err := newZhipuAIForTest(srv.URL).AudioSpeech(
&modelName,
&content,
&APIConfig{ApiKey: &apiKey},
&TTSConfig{Format: "mp3", Params: map[string]interface{}{
"voice": "zhipu",
"model": "ignored-model",
"input": "ignored-input",
"stream": true,
"response_format": "wav",
}},
)
if err != nil {
t.Fatalf("AudioSpeech: %v", err)
}
if string(resp.Audio) != "audio-bytes" {
t.Errorf("Audio=%q", string(resp.Audio))
}
}
func TestZhipuAIAudioSpeechWithSender(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("Decode: %v", err)
return
}
if body["stream"] != true {
t.Errorf("stream=%v", body["stream"])
}
_, _ = w.Write([]byte("chunk-one"))
_, _ = w.Write([]byte("chunk-two"))
}))
defer srv.Close()
apiKey := "test-key"
modelName := "glm-tts"
content := "hello"
var chunks []string
err := newZhipuAIForTest(srv.URL).AudioSpeechWithSender(
&modelName,
&content,
&APIConfig{ApiKey: &apiKey},
nil,
func(content *string, reasoning *string) error {
if content != nil {
chunks = append(chunks, *content)
}
if reasoning != nil {
t.Errorf("reasoning=%q", *reasoning)
}
return nil
},
)
if err != nil {
t.Fatalf("AudioSpeechWithSender: %v", err)
}
if strings.Join(chunks, "") != "chunk-onechunk-two" {
t.Errorf("chunks=%q", strings.Join(chunks, ""))
}
}
func TestZhipuAIAudioSpeechValidation(t *testing.T) {
apiKey := "test-key"
modelName := "glm-tts"
content := "hello"
tests := []struct {
name string
modelName *string
audioContent *string
apiConfig *APIConfig
want string
}{
{name: "missing api key", modelName: &modelName, audioContent: &content, apiConfig: &APIConfig{}, want: "api key is required"},
{name: "missing model", modelName: nil, audioContent: &content, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "model name is required"},
{name: "missing content", modelName: &modelName, audioContent: nil, apiConfig: &APIConfig{ApiKey: &apiKey}, want: "audio content is empty"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newZhipuAIForTest("http://unused").AudioSpeech(tt.modelName, tt.audioContent, tt.apiConfig, nil)
if err == nil || !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error=%v, want %q", err, tt.want)
}
})
}
}
func TestZhipuAIAudioSpeechRequiresTTSSuffix(t *testing.T) {
apiKey := "test-key"
modelName := "glm-tts"
content := "hello"
_, err := NewZhipuAIModel(map[string]string{"default": "http://unused"}, URLSuffix{}).AudioSpeech(
&modelName,
&content,
&APIConfig{ApiKey: &apiKey},
nil,
)
if err == nil || !strings.Contains(err.Error(), "TTS URL suffix is not configured") {
t.Fatalf("error=%v", err)
}
}
func TestZhipuAIAudioSpeechHTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, `{"error":"bad request"}`)
}))
defer srv.Close()
apiKey := "test-key"
modelName := "glm-tts"
content := "hello"
_, err := newZhipuAIForTest(srv.URL).AudioSpeech(&modelName, &content, &APIConfig{ApiKey: &apiKey}, nil)
if err == nil || !strings.Contains(err.Error(), "ZhipuAI TTS API error: 400 Bad Request") {
t.Fatalf("error=%v", err)
}
}