mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-22 08:58:23 +08:00
### What problem does this PR solve? Support stream for multimodal chat ### Type of change - [x] Refactoring
906 lines
20 KiB
Go
906 lines
20 KiB
Go
//
|
|
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
|
|
package handler
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"ragflow/internal/common"
|
|
"ragflow/internal/dao"
|
|
"ragflow/internal/entity/models"
|
|
"ragflow/internal/service"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// ProviderHandler provider handler
|
|
type ProviderHandler struct {
|
|
userService *service.UserService
|
|
modelProviderService *service.ModelProviderService
|
|
userTenantDAO *dao.UserTenantDAO
|
|
}
|
|
|
|
// NewProviderHandler create provider handler
|
|
func NewProviderHandler(userService *service.UserService, modelProviderService *service.ModelProviderService) *ProviderHandler {
|
|
return &ProviderHandler{
|
|
userService: userService,
|
|
modelProviderService: modelProviderService,
|
|
userTenantDAO: dao.NewUserTenantDAO(),
|
|
}
|
|
}
|
|
|
|
func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
|
|
|
keywords := ""
|
|
if queryKeywords := c.Query("available"); queryKeywords != "" {
|
|
keywords = queryKeywords
|
|
}
|
|
|
|
// convert keywords to small case
|
|
keywords = strings.ToLower(keywords)
|
|
if keywords == "true" {
|
|
// list pool providers
|
|
providers, err := dao.GetModelProviderManager().ListProviders()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
for _, provider := range providers {
|
|
delete(provider, "url_suffix")
|
|
delete(provider, "tags")
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": providers,
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
// list tenant providers
|
|
providers, errorCode, err := h.modelProviderService.ListProvidersOfTenant(userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": providers,
|
|
})
|
|
return
|
|
}
|
|
|
|
type AddProviderRequest struct {
|
|
ProviderName string `json:"provider_name" binding:"required"`
|
|
}
|
|
|
|
func (h *ProviderHandler) AddProvider(c *gin.Context) {
|
|
|
|
var req AddProviderRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
"data": false,
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
errorCode, err := h.modelProviderService.AddModelProvider(req.ProviderName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
errorCode, err := h.modelProviderService.DeleteModelProvider(providerName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ShowProvider(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
provider, err := dao.GetModelProviderManager().GetProviderByName(providerName)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": provider,
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ListModels(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
providerModels, err := dao.GetModelProviderManager().ListModels(providerName)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": providerModels,
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ShowModel(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
modelName := c.Param("model_name")
|
|
if modelName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Model name is required",
|
|
})
|
|
return
|
|
}
|
|
model, err := dao.GetModelProviderManager().GetModelByName(providerName, modelName)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": model,
|
|
})
|
|
}
|
|
|
|
type CreateProviderInstanceRequest struct {
|
|
InstanceName string `json:"instance_name" binding:"required"`
|
|
APIKey string `json:"api_key"`
|
|
BaseURL string `json:"base_url"`
|
|
Region string `json:"region"`
|
|
}
|
|
|
|
func (h *ProviderHandler) CreateProviderInstance(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req CreateProviderInstanceRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if instance name is "default"
|
|
if req.InstanceName == "default" {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": "Instance name cannot be 'default'",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
_, err := h.modelProviderService.CreateProviderInstance(providerName, req.InstanceName, req.APIKey, req.BaseURL, req.Region, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeServerError,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ListProviderInstances(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
instances, errorCode, err := h.modelProviderService.ListProviderInstances(providerName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": instances,
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ShowProviderInstance(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
// Get tenant ID from user
|
|
instance, errorCode, err := h.modelProviderService.ShowProviderInstance(providerName, instanceName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": instance,
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ShowInstanceBalance(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
// Get tenant ID from user
|
|
balance, errorCode, err := h.modelProviderService.ShowInstanceBalance(providerName, instanceName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": balance,
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) CheckProviderConnection(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
// Get tenant ID from user
|
|
errorCode, err := h.modelProviderService.CheckProviderConnection(providerName, instanceName, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
type AlterProviderInstanceRequest struct {
|
|
LLMName string `json:"llm_name" binding:"required"`
|
|
}
|
|
|
|
func (h *ProviderHandler) AlterProviderInstance(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req AlterProviderInstanceRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
if userID == "" {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeUnauthorized,
|
|
"message": "Unauthorized",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
type DropProviderInstanceRequest struct {
|
|
Instances []string `json:"instances" binding:"required"`
|
|
}
|
|
|
|
func (h *ProviderHandler) DropProviderInstance(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
var req DropProviderInstanceRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
_, err := h.modelProviderService.DropProviderInstances(providerName, userID, req.Instances)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeServerError,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) ListInstanceModels(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
keywords := ""
|
|
if queryKeywords := c.Query("supported"); queryKeywords != "" {
|
|
keywords = queryKeywords
|
|
}
|
|
|
|
// convert keywords to small case
|
|
keywords = strings.ToLower(keywords)
|
|
if keywords == "true" {
|
|
// list supported models
|
|
|
|
modelList, err := h.modelProviderService.ListSupportedModels(providerName, instanceName, c.GetString("user_id"))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeServerError,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
var modelResponse []map[string]string
|
|
for _, modelName := range modelList {
|
|
modelResponse = append(modelResponse, map[string]string{
|
|
"model_name": modelName,
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": modelResponse,
|
|
})
|
|
return
|
|
}
|
|
|
|
modelInstances, err := h.modelProviderService.ListInstanceModels(providerName, instanceName, c.GetString("user_id"))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeNotFound,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
"data": modelInstances,
|
|
})
|
|
}
|
|
|
|
type EnableOrDisableModelRequest struct {
|
|
Status string `json:"status" binding:"required"`
|
|
}
|
|
|
|
func (h *ProviderHandler) EnableOrDisableModel(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
modelName := c.Param("model_name")
|
|
if modelName != "" {
|
|
modelName = strings.TrimPrefix(modelName, "/")
|
|
}
|
|
if modelName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Model name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req EnableOrDisableModelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
println("JSON bind error: %v (type: %T)", err, err)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
_, err := h.modelProviderService.UpdateModelStatus(providerName, instanceName, modelName, userID, req.Status)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeServerError,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
func (h *ProviderHandler) AddCustomModel(c *gin.Context) {
|
|
var req service.AddCustomModelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
println("JSON bind error: %v (type: %T)", err, err)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.ProviderName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.InstanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.ModelName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Model name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.ModelTypes == nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Model type is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
errorCode, err := h.modelProviderService.AddCustomModel(&req, userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeSuccess,
|
|
})
|
|
|
|
}
|
|
|
|
type DropInstanceModelRequest struct {
|
|
Models []string `json:"models" binding:"required"`
|
|
}
|
|
|
|
func (h *ProviderHandler) DropInstanceModels(c *gin.Context) {
|
|
providerName := c.Param("provider_name")
|
|
if providerName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
instanceName := c.Param("instance_name")
|
|
if instanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
var req DropInstanceModelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
_, err := h.modelProviderService.DropInstanceModels(providerName, instanceName, userID, req.Models)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeServerError,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"message": "success",
|
|
})
|
|
}
|
|
|
|
type ChatToModelRequest struct {
|
|
ProviderName *string `json:"provider_name"`
|
|
InstanceName *string `json:"instance_name"`
|
|
ModelName *string `json:"model_name"`
|
|
Messages []map[string]interface{} `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
Thinking bool `json:"thinking"`
|
|
Effort *string `json:"effort"`
|
|
Verbosity *string `json:"verbosity"`
|
|
}
|
|
|
|
func (h *ProviderHandler) ChatToModel(c *gin.Context) {
|
|
var req ChatToModelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
println("JSON bind error: %v (type: %T)", err, err)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": common.CodeBadRequest,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.ProviderName == nil || *req.ProviderName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Provider name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.InstanceName == nil || *req.InstanceName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Instance name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
if req.ModelName == nil || *req.ModelName == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"code": 400,
|
|
"message": "Model name is required",
|
|
})
|
|
return
|
|
}
|
|
|
|
userID := c.GetString("user_id")
|
|
|
|
if !req.Thinking {
|
|
req.Effort = nil
|
|
req.Verbosity = nil
|
|
}
|
|
|
|
apiConfig := models.APIConfig{
|
|
ApiKey: nil,
|
|
Region: nil,
|
|
}
|
|
|
|
chatConfig := models.ChatConfig{
|
|
Thinking: &req.Thinking,
|
|
Stream: &req.Stream,
|
|
Vision: nil,
|
|
Stop: &[]string{},
|
|
DoSample: nil,
|
|
MaxTokens: nil,
|
|
Temperature: nil,
|
|
TopP: nil,
|
|
Effort: req.Effort,
|
|
Verbosity: req.Verbosity,
|
|
}
|
|
|
|
// Check if it's a stream request
|
|
if req.Stream {
|
|
// Set SSE headers
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
c.Writer.Flush()
|
|
|
|
// Create sender function that writes directly to response
|
|
sender := func(content, reasoningContent *string) error {
|
|
// Check for [DONE] marker (OpenAI compatible)
|
|
if content != nil {
|
|
if *content == "[DONE]" {
|
|
c.SSEvent("done", "[DONE]")
|
|
return nil
|
|
}
|
|
message := fmt.Sprintf("[MESSAGE]%s", *content)
|
|
c.SSEvent("message", message)
|
|
c.Writer.Flush()
|
|
}
|
|
|
|
if reasoningContent != nil {
|
|
message := fmt.Sprintf("[REASONING]%s", *reasoningContent)
|
|
c.SSEvent("message", message)
|
|
c.Writer.Flush()
|
|
}
|
|
|
|
//logger.Info(data)
|
|
return nil
|
|
}
|
|
|
|
// Convert []map[string]interface{} to []models.Message
|
|
messages := make([]models.Message, len(req.Messages))
|
|
for i, msg := range req.Messages {
|
|
role, _ := msg["role"].(string)
|
|
content := msg["content"]
|
|
messages[i] = models.Message{Role: role, Content: content}
|
|
}
|
|
|
|
// Stream response using sender function (best performance, no channel)
|
|
errorCode, err := h.modelProviderService.ChatToModelStreamWithSender(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig, sender)
|
|
|
|
if errorCode != common.CodeSuccess {
|
|
c.SSEvent("error", err.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
// Non-stream response
|
|
var response *models.ChatResponse
|
|
var errorCode common.ErrorCode
|
|
var err error
|
|
|
|
// Convert []map[string]interface{} to []models.Message
|
|
messages := make([]models.Message, len(req.Messages))
|
|
for i, msg := range req.Messages {
|
|
role, _ := msg["role"].(string)
|
|
content := msg["content"]
|
|
messages[i] = models.Message{Role: role, Content: content}
|
|
}
|
|
response, errorCode, err = h.modelProviderService.ChatToModelWithMessages(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, messages, &apiConfig, &chatConfig)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": errorCode,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"code": 0,
|
|
"reasoning_content": response.ReasonContent,
|
|
"answer": response.Answer,
|
|
})
|
|
}
|