Go: implement ASR and TTS for Xinference (#15096)

### What problem does this PR solve?

implement ASR and TTS for Xinference

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Haruko386
2026-05-21 18:28:06 +08:00
committed by GitHub
parent 111cdc77b5
commit a725e114f9
7 changed files with 266 additions and 44 deletions

View File

@ -7,6 +7,7 @@
"chat": "openai/v1/chat/completions",
"models": "openai/v1/models",
"embedding": "openai/v1/embeddings",
"balance": "openapi/v1/billing/balance/detail",
"rerank": "openai/v1/rerank"
},
"class": "novita",

View File

@ -66,7 +66,6 @@ func (f *FishAudioModel) Rerank(modelName *string, query string, documents []str
// TranscribeAudio transcribe audio
func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
return nil, fmt.Errorf("FishAudio API key is missing")
}
@ -151,11 +150,7 @@ func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiCon
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"FishAudio ASR error: %s - %s",
resp.Status,
string(respBody),
)
return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody))
}
// result

View File

@ -210,7 +210,7 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/chat/completions", k.BaseURL[region])
url := fmt.Sprintf("%s/%s", k.BaseURL[region], k.URLSuffix.Chat)
// Convert messages to API format
apiMessages := make([]map[string]interface{}, len(messages))
@ -228,38 +228,40 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess
"stream": true,
}
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig != nil {
if chatModelConfig.Stream != nil {
reqBody["stream"] = *chatModelConfig.Stream
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.MaxTokens != nil {
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.Temperature != nil {
reqBody["temperature"] = *chatModelConfig.Temperature
}
if chatModelConfig.DoSample != nil {
reqBody["do_sample"] = *chatModelConfig.DoSample
}
if chatModelConfig.DoSample != nil {
reqBody["do_sample"] = *chatModelConfig.DoSample
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.TopP != nil {
reqBody["top_p"] = *chatModelConfig.TopP
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Stop != nil {
reqBody["stop"] = *chatModelConfig.Stop
}
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
if chatModelConfig.Thinking != nil {
if *chatModelConfig.Thinking {
reqBody["thinking"] = map[string]interface{}{
"type": "enabled",
}
} else {
reqBody["thinking"] = map[string]interface{}{
"type": "disabled",
}
}
}
}
@ -364,7 +366,7 @@ func (z *MoonshotModel) Embed(modelName *string, texts []string, apiConfig *APIC
func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
var region = "default"
if apiConfig.Region != nil {
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
@ -419,9 +421,8 @@ func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
}
func (z *MoonshotModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
var region = "default"
if apiConfig.Region != nil {
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}

View File

@ -24,6 +24,7 @@ import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
@ -841,9 +842,72 @@ func (n *NovitaModel) Rerank(modelName *string, query string, documents []string
return &rerankResponse, nil
}
// Balance is not exposed by the Novita API.
// Balance Get remaining credit
func (n *NovitaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
return nil, fmt.Errorf("%s, no such method", n.Name())
var region = "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", n.BaseURL[region], n.URLSuffix.Balance)
// Build request body
reqBody := map[string]interface{}{}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := n.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result map[string]interface{}
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
balanceInterface, exists := result["availableBalance"]
if !exists || balanceInterface == nil {
return nil, fmt.Errorf("missing 'availableBalance' in response. Raw body: %s", string(body))
}
balanceStr, ok := balanceInterface.(string)
if !ok {
return nil, fmt.Errorf("'availableBalance' is not a string. Raw body: %s", string(body))
}
balance, err := strconv.ParseFloat(balanceStr, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse 'availableBalance' as float: %w. Raw body: %s", err, string(body))
}
var response = map[string]interface{}{
"balance": balance,
"currency": "USD",
}
return response, nil
}
func (n *NovitaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {

View File

@ -12,7 +12,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package models
@ -23,7 +22,11 @@ import (
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
@ -589,15 +592,166 @@ func (x *XinferenceModel) Rerank(modelName *string, query string, documents []st
}
func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
return nil, fmt.Errorf("%s, no such method", x.Name())
if file == nil || *file == "" {
return nil, fmt.Errorf("file is missing")
}
region := "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.ASR)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
// audio file
audioFile, err := os.Open(*file)
if err != nil {
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
part, err := writer.CreateFormFile("file", filepath.Base(*file))
if err != nil {
return nil, fmt.Errorf("failed to create multipart file: %w", err)
}
if _, err = io.Copy(part, audioFile); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
if err = writer.WriteField("model", *modelName); err != nil {
return nil, fmt.Errorf("failed to write model name: %w", err)
}
// extra params
if asrConfig != nil && asrConfig.Params != nil {
for key, value := range asrConfig.Params {
var val string
switch v := value.(type) {
case string:
val = v
case bool:
val = strconv.FormatBool(v)
case int:
val = strconv.Itoa(v)
case float64:
val = strconv.FormatFloat(v, 'f', -1, 64)
default:
val = fmt.Sprintf("%v", v)
}
if err := writer.WriteField(key, val); err != nil {
return nil, fmt.Errorf("failed to write field %s: %w", key, err)
}
}
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
// request
req, err := http.NewRequest("POST", url, &body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := x.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody))
}
// result
var result struct {
Text string `json:"text"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return &ASRResponse{
Text: result.Text,
}, nil
}
func (x *XinferenceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
return fmt.Errorf("%s, no such method", x.Name())
}
func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) {
return nil, fmt.Errorf("%s, no such method", x.Name())
func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
if audioContent == nil || *audioContent == "" {
return nil, fmt.Errorf("text content is missing")
}
var region = "default"
if apiConfig.Region != nil && *apiConfig.Region != "" {
region = *apiConfig.Region
}
url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.TTS)
reqBody := map[string]interface{}{
"model": *modelName,
"input": *audioContent,
}
if ttsConfig != nil && ttsConfig.Params != nil {
for key, value := range ttsConfig.Params {
reqBody[key] = value
}
}
if ttsConfig != nil && ttsConfig.Format != "" {
reqBody["format"] = ttsConfig.Format
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
resp, err := x.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s - %s", resp.Status, string(body))
}
return &TTSResponse{Audio: body}, nil
}
func (x *XinferenceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {

View File

@ -290,7 +290,11 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
}
monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever())
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})))
monkeypatch.setattr(
module.DocumentService,
"get_by_ids",
lambda doc_ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={"origin": f"meta-{doc_id}"}) for doc_id in doc_ids],
)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))

View File

@ -97,7 +97,10 @@ def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=No
_stub(
monkeypatch,
"api.db.services.document_service",
DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))),
DocumentService=SimpleNamespace(
get_by_id=lambda _id: (True, SimpleNamespace(id=_id, meta_fields={})),
get_by_ids=lambda ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={}) for doc_id in ids],
),
)
_stub(
monkeypatch,