mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-08 08:07:21 +08:00
Go: implement ASR and TTS for Xinference (#15096)
### What problem does this PR solve? implement ASR and TTS for Xinference ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@ -7,6 +7,7 @@
|
||||
"chat": "openai/v1/chat/completions",
|
||||
"models": "openai/v1/models",
|
||||
"embedding": "openai/v1/embeddings",
|
||||
"balance": "openapi/v1/billing/balance/detail",
|
||||
"rerank": "openai/v1/rerank"
|
||||
},
|
||||
"class": "novita",
|
||||
|
||||
@ -66,7 +66,6 @@ func (f *FishAudioModel) Rerank(modelName *string, query string, documents []str
|
||||
|
||||
// TranscribeAudio transcribe audio
|
||||
func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
|
||||
if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" {
|
||||
return nil, fmt.Errorf("FishAudio API key is missing")
|
||||
}
|
||||
@ -151,11 +150,7 @@ func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiCon
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf(
|
||||
"FishAudio ASR error: %s - %s",
|
||||
resp.Status,
|
||||
string(respBody),
|
||||
)
|
||||
return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
// result
|
||||
|
||||
@ -210,7 +210,7 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/chat/completions", k.BaseURL[region])
|
||||
url := fmt.Sprintf("%s/%s", k.BaseURL[region], k.URLSuffix.Chat)
|
||||
|
||||
// Convert messages to API format
|
||||
apiMessages := make([]map[string]interface{}, len(messages))
|
||||
@ -228,38 +228,40 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
if chatModelConfig != nil {
|
||||
if chatModelConfig.Stream != nil {
|
||||
reqBody["stream"] = *chatModelConfig.Stream
|
||||
}
|
||||
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
if chatModelConfig.MaxTokens != nil {
|
||||
reqBody["max_tokens"] = *chatModelConfig.MaxTokens
|
||||
}
|
||||
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
if chatModelConfig.Temperature != nil {
|
||||
reqBody["temperature"] = *chatModelConfig.Temperature
|
||||
}
|
||||
|
||||
if chatModelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *chatModelConfig.DoSample
|
||||
}
|
||||
if chatModelConfig.DoSample != nil {
|
||||
reqBody["do_sample"] = *chatModelConfig.DoSample
|
||||
}
|
||||
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
if chatModelConfig.TopP != nil {
|
||||
reqBody["top_p"] = *chatModelConfig.TopP
|
||||
}
|
||||
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
if chatModelConfig.Stop != nil {
|
||||
reqBody["stop"] = *chatModelConfig.Stop
|
||||
}
|
||||
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
if chatModelConfig.Thinking != nil {
|
||||
if *chatModelConfig.Thinking {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "enabled",
|
||||
}
|
||||
} else {
|
||||
reqBody["thinking"] = map[string]interface{}{
|
||||
"type": "disabled",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -364,7 +366,7 @@ func (z *MoonshotModel) Embed(modelName *string, texts []string, apiConfig *APIC
|
||||
|
||||
func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
@ -419,9 +421,8 @@ func (z *MoonshotModel) ListModels(apiConfig *APIConfig) ([]string, error) {
|
||||
}
|
||||
|
||||
func (z *MoonshotModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil {
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -841,9 +842,72 @@ func (n *NovitaModel) Rerank(modelName *string, query string, documents []string
|
||||
return &rerankResponse, nil
|
||||
}
|
||||
|
||||
// Balance is not exposed by the Novita API.
|
||||
// Balance Get remaining credit
|
||||
func (n *NovitaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", n.Name())
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", n.BaseURL[region], n.URLSuffix.Balance)
|
||||
|
||||
// Build request body
|
||||
reqBody := map[string]interface{}{}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := n.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var result map[string]interface{}
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
balanceInterface, exists := result["availableBalance"]
|
||||
if !exists || balanceInterface == nil {
|
||||
return nil, fmt.Errorf("missing 'availableBalance' in response. Raw body: %s", string(body))
|
||||
}
|
||||
|
||||
balanceStr, ok := balanceInterface.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'availableBalance' is not a string. Raw body: %s", string(body))
|
||||
}
|
||||
balance, err := strconv.ParseFloat(balanceStr, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse 'availableBalance' as float: %w. Raw body: %s", err, string(body))
|
||||
}
|
||||
|
||||
var response = map[string]interface{}{
|
||||
"balance": balance,
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (n *NovitaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package models
|
||||
|
||||
@ -23,7 +22,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -589,15 +592,166 @@ func (x *XinferenceModel) Rerank(modelName *string, query string, documents []st
|
||||
}
|
||||
|
||||
func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", x.Name())
|
||||
if file == nil || *file == "" {
|
||||
return nil, fmt.Errorf("file is missing")
|
||||
}
|
||||
|
||||
region := "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.ASR)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
// audio file
|
||||
audioFile, err := os.Open(*file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audio file: %w", err)
|
||||
}
|
||||
defer audioFile.Close()
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(*file))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create multipart file: %w", err)
|
||||
}
|
||||
|
||||
if _, err = io.Copy(part, audioFile); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy audio data: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("model", *modelName); err != nil {
|
||||
return nil, fmt.Errorf("failed to write model name: %w", err)
|
||||
}
|
||||
|
||||
// extra params
|
||||
if asrConfig != nil && asrConfig.Params != nil {
|
||||
for key, value := range asrConfig.Params {
|
||||
|
||||
var val string
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
val = v
|
||||
case bool:
|
||||
val = strconv.FormatBool(v)
|
||||
case int:
|
||||
val = strconv.Itoa(v)
|
||||
case float64:
|
||||
val = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
default:
|
||||
val = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if err := writer.WriteField(key, val); err != nil {
|
||||
return nil, fmt.Errorf("failed to write field %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
|
||||
}
|
||||
|
||||
// request
|
||||
req, err := http.NewRequest("POST", url, &body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := x.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("FishAudio ASR error: %s - %s", resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
// result
|
||||
var result struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return &ASRResponse{
|
||||
Text: result.Text,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (x *XinferenceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error {
|
||||
return fmt.Errorf("%s, no such method", x.Name())
|
||||
}
|
||||
|
||||
func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) {
|
||||
return nil, fmt.Errorf("%s, no such method", x.Name())
|
||||
func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) {
|
||||
if audioContent == nil || *audioContent == "" {
|
||||
return nil, fmt.Errorf("text content is missing")
|
||||
}
|
||||
|
||||
var region = "default"
|
||||
if apiConfig.Region != nil && *apiConfig.Region != "" {
|
||||
region = *apiConfig.Region
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.TTS)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": *modelName,
|
||||
"input": *audioContent,
|
||||
}
|
||||
|
||||
if ttsConfig != nil && ttsConfig.Params != nil {
|
||||
for key, value := range ttsConfig.Params {
|
||||
reqBody[key] = value
|
||||
}
|
||||
}
|
||||
if ttsConfig != nil && ttsConfig.Format != "" {
|
||||
reqBody["format"] = ttsConfig.Format
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey))
|
||||
|
||||
resp, err := x.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return &TTSResponse{Audio: body}, nil
|
||||
}
|
||||
|
||||
func (x *XinferenceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error {
|
||||
|
||||
@ -290,7 +290,11 @@ def test_retrieval_success_with_metadata_and_kg(monkeypatch):
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever())
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"})))
|
||||
monkeypatch.setattr(
|
||||
module.DocumentService,
|
||||
"get_by_ids",
|
||||
lambda doc_ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={"origin": f"meta-{doc_id}"}) for doc_id in doc_ids],
|
||||
)
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: [])
|
||||
|
||||
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
|
||||
|
||||
@ -97,7 +97,10 @@ def _load_dify_retrieval(monkeypatch, *, kb, accessible, request_body, chunks=No
|
||||
_stub(
|
||||
monkeypatch,
|
||||
"api.db.services.document_service",
|
||||
DocumentService=SimpleNamespace(get_by_id=lambda _id: (True, SimpleNamespace(meta_fields={}))),
|
||||
DocumentService=SimpleNamespace(
|
||||
get_by_id=lambda _id: (True, SimpleNamespace(id=_id, meta_fields={})),
|
||||
get_by_ids=lambda ids, cols=None: [SimpleNamespace(id=doc_id, meta_fields={}) for doc_id in ids],
|
||||
),
|
||||
)
|
||||
_stub(
|
||||
monkeypatch,
|
||||
|
||||
Reference in New Issue
Block a user