Compare commits
10 Commits
fix/tool-r
...
v0.2.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 0367e66eca | |||
| 38b63f00a3 | |||
| 72656e4fd1 | |||
| f80d4f757b | |||
| 36923bd0a4 | |||
| adc5986d13 | |||
| c7bf6bbdec | |||
| a44b4e8f7e | |||
| f78d297311 | |||
| 3fe4031531 |
@ -33,9 +33,8 @@ RUN apk add --no-cache --virtual .python-build-deps build-base py3-pip git && \
|
||||
# Activate venv and install packages
|
||||
. /app/.venv/bin/activate && \
|
||||
# If you want to use other third-party libraries, you can install them here.
|
||||
pip install git+https://gitcode.com/gh_mirrors/re/requests-async.git@master && \
|
||||
pip install urllib3==1.26.16 && \
|
||||
pip install --no-cache-dir pillow==11.2.1 pdfplumber==0.11.7 python-docx==1.2.0 numpy==2.3.1 && \
|
||||
pip install --no-cache-dir h11==0.16.0 httpx==0.28.1 pillow==11.2.1 pdfplumber==0.11.7 python-docx==1.2.0 numpy==2.3.1 && \
|
||||
# Deactivate (optional, as RUN is a new shell)
|
||||
# deactivate && \
|
||||
# Remove build dependencies
|
||||
|
||||
@ -19,6 +19,7 @@ package plugin
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
api "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop_common"
|
||||
@ -74,7 +75,12 @@ func (mf *PluginManifest) EncryptAuthPayload() (*PluginManifest, error) {
|
||||
return mf_, nil
|
||||
}
|
||||
|
||||
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), utils.AuthSecretKey)
|
||||
secret := os.Getenv(utils.AuthSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultAuthSecret
|
||||
}
|
||||
|
||||
payload_, err := utils.EncryptByAES([]byte(mf_.Auth.Payload), secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -357,7 +363,12 @@ func (au *AuthV2) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
|
||||
if auth.Payload != "" {
|
||||
payload_, err := utils.DecryptByAES(auth.Payload, utils.AuthSecretKey)
|
||||
secret := os.Getenv(utils.AuthSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultAuthSecret
|
||||
}
|
||||
|
||||
payload_, err := utils.DecryptByAES(auth.Payload, secret)
|
||||
if err == nil {
|
||||
auth.Payload = string(payload_)
|
||||
}
|
||||
|
||||
@ -37,12 +37,13 @@ type AgentRuntime struct {
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
||||
EventTypeOfToolsMessage EventType = "tools_message"
|
||||
EventTypeOfFuncCall EventType = "func_call"
|
||||
EventTypeOfSuggest EventType = "suggest"
|
||||
EventTypeOfKnowledge EventType = "knowledge"
|
||||
EventTypeOfInterrupt EventType = "interrupt"
|
||||
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
||||
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
|
||||
EventTypeOfToolsMessage EventType = "tools_message"
|
||||
EventTypeOfFuncCall EventType = "func_call"
|
||||
EventTypeOfSuggest EventType = "suggest"
|
||||
EventTypeOfKnowledge EventType = "knowledge"
|
||||
EventTypeOfInterrupt EventType = "interrupt"
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
|
||||
@ -26,6 +25,7 @@ import (
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
user "github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
@ -35,7 +35,7 @@ type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
OSS storage.Storage
|
||||
CacheCli *redisV9.Client
|
||||
CacheCli cache.Cmdable
|
||||
ProjectEventBus search.ProjectEventBus
|
||||
|
||||
ModelMgr modelmgr.Manager
|
||||
|
||||
@ -338,8 +338,9 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
|
||||
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
||||
// ARK_EMBEDDING_AK will be removed in the future
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
@ -347,6 +348,15 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if arkEmbeddingAPIType != "" {
|
||||
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
|
||||
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
|
||||
} else {
|
||||
apiType = t
|
||||
}
|
||||
}
|
||||
|
||||
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: func() string {
|
||||
if arkEmbeddingApiKey != "" {
|
||||
@ -356,6 +366,7 @@ func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
}(),
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
APIType: &apiType,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
|
||||
@ -19,12 +19,11 @@ package memory
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/repository"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
@ -43,7 +42,7 @@ type ServiceComponents struct {
|
||||
EventBus search.ResourceEventBus
|
||||
TosClient storage.Storage
|
||||
ResourceDomainNotifier search.ResourceEventBus
|
||||
CacheCli *redis.Client
|
||||
CacheCli cache.Cmdable
|
||||
}
|
||||
|
||||
func InitService(c *ServiceComponents) *MemoryApplicationServices {
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -1703,7 +1704,12 @@ func (p *PluginApplicationService) OauthAuthorizationCode(ctx context.Context, r
|
||||
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
|
||||
}
|
||||
|
||||
stateBytes, err := utils.DecryptByAES(stateStr, utils.StateSecretKey)
|
||||
secret := os.Getenv(utils.StateSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultStateSecret
|
||||
}
|
||||
|
||||
stateBytes, err := utils.DecryptByAES(stateStr, secret)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapByCode(err, errno.ErrPluginOAuthFailed, errorx.KV(errno.PluginMsgKey, "invalid state"))
|
||||
}
|
||||
|
||||
@ -18,9 +18,10 @@ package singleagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/repository"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
@ -50,7 +51,7 @@ var SingleAgentSVC *SingleAgentApplicationService
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache *redis.Client
|
||||
Cache cache.Cmdable
|
||||
TosClient storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
EventBus search.ProjectEventBus
|
||||
|
||||
@ -20,12 +20,9 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
||||
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
|
||||
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
|
||||
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
|
||||
@ -46,17 +43,19 @@ import (
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
|
||||
crossvariable "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache *redis.Client
|
||||
Cache cache.Cmdable
|
||||
DatabaseDomainSVC dbservice.Database
|
||||
VariablesDomainSVC variables.Variables
|
||||
PluginDomainSVC plugin.PluginService
|
||||
|
||||
@ -2520,6 +2520,12 @@ func (w *ApplicationService) GetApiDetail(ctx context.Context, req *workflow.Get
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, v := range outputVars {
|
||||
if err := crossplugin.GetPluginService().UnwrapArrayItemFieldsInVariable(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
toolDetailInfo := &vo.ToolDetailInfo{
|
||||
ApiDetailData: &workflow.ApiDetailData{
|
||||
PluginID: req.GetPluginID(),
|
||||
@ -3701,7 +3707,7 @@ func toVariable(p *workflow.APIParameter) (*vo.Variable, error) {
|
||||
case workflow.ParameterType_Array:
|
||||
v.Type = vo.VariableTypeList
|
||||
if len(p.SubParameters) > 0 {
|
||||
subVs := make([]any, 0)
|
||||
subVs := make([]*vo.Variable, 0)
|
||||
for _, ap := range p.SubParameters {
|
||||
av, err := toVariable(ap)
|
||||
if err != nil {
|
||||
|
||||
@ -201,7 +201,7 @@ func (t *pluginService) GetPluginToolsInfo(ctx context.Context, req *crossplugin
|
||||
)
|
||||
if toolExample != nil {
|
||||
requestExample = toolExample.RequestExample
|
||||
responseExample = toolExample.RequestExample
|
||||
responseExample = toolExample.ResponseExample
|
||||
}
|
||||
|
||||
response.ToolInfoList[tf.ID] = crossplugin.ToolInfo{
|
||||
@ -220,6 +220,63 @@ func (t *pluginService) GetPluginToolsInfo(ctx context.Context, req *crossplugin
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (t *pluginService) UnwrapArrayItemFieldsInVariable(v *vo.Variable) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.Type == vo.VariableTypeObject {
|
||||
subVars, ok := v.Schema.([]*vo.Variable)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
newSubVars := make([]*vo.Variable, 0, len(subVars))
|
||||
for _, subVar := range subVars {
|
||||
if subVar.Name == "[Array Item]" {
|
||||
if err := t.UnwrapArrayItemFieldsInVariable(subVar); err != nil {
|
||||
return err
|
||||
}
|
||||
// If the array item is an object, append its children
|
||||
if subVar.Type == vo.VariableTypeObject {
|
||||
if innerSubVars, ok := subVar.Schema.([]*vo.Variable); ok {
|
||||
newSubVars = append(newSubVars, innerSubVars...)
|
||||
}
|
||||
} else {
|
||||
// If the array item is a primitive type, clear its name and append it
|
||||
subVar.Name = ""
|
||||
newSubVars = append(newSubVars, subVar)
|
||||
}
|
||||
} else {
|
||||
// For other sub-variables, recursively unwrap and append
|
||||
if err := t.UnwrapArrayItemFieldsInVariable(subVar); err != nil {
|
||||
return err
|
||||
}
|
||||
newSubVars = append(newSubVars, subVar)
|
||||
}
|
||||
}
|
||||
v.Schema = newSubVars
|
||||
|
||||
} else if v.Type == vo.VariableTypeList {
|
||||
if v.Schema != nil {
|
||||
subVar, ok := v.Schema.(*vo.Variable)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := t.UnwrapArrayItemFieldsInVariable(subVar); err != nil {
|
||||
return err
|
||||
}
|
||||
// If the array item definition itself has "[Array Item]" name, clear it
|
||||
if subVar.Name == "[Array Item]" {
|
||||
subVar.Name = ""
|
||||
}
|
||||
v.Schema = subVar
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *pluginService) GetPluginInvokableTools(ctx context.Context, req *crossplugin.ToolsInvokableRequest) (
|
||||
_ map[int64]crossplugin.InvokableTool, err error) {
|
||||
defer func() {
|
||||
@ -327,7 +384,7 @@ func (t *pluginService) ExecutePlugin(ctx context.Context, input map[string]any,
|
||||
}
|
||||
|
||||
var output map[string]any
|
||||
err = sonic.UnmarshalString(r.RawResp, &output)
|
||||
err = sonic.UnmarshalString(r.TrimmedResp, &output)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
|
||||
217
backend/crossdomain/workflow/plugin/plugin_test.go
Normal file
217
backend/crossdomain/workflow/plugin/plugin_test.go
Normal file
@ -0,0 +1,217 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPluginService_UnwrapArrayItemFieldsInVariable(t *testing.T) {
|
||||
s := &pluginService{}
|
||||
t.Run("unwraps a simple array item", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{Name: "field2", Type: vo.VariableTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{Name: "field2", Type: vo.VariableTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("handles nested array items", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "nestedField", Type: vo.VariableTypeBoolean},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{Name: "nestedField", Type: vo.VariableTypeBoolean},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("handles array item within a list", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "rootList",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "itemField", Type: vo.VariableTypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := &vo.Variable{
|
||||
Name: "rootList",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "itemField", Type: vo.VariableTypeString},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("does nothing if no array item is present", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{Name: "field2", Type: vo.VariableTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a copy for comparison as the input will be modified in place.
|
||||
expected := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{Name: "field1", Type: vo.VariableTypeString},
|
||||
{Name: "field2", Type: vo.VariableTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("handles primitive type array item in object", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeString,
|
||||
},
|
||||
{
|
||||
Name: "anotherField",
|
||||
Type: vo.VariableTypeInteger,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expected := &vo.Variable{
|
||||
Name: "root",
|
||||
Type: vo.VariableTypeObject,
|
||||
Schema: []*vo.Variable{
|
||||
{
|
||||
Name: "",
|
||||
Type: vo.VariableTypeString,
|
||||
},
|
||||
{
|
||||
Name: "anotherField",
|
||||
Type: vo.VariableTypeInteger,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("handles list of primitives", func(t *testing.T) {
|
||||
input := &vo.Variable{
|
||||
Name: "listOfStrings",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Name: "[Array Item]",
|
||||
Type: vo.VariableTypeString,
|
||||
},
|
||||
}
|
||||
|
||||
expected := &vo.Variable{
|
||||
Name: "listOfStrings",
|
||||
Type: vo.VariableTypeList,
|
||||
Schema: &vo.Variable{
|
||||
Name: "",
|
||||
Type: vo.VariableTypeString,
|
||||
},
|
||||
}
|
||||
|
||||
err := s.UnwrapArrayItemFieldsInVariable(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, input)
|
||||
})
|
||||
|
||||
t.Run("handles nil input", func(t *testing.T) {
|
||||
err := s.UnwrapArrayItemFieldsInVariable(nil)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
}
|
||||
tr := newPreToolRetriever(&toolPreCallConf{})
|
||||
|
||||
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
|
||||
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
|
||||
wfInfos: conf.Agent.Workflow,
|
||||
})
|
||||
if err != nil {
|
||||
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
ToolsConfig: compose.ToolsNodeConfig{
|
||||
Tools: agentTools,
|
||||
},
|
||||
ToolReturnDirectly: toolsReturnDirectly,
|
||||
ToolReturnDirectly: returnDirectlyTools,
|
||||
ModelNodeName: keyOfReActAgentChatModel,
|
||||
ToolsNodeName: keyOfReActAgentToolsNode,
|
||||
})
|
||||
@ -273,10 +273,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
}
|
||||
|
||||
return &AgentRunner{
|
||||
runner: runner,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
modelInfo: modelInfo,
|
||||
containWfTool: containWfTool,
|
||||
runner: runner,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
modelInfo: modelInfo,
|
||||
containWfTool: containWfTool,
|
||||
returnDirectlyTools: returnDirectlyTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -57,8 +57,9 @@ type AgentRunner struct {
|
||||
runner compose.Runnable[*AgentRequest, *schema.Message]
|
||||
requireCheckpoint bool
|
||||
|
||||
containWfTool bool
|
||||
modelInfo *modelmgr.Model
|
||||
returnDirectlyTools map[string]struct{}
|
||||
containWfTool bool
|
||||
modelInfo *modelmgr.Model
|
||||
}
|
||||
|
||||
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
) {
|
||||
executeID := uuid.New()
|
||||
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
|
||||
@ -38,14 +38,15 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
|
||||
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
|
||||
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
||||
) {
|
||||
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
||||
|
||||
rcc := &replyChunkCallback{
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
returnDirectlyTools: returnDirectlyTools,
|
||||
}
|
||||
|
||||
clb = callbacks.NewHandlerBuilder().
|
||||
@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
||||
}
|
||||
|
||||
type replyChunkCallback struct {
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
returnDirectlyTools map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
@ -186,11 +188,6 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
|
||||
logs.CtxInfof(ctx, "info-OnEndWithStreamOutput, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
|
||||
switch info.Component {
|
||||
case compose.ComponentOfGraph, components.ComponentOfChatModel:
|
||||
if info.Name == keyOfReActAgent {
|
||||
r.processToolsReturnDirectlyStreamWithLazyInit(ctx, output)
|
||||
return ctx
|
||||
}
|
||||
|
||||
if info.Name != keyOfReActAgentChatModel && info.Name != keyOfLLM {
|
||||
output.Close()
|
||||
return ctx
|
||||
@ -206,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
|
||||
}, nil)
|
||||
return ctx
|
||||
case compose.ComponentOfToolsNode:
|
||||
toolsMessage, err := concatToolsNodeOutput(ctx, output)
|
||||
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
|
||||
if err != nil {
|
||||
r.sw.Send(nil, err)
|
||||
return ctx
|
||||
@ -275,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
||||
return interruptEventType
|
||||
}
|
||||
|
||||
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
defer output.Close()
|
||||
toolsMsgChunks := make([][]*schema.Message, 0, 5)
|
||||
var toolsMsgChunks [][]*schema.Message
|
||||
var sr *schema.StreamReader[*schema.Message]
|
||||
var sw *schema.StreamWriter[*schema.Message]
|
||||
defer func() {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
}
|
||||
}()
|
||||
var streamInitialized bool
|
||||
returnDirectToolsMap := make(map[int]bool)
|
||||
isReturnDirectToolsFirstCheck := true
|
||||
isToolsMsgChunksInit := false
|
||||
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@ -285,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if sw != nil {
|
||||
sw.Send(nil, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs := convToolsNodeCallbackOutput(cbOut)
|
||||
|
||||
for _, msg := range msgs {
|
||||
if !isToolsMsgChunksInit {
|
||||
isToolsMsgChunksInit = true
|
||||
toolsMsgChunks = make([][]*schema.Message, len(msgs))
|
||||
}
|
||||
|
||||
for mIndex, msg := range msgs {
|
||||
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if len(r.returnDirectlyTools) > 0 {
|
||||
if isReturnDirectToolsFirstCheck {
|
||||
isReturnDirectToolsFirstCheck = false
|
||||
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
|
||||
returnDirectToolsMap[mIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
findSameMsg := false
|
||||
for i, msgChunks := range toolsMsgChunks {
|
||||
if msg.ToolCallID == msgChunks[0].ToolCallID {
|
||||
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
|
||||
findSameMsg = true
|
||||
break
|
||||
if _, ok := returnDirectToolsMap[mIndex]; ok {
|
||||
if !streamInitialized {
|
||||
sr, sw = schema.Pipe[*schema.Message](5)
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||
ChatModelAnswer: sr,
|
||||
}, nil)
|
||||
streamInitialized = true
|
||||
}
|
||||
sw.Send(msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if !findSameMsg {
|
||||
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
|
||||
if toolsMsgChunks[mIndex] == nil {
|
||||
toolsMsgChunks[mIndex] = []*schema.Message{msg}
|
||||
} else {
|
||||
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -331,56 +361,6 @@ func convToolsNodeCallbackInput(input callbacks.CallbackInput) *schema.Message {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
func convToolsNodeCallbackOutputMessage(output callbacks.CallbackOutput) *schema.Message {
|
||||
switch t := output.(type) {
|
||||
case *schema.Message:
|
||||
return t
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) processToolsReturnDirectlyStreamWithLazyInit(_ context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) {
|
||||
var streamInitialized bool
|
||||
var sr *schema.StreamReader[*schema.Message]
|
||||
var sw *schema.StreamWriter[*schema.Message]
|
||||
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
if sw != nil {
|
||||
sw.Send(nil, err)
|
||||
sw.Close()
|
||||
}
|
||||
break
|
||||
}
|
||||
msg := convToolsNodeCallbackOutputMessage(cbOut)
|
||||
|
||||
if msg == nil {
|
||||
break
|
||||
}
|
||||
if msg.Role != schema.Tool {
|
||||
break
|
||||
}
|
||||
if msg.Role == schema.Tool {
|
||||
if !streamInitialized {
|
||||
sr, sw = schema.Pipe[*schema.Message](5)
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfChatModelAnswer,
|
||||
ChatModelAnswer: sr,
|
||||
}, nil)
|
||||
streamInitialized = true
|
||||
}
|
||||
sw.Send(msg, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convToolsNodeCallbackOutput(output callbacks.CallbackOutput) []*schema.Message {
|
||||
switch t := output.(type) {
|
||||
|
||||
@ -22,10 +22,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@ -53,7 +52,7 @@ func (sa *SingleAgentDraftDAO) UpdateDisplayInfo(ctx context.Context, userID int
|
||||
func (sa *SingleAgentDraftDAO) GetDisplayInfo(ctx context.Context, userID, agentID int64) (*entity.AgentDraftDisplayInfo, error) {
|
||||
key := makeAgentDisplayInfoKey(userID, agentID)
|
||||
data, err := sa.cacheClient.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
tabStatusDefault := developer_api.TabStatus_Default
|
||||
return &entity.AgentDraftDisplayInfo{
|
||||
AgentID: agentID,
|
||||
|
||||
@ -20,24 +20,23 @@ import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
)
|
||||
|
||||
func NewCountRepo(cli *redis.Client) *CounterImpl {
|
||||
func NewCountRepo(cli cache.Cmdable) *CounterImpl {
|
||||
return &CounterImpl{
|
||||
cacheClient: cli,
|
||||
}
|
||||
}
|
||||
|
||||
type CounterImpl struct {
|
||||
cacheClient *redis.Client
|
||||
cacheClient cache.Cmdable
|
||||
}
|
||||
|
||||
func (c *CounterImpl) Get(ctx context.Context, key string) (int64, error) {
|
||||
val, err := c.cacheClient.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
if err == cache.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@ -20,13 +20,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
@ -35,10 +35,10 @@ import (
|
||||
type SingleAgentDraftDAO struct {
|
||||
idGen idgen.IDGenerator
|
||||
dbQuery *query.Query
|
||||
cacheClient *redis.Client
|
||||
cacheClient cache.Cmdable
|
||||
}
|
||||
|
||||
func NewSingleAgentDraftDAO(db *gorm.DB, idGen idgen.IDGenerator, cli *redis.Client) *SingleAgentDraftDAO {
|
||||
func NewSingleAgentDraftDAO(db *gorm.DB, idGen idgen.IDGenerator, cli cache.Cmdable) *SingleAgentDraftDAO {
|
||||
query.SetDefault(db)
|
||||
|
||||
return &SingleAgentDraftDAO{
|
||||
|
||||
@ -19,15 +19,15 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func NewSingleAgentRepo(db *gorm.DB, idGen idgen.IDGenerator, cli *redis.Client) SingleAgentDraftRepo {
|
||||
func NewSingleAgentRepo(db *gorm.DB, idGen idgen.IDGenerator, cli cache.Cmdable) SingleAgentDraftRepo {
|
||||
return dal.NewSingleAgentDraftDAO(db, idGen, cli)
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func NewSingleAgentVersionRepo(db *gorm.DB, idGen idgen.IDGenerator) SingleAgent
|
||||
return dal.NewSingleAgentVersion(db, idGen)
|
||||
}
|
||||
|
||||
func NewCounterRepo(cli *redis.Client) CounterRepository {
|
||||
func NewCounterRepo(cli cache.Cmdable) CounterRepository {
|
||||
return dal.NewCountRepo(cli)
|
||||
}
|
||||
|
||||
|
||||
@ -21,16 +21,15 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type AppCache struct {
|
||||
cacheCli *redisV9.Client
|
||||
cacheCli cache.Cmdable
|
||||
}
|
||||
|
||||
func NewAppCache(cacheCli *redisV9.Client) *AppCache {
|
||||
func NewAppCache(cacheCli cache.Cmdable) *AppCache {
|
||||
return &AppCache{
|
||||
cacheCli: cacheCli,
|
||||
}
|
||||
@ -39,7 +38,7 @@ func NewAppCache(cacheCli *redisV9.Client) *AppCache {
|
||||
func (a *AppCache) Get(ctx context.Context, key string) (value string, exist bool, err error) {
|
||||
cmd := a.cacheCli.Get(ctx, key)
|
||||
if cmd.Err() != nil {
|
||||
if errors.Is(cmd.Err(), redisV9.Nil) {
|
||||
if errors.Is(cmd.Err(), cache.Nil) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, cmd.Err()
|
||||
|
||||
@ -25,12 +25,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@ -50,7 +50,7 @@ type appRepoImpl struct {
|
||||
type APPRepoComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
CacheCli *redisV9.Client
|
||||
CacheCli cache.Cmdable
|
||||
}
|
||||
|
||||
func NewAPPRepo(components *APPRepoComponents) AppRepository {
|
||||
|
||||
@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
|
||||
return message.MessageTypeKnowledge, nil
|
||||
case singleagent.EventTypeOfToolsMessage:
|
||||
return message.MessageTypeToolResponse, nil
|
||||
case singleagent.EventTypeOfChatModelAnswer:
|
||||
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
|
||||
return message.MessageTypeAnswer, nil
|
||||
case singleagent.EventTypeOfSuggest:
|
||||
return message.MessageTypeFlowUp, nil
|
||||
|
||||
@ -32,21 +32,19 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/events"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/processor/impl"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
@ -1484,7 +1482,7 @@ func (k *knowledgeSVC) getObjectURL(ctx context.Context, uri string) (string, er
|
||||
if err != nil {
|
||||
return "", errorx.New(errno.ErrKnowledgeGetObjectURLFailCode, errorx.KV("msg", fmt.Sprintf("get object url failed, %v", err)))
|
||||
}
|
||||
if errors.Is(cmd.Err(), redisV9.Nil) {
|
||||
if errors.Is(cmd.Err(), cache.Nil) {
|
||||
err = k.cacheCli.Set(ctx, uri, url, cacheTime*time.Second).Err()
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[getObjectURL] set cache failed, %v", err)
|
||||
|
||||
@ -32,6 +32,8 @@ import (
|
||||
"github.com/tealeg/xlsx/v3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/table"
|
||||
@ -64,10 +66,10 @@ type databaseService struct {
|
||||
onlineDAO repository.OnlineDAO
|
||||
agentToDatabaseDAO repository.AgentToDatabaseDAO
|
||||
storage storage.Storage
|
||||
cache *redis.Client
|
||||
cache cache.Cmdable
|
||||
}
|
||||
|
||||
func NewService(rdb rdb.RDB, db *gorm.DB, generator idgen.IDGenerator, storage storage.Storage, cacheCli *redis.Client) Database {
|
||||
func NewService(rdb rdb.RDB, db *gorm.DB, generator idgen.IDGenerator, storage storage.Storage, cacheCli cache.Cmdable) Database {
|
||||
return &databaseService{
|
||||
rdb: rdb,
|
||||
db: db,
|
||||
@ -641,7 +643,7 @@ func (d databaseService) UpdateDatabaseRecord(ctx context.Context, req *UpdateDa
|
||||
cond := &rdb.Condition{
|
||||
Field: database.DefaultUidColName,
|
||||
Operator: entity3.OperatorEqual,
|
||||
Value: req.UserID,
|
||||
Value: strconv.FormatInt(req.UserID, 10),
|
||||
}
|
||||
|
||||
condition.Conditions = append(condition.Conditions, cond)
|
||||
@ -711,7 +713,7 @@ func (d databaseService) DeleteDatabaseRecord(ctx context.Context, req *DeleteDa
|
||||
cond := &rdb.Condition{
|
||||
Field: database.DefaultUidColName,
|
||||
Operator: entity3.OperatorEqual,
|
||||
Value: req.UserID,
|
||||
Value: strconv.FormatInt(req.UserID, 10),
|
||||
}
|
||||
|
||||
condition.Conditions = append(condition.Conditions, cond)
|
||||
@ -773,20 +775,21 @@ func (d databaseService) ListDatabaseRecord(ctx context.Context, req *ListDataba
|
||||
Conditions: []*rdb.Condition{cond},
|
||||
}
|
||||
}
|
||||
|
||||
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
|
||||
cond := &rdb.Condition{
|
||||
Field: database.DefaultUidColName,
|
||||
Operator: entity3.OperatorEqual,
|
||||
Value: req.UserID,
|
||||
}
|
||||
|
||||
if complexCondition == nil {
|
||||
complexCondition = &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{cond},
|
||||
if req.TableType == table.TableType_DraftTable {
|
||||
if tableInfo.RwMode == table.BotTableRWMode_LimitedReadWrite {
|
||||
cond := &rdb.Condition{
|
||||
Field: database.DefaultUidColName,
|
||||
Operator: entity3.OperatorEqual,
|
||||
Value: strconv.FormatInt(req.UserID, 10),
|
||||
}
|
||||
|
||||
if complexCondition == nil {
|
||||
complexCondition = &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{cond},
|
||||
}
|
||||
} else {
|
||||
complexCondition.Conditions = append(complexCondition.Conditions, cond)
|
||||
}
|
||||
} else {
|
||||
complexCondition.Conditions = append(complexCondition.Conditions, cond)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,16 +21,15 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
redisV9 "github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type OAuthCache struct {
|
||||
cacheCli *redisV9.Client
|
||||
cacheCli cache.Cmdable
|
||||
}
|
||||
|
||||
func NewOAuthCache(cacheCli *redisV9.Client) *OAuthCache {
|
||||
func NewOAuthCache(cacheCli cache.Cmdable) *OAuthCache {
|
||||
return &OAuthCache{
|
||||
cacheCli: cacheCli,
|
||||
}
|
||||
@ -39,7 +38,7 @@ func NewOAuthCache(cacheCli *redisV9.Client) *OAuthCache {
|
||||
func (o *OAuthCache) Get(ctx context.Context, key string) (value string, exist bool, err error) {
|
||||
cmd := o.cacheCli.Get(ctx, key)
|
||||
if cmd.Err() != nil {
|
||||
if errors.Is(cmd.Err(), redisV9.Nil) {
|
||||
if errors.Is(cmd.Err(), cache.Nil) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, cmd.Err()
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@ -42,14 +43,19 @@ func NewPluginOAuthAuthDAO(db *gorm.DB, idGen idgen.IDGenerator) *PluginOAuthAut
|
||||
type pluginOAuthAuthPO model.PluginOauthAuth
|
||||
|
||||
func (p pluginOAuthAuthPO) ToDO() *entity.AuthorizationCodeInfo {
|
||||
secret := os.Getenv(utils.OAuthTokenSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultOAuthTokenSecret
|
||||
}
|
||||
|
||||
if p.RefreshToken != "" {
|
||||
refreshToken, err := utils.DecryptByAES(p.RefreshToken, utils.OAuthTokenSecretKey)
|
||||
refreshToken, err := utils.DecryptByAES(p.RefreshToken, secret)
|
||||
if err == nil {
|
||||
p.RefreshToken = string(refreshToken)
|
||||
}
|
||||
}
|
||||
if p.AccessToken != "" {
|
||||
accessToken, err := utils.DecryptByAES(p.AccessToken, utils.OAuthTokenSecretKey)
|
||||
accessToken, err := utils.DecryptByAES(p.AccessToken, secret)
|
||||
if err == nil {
|
||||
p.AccessToken = string(accessToken)
|
||||
}
|
||||
@ -103,16 +109,20 @@ func (p *PluginOAuthAuthDAO) Upsert(ctx context.Context, info *entity.Authorizat
|
||||
}
|
||||
|
||||
meta := info.Meta
|
||||
secret := os.Getenv(utils.OAuthTokenSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultOAuthTokenSecret
|
||||
}
|
||||
|
||||
var accessToken, refreshToken string
|
||||
if info.AccessToken != "" {
|
||||
accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), utils.OAuthTokenSecretKey)
|
||||
accessToken, err = utils.EncryptByAES([]byte(info.AccessToken), secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if info.RefreshToken != "" {
|
||||
refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), utils.OAuthTokenSecretKey)
|
||||
refreshToken, err = utils.EncryptByAES([]byte(info.RefreshToken), secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -437,7 +437,13 @@ func genAuthURL(info *entity.AuthorizationCodeInfo) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal state failed, err=%v", err)
|
||||
}
|
||||
encryptState, err := utils.EncryptByAES(stateStr, utils.StateSecretKey)
|
||||
|
||||
secret := os.Getenv(utils.StateSecretEnv)
|
||||
if secret == "" {
|
||||
secret = utils.DefaultStateSecret
|
||||
}
|
||||
|
||||
encryptState, err := utils.EncryptByAES(stateStr, secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encrypt state failed, err=%v", err)
|
||||
}
|
||||
|
||||
@ -20,18 +20,131 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/bytedance/gopkg/util/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthSecretKey = "^*6x3hdu2nc%-p38"
|
||||
StateSecretKey = "osj^kfhsd*(z!sno"
|
||||
OAuthTokenSecretKey = "cn+$PJ(HhJ[5d*z9"
|
||||
AuthSecretEnv = "PLUGIN_AES_AUTH_SECRET"
|
||||
StateSecretEnv = "PLUGIN_AES_STATE_SECRET"
|
||||
OAuthTokenSecretEnv = "PLUGIN_AES_OAUTH_TOKEN_SECRET"
|
||||
)
|
||||
|
||||
func EncryptByAES(val []byte, secretKey string) (string, error) {
|
||||
sb := []byte(secretKey)
|
||||
const encryptVersion = "aes-cbc-v1"
|
||||
|
||||
// In order to be compatible with the problem of no existing env configuration,
|
||||
// these default values are temporarily retained.
|
||||
const (
|
||||
// Deprecated. Configuring AuthSecretEnv in env instead.
|
||||
DefaultAuthSecret = "^*6x3hdu2nc%-p38"
|
||||
// Deprecated. Configuring StateSecretEnv in env instead.
|
||||
DefaultStateSecret = "osj^kfhsd*(z!sno"
|
||||
// Deprecated. Configuring OAuthTokenSecretEnv in env instead.
|
||||
DefaultOAuthTokenSecret = "cn+$PJ(HhJ[5d*z9"
|
||||
)
|
||||
|
||||
type AESEncryption struct {
|
||||
Version string `json:"version"`
|
||||
IV []byte `json:"iv"`
|
||||
EncryptedData []byte `json:"encrypted_data"`
|
||||
}
|
||||
|
||||
func EncryptByAES(val []byte, secret string) (string, error) {
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("secret is required")
|
||||
}
|
||||
|
||||
sb := []byte(secret)
|
||||
|
||||
block, err := aes.NewCipher(sb)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
blockSize := block.BlockSize()
|
||||
paddingData := pkcs7Padding(val, blockSize)
|
||||
|
||||
iv := make([]byte, blockSize)
|
||||
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encrypted := make([]byte, len(paddingData))
|
||||
blockMode := cipher.NewCBCEncrypter(block, iv)
|
||||
blockMode.CryptBlocks(encrypted, paddingData)
|
||||
|
||||
en := &AESEncryption{
|
||||
Version: encryptVersion,
|
||||
IV: iv,
|
||||
EncryptedData: encrypted,
|
||||
}
|
||||
|
||||
encrypted, err = json.Marshal(en)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.RawURLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func pkcs7Padding(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(data)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
func DecryptByAES(data, secret string) ([]byte, error) {
|
||||
if secret == "" {
|
||||
return nil, fmt.Errorf("secret is required")
|
||||
}
|
||||
|
||||
enBytes, err := base64.RawURLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
en := &AESEncryption{}
|
||||
err = json.Unmarshal(enBytes, &en)
|
||||
if err != nil { // fallback to unsafeEncryptByAES
|
||||
logger.Warnf("failed to unmarshal encrypted data, fallback to unsafeEncryptByAES: %v", err)
|
||||
return UnsafeDecryptByAES(data, secret)
|
||||
}
|
||||
|
||||
sb := []byte(secret)
|
||||
|
||||
block, err := aes.NewCipher(sb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blockMode := cipher.NewCBCDecrypter(block, en.IV)
|
||||
|
||||
if len(en.EncryptedData)%blockMode.BlockSize() != 0 {
|
||||
return nil, fmt.Errorf("invalid block size")
|
||||
}
|
||||
|
||||
decrypted := make([]byte, len(en.EncryptedData))
|
||||
blockMode.CryptBlocks(decrypted, en.EncryptedData)
|
||||
|
||||
decrypted, err = pkcs7UnPadding(decrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
// Deprecated: use EncryptByAES instead
|
||||
// UnsafeEncryptByAES is an insecure encryption method,
|
||||
// because the iv is fixed using the first 16 bits of the secret.
|
||||
func UnsafeEncryptByAES(val []byte, secret string) (string, error) {
|
||||
sb := []byte(secret)
|
||||
|
||||
block, err := aes.NewCipher(sb)
|
||||
if err != nil {
|
||||
@ -48,20 +161,18 @@ func EncryptByAES(val []byte, secretKey string) (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func pkcs7Padding(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(data)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
func DecryptByAES(data, secretKey string) ([]byte, error) {
|
||||
// Deprecated: use DecryptByAES instead
|
||||
// UnsafeDecryptByAES is an insecure decryption method,
|
||||
// because the iv is fixed using the first 16 bits of the secret.
|
||||
// In order to be compatible with existing data that has been encrypted by UnsafeEncryptByAES,
|
||||
// this method is retained as a fallback decryption method.
|
||||
func UnsafeDecryptByAES(data, secret string) ([]byte, error) {
|
||||
dataBytes, err := base64.RawURLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sb := []byte(secretKey)
|
||||
sb := []byte(secret)
|
||||
|
||||
block, err := aes.NewCipher(sb)
|
||||
if err != nil {
|
||||
|
||||
50
backend/domain/plugin/utils/aes_test.go
Normal file
50
backend/domain/plugin/utils/aes_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecryptByAES(t *testing.T) {
|
||||
mockey.PatchConvey("unsafe encryption compatibility", t, func() {
|
||||
secret := "test_secret_1234"
|
||||
plaintext := []byte("test_plaintext")
|
||||
|
||||
encrypted, err := UnsafeEncryptByAES(plaintext, secret)
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := DecryptByAES(encrypted, secret)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
})
|
||||
|
||||
mockey.PatchConvey("safe encryption", t, func() {
|
||||
secret := "test_secret_1234"
|
||||
plaintext := []byte("test_plaintext")
|
||||
|
||||
encrypted, err := EncryptByAES(plaintext, secret)
|
||||
assert.NoError(t, err)
|
||||
|
||||
decrypted, err := DecryptByAES(encrypted, secret)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
})
|
||||
}
|
||||
@ -30,6 +30,7 @@ import (
|
||||
//go:generate mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go
|
||||
type Service interface {
|
||||
GetPluginToolsInfo(ctx context.Context, req *ToolsInfoRequest) (*ToolsInfoResponse, error)
|
||||
UnwrapArrayItemFieldsInVariable(v *vo.Variable) error
|
||||
GetPluginInvokableTools(ctx context.Context, req *ToolsInvokableRequest) (map[int64]InvokableTool, error)
|
||||
ExecutePlugin(ctx context.Context, input map[string]any, pe *Entity,
|
||||
toolID int64, cfg ExecConfig) (map[string]any, error)
|
||||
|
||||
@ -29,8 +29,9 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
schema "github.com/cloudwego/eino/schema"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
vo "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
@ -103,6 +104,20 @@ func (mr *MockServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockService)(nil).GetPluginToolsInfo), ctx, req)
|
||||
}
|
||||
|
||||
// UnwrapArrayItemFieldsInVariable mocks base method.
|
||||
func (m *MockService) UnwrapArrayItemFieldsInVariable(v *vo.Variable) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnwrapArrayItemFieldsInVariable", v)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnwrapArrayItemFieldsInVariable indicates an expected call of UnwrapArrayItemFieldsInVariable.
|
||||
func (mr *MockServiceMockRecorder) UnwrapArrayItemFieldsInVariable(v any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapArrayItemFieldsInVariable", reflect.TypeOf((*MockService)(nil).UnwrapArrayItemFieldsInVariable), v)
|
||||
}
|
||||
|
||||
// MockInvokableTool is a mock of InvokableTool interface.
|
||||
type MockInvokableTool struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@ -77,8 +77,8 @@ type FunctionInfo struct {
|
||||
|
||||
type FunctionCallInfo struct {
|
||||
FunctionInfo
|
||||
CallID string `json:"-"`
|
||||
Arguments string `json:"arguments"`
|
||||
CallID string `json:"-"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolResponseInfo struct {
|
||||
|
||||
@ -655,11 +655,10 @@ func TestKnowledgeNodes(t *testing.T) {
|
||||
mockKnowledgeOperator.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(rResponse, nil)
|
||||
mockGlobalAppVarStore := mockvar.NewMockStore(ctrl)
|
||||
mockGlobalAppVarStore.EXPECT().Get(gomock.Any(), gomock.Any()).Return("v1", nil).AnyTimes()
|
||||
mockGlobalAppVarStore.EXPECT().Set(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
|
||||
variable.SetVariableHandler(&variable.Handler{
|
||||
AppVarStore: mockGlobalAppVarStore,
|
||||
})
|
||||
variable.SetVariableHandler(&variable.Handler{AppVarStore: mockGlobalAppVarStore})
|
||||
|
||||
mockey.Mock(execute.GetAppVarStore).Return(&execute.AppVariables{Vars: map[string]any{}}).Build()
|
||||
|
||||
ctx := t.Context()
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
||||
@ -2061,7 +2061,7 @@ func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
|
||||
|
||||
func parseBatchMode(n *vo.Node) (
|
||||
batchN *vo.Node, // the new batch node
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
enabled bool, // whether the node has enabled batch mode
|
||||
err error) {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return nil, false, nil
|
||||
|
||||
@ -34,7 +34,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@ -53,7 +52,6 @@ type State struct {
|
||||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
|
||||
AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -85,15 +83,7 @@ func init() {
|
||||
_ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern")
|
||||
_ = compose.RegisterSerializableType[vo.Locator]("wf_locator")
|
||||
_ = compose.RegisterSerializableType[vo.BizType]("biz_type")
|
||||
_ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables")
|
||||
}
|
||||
|
||||
func (s *State) SetAppVariableValue(key string, value any) {
|
||||
s.AppVariableStore.Set(key, value)
|
||||
}
|
||||
|
||||
func (s *State) GetAppVariableValue(key string) (any, bool) {
|
||||
return s.AppVariableStore.Get(key)
|
||||
_ = compose.RegisterSerializableType[*execute.AppVariables]("app_variables")
|
||||
}
|
||||
|
||||
func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) {
|
||||
@ -271,19 +261,6 @@ func (s *State) NodeExecuted(key vo.NodeKey) bool {
|
||||
|
||||
func GenState() compose.GenLocalState[*State] {
|
||||
return func(ctx context.Context) (state *State) {
|
||||
var parentState *State
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
|
||||
parentState = s
|
||||
return nil
|
||||
})
|
||||
|
||||
var appVariableStore *variableassigner.AppVariables
|
||||
if parentState == nil {
|
||||
appVariableStore = variableassigner.NewAppVariables()
|
||||
} else {
|
||||
appVariableStore = parentState.AppVariableStore
|
||||
}
|
||||
|
||||
return &State{
|
||||
Answers: make(map[vo.NodeKey][]string),
|
||||
Questions: make(map[vo.NodeKey][]*qa.Question),
|
||||
@ -296,7 +273,6 @@ func GenState() compose.GenLocalState[*State] {
|
||||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
AppVariableStore: appVariableStore,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -422,10 +398,9 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
||||
intermediateVarStore := &nodes.ParentIntermediateStore{}
|
||||
|
||||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
var exeCtx *execute.Context
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
@ -452,13 +427,20 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
@ -494,15 +476,18 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
||||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
exeCtx *execute.Context
|
||||
)
|
||||
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := exeCtx.RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
@ -518,13 +503,20 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
@ -776,7 +768,8 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
||||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
var exeCtx *execute.Context
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
@ -801,13 +794,20 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
@ -845,9 +845,10 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
|
||||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
exeCtx *execute.Context
|
||||
)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
if exeCtx = execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
@ -869,13 +870,20 @@ func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHand
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
if exeCtx == nil || exeCtx.AppVarStore == nil {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if v, ok = exeCtx.AppVarStore.Get(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
exeCtx.AppVarStore.Set(path, v)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
|
||||
@ -27,6 +27,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
@ -1271,13 +1273,21 @@ func (t *ToolHandler) OnStart(ctx context.Context, info *callbacks.RunInfo,
|
||||
return ctx
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if input.ArgumentsInJSON != "" {
|
||||
if err := sonic.UnmarshalString(input.ArgumentsInJSON, &args); err != nil {
|
||||
logs.Errorf("failed to unmarshal arguments: %v", err)
|
||||
return ctx
|
||||
}
|
||||
}
|
||||
|
||||
t.ch <- &Event{
|
||||
Type: FunctionCall,
|
||||
Context: GetExeCtx(ctx),
|
||||
functionCall: &entity.FunctionCallInfo{
|
||||
FunctionInfo: t.info,
|
||||
CallID: compose.GetToolCallID(ctx),
|
||||
Arguments: input.ArgumentsInJSON,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
@ -45,6 +46,8 @@ type Context struct {
|
||||
StartTime int64 // UnixMilli
|
||||
|
||||
CheckPointID string
|
||||
|
||||
AppVarStore *AppVariables
|
||||
}
|
||||
|
||||
type RootCtx struct {
|
||||
@ -106,12 +109,15 @@ func restoreWorkflowCtx(ctx context.Context, h *WorkflowHandler) (context.Contex
|
||||
}
|
||||
|
||||
storedCtx.ResumeEvent = h.resumeEvent
|
||||
currentC := GetExeCtx(ctx)
|
||||
if currentC != nil {
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
storedCtx.AppVarStore = currentC.AppVarStore
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
@ -150,13 +156,16 @@ func restoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey, resumeEvent *entity
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
}
|
||||
|
||||
currentC := GetExeCtx(ctx)
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
if storedCtx.TokenCollector != nil && storedCtx.TokenCollector.Parent != nil {
|
||||
currentC := GetExeCtx(ctx)
|
||||
currentTokenCollector := currentC.TokenCollector
|
||||
storedCtx.TokenCollector.Parent = currentTokenCollector
|
||||
}
|
||||
|
||||
storedCtx.AppVarStore = currentC.AppVarStore
|
||||
|
||||
storedCtx.NodeCtx.CurrentRetryCount = 0
|
||||
|
||||
return context.WithValue(ctx, contextKey{}, storedCtx), nil
|
||||
@ -184,6 +193,7 @@ func tryRestoreNodeCtx(ctx context.Context, nodeKey vo.NodeKey) (context.Context
|
||||
existingC := GetExeCtx(ctx)
|
||||
if existingC != nil {
|
||||
storedCtx.RootCtx.ResumeEvent = existingC.RootCtx.ResumeEvent
|
||||
storedCtx.AppVarStore = existingC.AppVarStore
|
||||
}
|
||||
|
||||
// restore the parent-child relationship between token collectors
|
||||
@ -213,6 +223,7 @@ func PrepareRootExeCtx(ctx context.Context, h *WorkflowHandler) (context.Context
|
||||
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("wf_%d", h.rootWorkflowBasic.ID), parentTokenCollector),
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
AppVarStore: NewAppVariables(),
|
||||
}
|
||||
|
||||
if h.requireCheckpoint {
|
||||
@ -266,6 +277,7 @@ func PrepareSubExeCtx(ctx context.Context, wb *entity.WorkflowBasic, requireChec
|
||||
TokenCollector: newTokenCollector(fmt.Sprintf("sub_wf_%d", wb.ID), c.TokenCollector),
|
||||
CheckPointID: newCheckpointID,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
AppVarStore: c.AppVarStore,
|
||||
}
|
||||
|
||||
if requireCheckpoint {
|
||||
@ -308,6 +320,7 @@ func PrepareNodeExeCtx(ctx context.Context, nodeKey vo.NodeKey, nodeName string,
|
||||
BatchInfo: c.BatchInfo,
|
||||
StartTime: time.Now().UnixMilli(),
|
||||
CheckPointID: c.CheckPointID,
|
||||
AppVarStore: c.AppVarStore,
|
||||
}
|
||||
|
||||
if c.NodeCtx == nil { // node within top level workflow, also not under composite node
|
||||
@ -354,6 +367,7 @@ func InheritExeCtxWithBatchInfo(ctx context.Context, index int, items map[string
|
||||
CompositeNodeKey: c.NodeCtx.NodeKey,
|
||||
},
|
||||
CheckPointID: newCheckpointID,
|
||||
AppVarStore: c.AppVarStore,
|
||||
}), newCheckpointID
|
||||
}
|
||||
|
||||
@ -363,3 +377,38 @@ type ExeContextStore interface {
|
||||
GetWorkflowCtx() (*Context, bool, error)
|
||||
SetWorkflowCtx(value *Context) error
|
||||
}
|
||||
|
||||
type AppVariables struct {
|
||||
Vars map[string]any
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAppVariables() *AppVariables {
|
||||
return &AppVariables{
|
||||
Vars: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (av *AppVariables) Set(key string, value any) {
|
||||
av.mu.Lock()
|
||||
av.Vars[key] = value
|
||||
av.mu.Unlock()
|
||||
}
|
||||
|
||||
func (av *AppVariables) Get(key string) (any, bool) {
|
||||
av.mu.RLock()
|
||||
defer av.mu.RUnlock()
|
||||
|
||||
if value, ok := av.Vars[key]; ok {
|
||||
return value, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func GetAppVarStore(ctx context.Context) *AppVariables {
|
||||
c := ctx.Value(contextKey{})
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.(*Context).AppVarStore
|
||||
}
|
||||
|
||||
@ -887,7 +887,12 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
|
||||
}
|
||||
|
||||
func (f *fcInfo) inputString() string {
|
||||
if f.input == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
m, err := sonic.MarshalString(f.input)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -899,12 +904,5 @@ func (f *fcInfo) outputString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"data": f.output.Response, // TODO: traceID, code, message?
|
||||
}
|
||||
b, err := sonic.MarshalString(m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
return f.output.Response
|
||||
}
|
||||
|
||||
@ -108,8 +108,8 @@ var pythonBuiltinBlacklist = map[string]struct{}{
|
||||
// If you want to use other third-party libraries, you can add them to this whitelist.
|
||||
// And you also need to install them in `/scripts/setup/python.sh` and `/backend/Dockerfile` via `pip install`.
|
||||
var pythonThirdPartyWhitelist = map[string]struct{}{
|
||||
"requests_async": {},
|
||||
"numpy": {},
|
||||
"httpx": {},
|
||||
"numpy": {},
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
||||
@ -18,9 +18,9 @@ package variableassigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
@ -32,38 +32,6 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type AppVariables struct {
|
||||
vars map[string]any
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAppVariables() *AppVariables {
|
||||
return &AppVariables{
|
||||
vars: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (av *AppVariables) Set(key string, value any) {
|
||||
av.mu.Lock()
|
||||
av.vars[key] = value
|
||||
av.mu.Unlock()
|
||||
}
|
||||
|
||||
func (av *AppVariables) Get(key string) (any, bool) {
|
||||
av.mu.RLock()
|
||||
defer av.mu.RUnlock()
|
||||
|
||||
if value, ok := av.vars[key]; ok {
|
||||
return value, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type AppVariableStore interface {
|
||||
GetAppVariableValue(key string) (any, bool)
|
||||
SetAppVariableValue(key string, value any)
|
||||
}
|
||||
|
||||
type VariableAssigner struct {
|
||||
config *Config
|
||||
}
|
||||
@ -109,16 +77,16 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
|
||||
vType := *pair.Left.VariableType
|
||||
switch vType {
|
||||
case vo.GlobalAPP:
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error {
|
||||
if len(pair.Left.FromPath) != 1 {
|
||||
return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
|
||||
}
|
||||
appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
appVS := execute.GetAppVarStore(ctx)
|
||||
if appVS == nil {
|
||||
return nil, errors.New("exeCtx or AppVarStore not found for variable assigner")
|
||||
}
|
||||
|
||||
if len(pair.Left.FromPath) != 1 {
|
||||
return nil, fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
|
||||
}
|
||||
|
||||
appVS.Set(pair.Left.FromPath[0], right)
|
||||
case vo.GlobalUser:
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
|
||||
@ -21,14 +21,13 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type cancelSignalStoreImpl struct {
|
||||
redis *redis.Client
|
||||
redis cache.Cmdable
|
||||
}
|
||||
|
||||
const workflowExecutionCancelStatusKey = "workflow:cancel:status:%d"
|
||||
|
||||
@ -23,13 +23,13 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
@ -40,7 +40,7 @@ import (
|
||||
|
||||
type executeHistoryStoreImpl struct {
|
||||
query *query.Query
|
||||
redis *redis.Client
|
||||
redis cache.Cmdable
|
||||
}
|
||||
|
||||
func (e *executeHistoryStoreImpl) CreateWorkflowExecution(ctx context.Context, execution *entity.WorkflowExecution) (err error) {
|
||||
@ -457,7 +457,7 @@ func (e *executeHistoryStoreImpl) loadNodeExecutionFromRedis(ctx context.Context
|
||||
|
||||
result, err := e.redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return nil
|
||||
}
|
||||
return vo.WrapError(errno.ErrRedisError, err)
|
||||
@ -523,7 +523,7 @@ func (e *executeHistoryStoreImpl) GetTestRunLatestExeID(ctx context.Context, wfI
|
||||
key := fmt.Sprintf(testRunLastExeKey, wfID, uID)
|
||||
exeIDStr, err := e.redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, vo.WrapError(errno.ErrRedisError, err)
|
||||
@ -548,7 +548,7 @@ func (e *executeHistoryStoreImpl) GetNodeDebugLatestExeID(ctx context.Context, w
|
||||
key := fmt.Sprintf(nodeDebugLastExeKey, wfID, nodeID, uID)
|
||||
exeIDStr, err := e.redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, vo.WrapError(errno.ErrRedisError, err)
|
||||
|
||||
@ -22,16 +22,15 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type interruptEventStoreImpl struct {
|
||||
redis *redis.Client
|
||||
redis cache.Cmdable
|
||||
}
|
||||
|
||||
const (
|
||||
@ -81,7 +80,7 @@ func (i *interruptEventStoreImpl) SaveInterruptEvents(ctx context.Context, wfExe
|
||||
|
||||
previousEventStr, err := i.redis.Get(ctx, previousResumedEventKey).Result()
|
||||
if err != nil {
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
if !errors.Is(err, cache.Nil) {
|
||||
return fmt.Errorf("failed to get previous resumed event for wfExeID %d: %w", wfExeID, err)
|
||||
}
|
||||
}
|
||||
@ -154,7 +153,7 @@ func (i *interruptEventStoreImpl) GetFirstInterruptEvent(ctx context.Context, wf
|
||||
|
||||
eventJSON, err := i.redis.LIndex(ctx, listKey, 0).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return nil, false, nil // List is empty or key does not exist
|
||||
}
|
||||
return nil, false, fmt.Errorf("failed to get first interrupt event from Redis list for wfExeID %d: %w", wfExeID, err)
|
||||
@ -203,7 +202,7 @@ func (i *interruptEventStoreImpl) PopFirstInterruptEvent(ctx context.Context, wf
|
||||
|
||||
eventJSON, err := i.redis.LPop(ctx, listKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return nil, false, nil // List is empty or key does not exist
|
||||
}
|
||||
return nil, false, vo.WrapError(errno.ErrRedisError,
|
||||
@ -227,7 +226,7 @@ func (i *interruptEventStoreImpl) ListInterruptEvents(ctx context.Context, wfExe
|
||||
|
||||
eventJSONs, err := i.redis.LRange(ctx, listKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return nil, nil // List is empty or key does not exist
|
||||
}
|
||||
return nil, vo.WrapError(errno.ErrRedisError,
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/exp/maps"
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
@ -41,6 +40,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
cm "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
@ -61,7 +61,7 @@ const (
|
||||
type RepositoryImpl struct {
|
||||
idgen.IDGenerator
|
||||
query *query.Query
|
||||
redis *redis.Client
|
||||
redis cache.Cmdable
|
||||
tos storage.Storage
|
||||
einoCompose.CheckPointStore
|
||||
workflow.InterruptEventStore
|
||||
@ -70,7 +70,7 @@ type RepositoryImpl struct {
|
||||
builtinModel cm.BaseChatModel
|
||||
}
|
||||
|
||||
func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis *redis.Client, tos storage.Storage,
|
||||
func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
|
||||
cpStore einoCompose.CheckPointStore, chatModel cm.BaseChatModel) workflow.Repository {
|
||||
return &RepositoryImpl{
|
||||
IDGenerator: idgen,
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"time"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
@ -67,7 +68,7 @@ func NewWorkflowService(repo workflow.Repository) workflow.Service {
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis *redis.Client, tos storage.Storage,
|
||||
func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
|
||||
cpStore einoCompose.CheckPointStore, chatModel chatmodel.BaseChatModel) workflow.Repository {
|
||||
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel)
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ require (
|
||||
go.uber.org/mock v0.5.1
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
|
||||
golang.org/x/mod v0.25.0
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/sync v0.16.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@ -55,7 +55,7 @@ require github.com/alicebob/miniredis/v2 v2.34.0
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0
|
||||
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09
|
||||
github.com/cloudwego/eino-ext/components/model/gemini v0.1.2
|
||||
|
||||
@ -926,8 +926,8 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy
|
||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/eino v0.3.55 h1:lMZrGtEh0k3qykQTLNXSXuAa98OtF2tS43GMHyvN7nA=
|
||||
github.com/cloudwego/eino v0.3.55/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09 h1:hZScBE/Etiji2RqjlABcAkq6n1uzYPu+jo4GV5TF8Hc=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.0.0-20250522060253-ddb617598b09/go.mod h1:pLtH5BZKgb7/bB8+P3W5/f1d46gTl9K77+08j88Gb4k=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0 h1:AuJsMdaTXc+dGUDQp82MifLYK8oiJf4gLQPUETmKISM=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.0/go.mod h1:0FZG/KRBl3hGWkNsm55UaXyVa6PDVIy5u+QvboAB+cY=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8 h1:uJrs6SmfYnca8A+k9+3qJ4MYwYHMncUlGac1mYQT+Ak=
|
||||
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20250728060543-79ec300857b8/go.mod h1:nav79aUcd+UR24dLA+7l7RcHCMlg26zbDAKvjONdrw0=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20250522060253-ddb617598b09 h1:C8RjF193iguUuevkuv0q4SC+XGlM/DlJEgic7l8OUAI=
|
||||
@ -2261,8 +2261,8 @@ golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
||||
2
backend/infra/contract/cache/cache.go
vendored
2
backend/infra/contract/cache/cache.go
vendored
@ -21,3 +21,5 @@ import (
|
||||
)
|
||||
|
||||
type Cmdable = redis.Cmdable
|
||||
|
||||
const Nil = redis.Nil
|
||||
|
||||
@ -24,10 +24,12 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
)
|
||||
|
||||
type redisStore struct {
|
||||
client *redis.Client
|
||||
client cache.Cmdable
|
||||
}
|
||||
|
||||
const (
|
||||
@ -38,7 +40,7 @@ const (
|
||||
func (r *redisStore) Get(ctx context.Context, checkPointID string) ([]byte, bool, error) {
|
||||
v, err := r.client.Get(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID)).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
if errors.Is(err, cache.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, false, err
|
||||
|
||||
@ -22,8 +22,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
@ -90,7 +88,7 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS
|
||||
err error
|
||||
)
|
||||
errMsg, err = p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarErrMsgRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil {
|
||||
if err == cache.Nil {
|
||||
errMsg = ""
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
@ -99,7 +97,7 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS
|
||||
return ProcessDone, 0, errMsg
|
||||
}
|
||||
totalNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarTotalNumRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(totalNumStr) == 0 {
|
||||
if err == cache.Nil || len(totalNumStr) == 0 {
|
||||
totalNum = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
@ -112,7 +110,7 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS
|
||||
}
|
||||
}
|
||||
processedNumStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarProcessedNumRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(processedNumStr) == 0 {
|
||||
if err == cache.Nil || len(processedNumStr) == 0 {
|
||||
processedNum = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
@ -128,7 +126,7 @@ func (p *ProgressBarImpl) GetProgress(ctx context.Context) (percent int, remainS
|
||||
return ProcessInit, DefaultProcessTime, ""
|
||||
}
|
||||
startTimeStr, err := p.CacheCli.Get(ctx, fmt.Sprintf(ProgressBarStartTimeRedisKey, p.PrimaryKeyID)).Result()
|
||||
if err == redis.Nil || len(startTimeStr) == 0 {
|
||||
if err == cache.Nil || len(startTimeStr) == 0 {
|
||||
startTime = ptr.Of(int64(0))
|
||||
} else if err != nil {
|
||||
return ProcessDone, 0, err.Error()
|
||||
|
||||
@ -72,6 +72,12 @@ func (v *vkSearchStore) Store(ctx context.Context, docs []*schema.Document, opts
|
||||
if err := v.collection.UpsertData(docsWithVector); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if implSpecOptions.ProgressBar != nil {
|
||||
if err = implSpecOptions.ProgressBar.AddN(len(part)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ids = slices.Transform(docs, func(a *schema.Document) string { return a.ID })
|
||||
|
||||
@ -21,7 +21,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
@ -33,7 +33,7 @@ const (
|
||||
|
||||
type IDGenerator = idgen.IDGenerator
|
||||
|
||||
func New(client *redis.Client) (idgen.IDGenerator, error) {
|
||||
func New(client cache.Cmdable) (idgen.IDGenerator, error) {
|
||||
// Initialization code.
|
||||
return &idGenImpl{
|
||||
cli: client,
|
||||
@ -41,7 +41,7 @@ func New(client *redis.Client) (idgen.IDGenerator, error) {
|
||||
}
|
||||
|
||||
type idGenImpl struct {
|
||||
cli *redis.Client
|
||||
cli cache.Cmdable
|
||||
namespace string
|
||||
}
|
||||
|
||||
|
||||
@ -21,15 +21,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
)
|
||||
|
||||
type JsonCache[T any] struct {
|
||||
cache *redis.Client
|
||||
cache cache.Cmdable
|
||||
prefix string
|
||||
}
|
||||
|
||||
func New[T any](prefix string, cache *redis.Client) *JsonCache[T] {
|
||||
func New[T any](prefix string, cache cache.Cmdable) *JsonCache[T] {
|
||||
return &JsonCache[T]{
|
||||
prefix: prefix,
|
||||
cache: cache,
|
||||
@ -59,7 +59,7 @@ func (g *JsonCache[T]) Get(ctx context.Context, k string) (*T, error) {
|
||||
var obj T
|
||||
|
||||
data, err := g.cache.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
if err == cache.Nil {
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
|
||||
@ -96,33 +96,35 @@ export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to
|
||||
# Settings for Embedding
|
||||
# The Embedding model relied on by knowledge base vectorization does not need to be configured
|
||||
# if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently,
|
||||
# Coze Studio supports three access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using
|
||||
# Coze Studio supports four access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using
|
||||
# embedding type: openai / ark / ollama / http
|
||||
export EMBEDDING_TYPE="ark"
|
||||
export EMBEDDING_MAX_BATCH_SIZE=100
|
||||
# openai embedding
|
||||
export OPENAI_EMBEDDING_BASE_URL="" # (string) OpenAI base_url
|
||||
export OPENAI_EMBEDDING_MODEL="" # (string) OpenAI embedding model
|
||||
export OPENAI_EMBEDDING_API_KEY="" # (string) OpenAI api_key
|
||||
export OPENAI_EMBEDDING_BY_AZURE=false # (bool) OpenAI by_azure
|
||||
export OPENAI_EMBEDDING_API_VERSION="" # OpenAI azure api version
|
||||
export OPENAI_EMBEDDING_DIMS=1024 # (int) 向量维度
|
||||
export OPENAI_EMBEDDING_REQUEST_DIMS=1024
|
||||
|
||||
# ark embedding
|
||||
export ARK_EMBEDDING_MODEL=""
|
||||
export ARK_EMBEDDING_API_KEY=""
|
||||
export ARK_EMBEDDING_DIMS="2048"
|
||||
export ARK_EMBEDDING_BASE_URL=""
|
||||
# openai embedding
|
||||
export OPENAI_EMBEDDING_BASE_URL="" # (string, required) OpenAI embedding base_url
|
||||
export OPENAI_EMBEDDING_MODEL="" # (string, required) OpenAI embedding model
|
||||
export OPENAI_EMBEDDING_API_KEY="" # (string, required) OpenAI embedding api_key
|
||||
export OPENAI_EMBEDDING_BY_AZURE=false # (bool, optional) OpenAI embedding by_azure
|
||||
export OPENAI_EMBEDDING_API_VERSION="" # (string, optional) OpenAI embedding azure api version
|
||||
export OPENAI_EMBEDDING_DIMS=1024 # (int, required) OpenAI embedding dimensions
|
||||
export OPENAI_EMBEDDING_REQUEST_DIMS=1024 # (int, optional) OpenAI embedding dimensions in requests, need to be empty if api doesn't support specifying dimensions.
|
||||
|
||||
# ark embedding by volcengine / byteplus
|
||||
export ARK_EMBEDDING_MODEL="" # (string, required) Ark embedding model
|
||||
export ARK_EMBEDDING_API_KEY="" # (string, required) Ark embedding api_key
|
||||
export ARK_EMBEDDING_DIMS="2048" # (int, required) Ark embedding dimensions
|
||||
export ARK_EMBEDDING_BASE_URL="" # (string, required) Ark embedding base_url
|
||||
export ARK_EMBEDDING_API_TYPE="" # (string, optional) Ark embedding api type, should be "text_api" / "multi_modal_api". Default "text_api".
|
||||
|
||||
# ollama embedding
|
||||
export OLLAMA_EMBEDDING_BASE_URL=""
|
||||
export OLLAMA_EMBEDDING_MODEL=""
|
||||
export OLLAMA_EMBEDDING_DIMS=""
|
||||
export OLLAMA_EMBEDDING_BASE_URL="" # (string, required) Ollama embedding base_url
|
||||
export OLLAMA_EMBEDDING_MODEL="" # (string, required) Ollama embedding model
|
||||
export OLLAMA_EMBEDDING_DIMS="" # (int, required) Ollama embedding dimensions
|
||||
|
||||
# http embedding
|
||||
export HTTP_EMBEDDING_ADDR=""
|
||||
export HTTP_EMBEDDING_DIMS=1024
|
||||
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
||||
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
||||
|
||||
# Settings for OCR
|
||||
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
||||
@ -217,3 +219,10 @@ export CODE_RUNNER_MEMORY_LIMIT_MB=""
|
||||
export DISABLE_USER_REGISTRATION="" # default "", if you want to disable, set to true
|
||||
export ALLOW_REGISTRATION_EMAIL="" # is a list of email addresses, separated by ",". Example: "11@example.com,22@example.com"
|
||||
|
||||
# Plugin AES secret
|
||||
# PLUGIN_AES_AUTH_SECRET is the secret of used to encrypt plugin authorization payload.
|
||||
export PLUGIN_AES_AUTH_SECRET="^*6x3hdu2nc%-p38"
|
||||
# PLUGIN_AES_STATE_SECRET is the secret of used to encrypt oauth state.
|
||||
export PLUGIN_AES_STATE_SECRET="osj^kfhsd*(z!sno"
|
||||
# PLUGIN_AES_OAUTH_TOKEN_SECRET is the secret of used to encrypt oauth refresh token and access token.
|
||||
export PLUGIN_AES_OAUTH_TOKEN_SECRET="cn+$PJ(HhJ[5d*z9"
|
||||
|
||||
@ -96,33 +96,35 @@ export VIKING_DB_MODEL_NAME="" # if vikingdb model name is not set, you need to
|
||||
# Settings for Embedding
|
||||
# The Embedding model relied on by knowledge base vectorization does not need to be configured
|
||||
# if the vector database comes with built-in Embedding functionality (such as VikingDB). Currently,
|
||||
# Coze Studio supports three access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using
|
||||
# Coze Studio supports four access methods: openai, ark, ollama, and custom http. Users can simply choose one of them when using
|
||||
# embedding type: openai / ark / ollama / http
|
||||
export EMBEDDING_TYPE="ark"
|
||||
export EMBEDDING_MAX_BATCH_SIZE=100
|
||||
# openai embedding
|
||||
export OPENAI_EMBEDDING_BASE_URL="" # (string) OpenAI base_url
|
||||
export OPENAI_EMBEDDING_MODEL="" # (string) OpenAI embedding model
|
||||
export OPENAI_EMBEDDING_API_KEY="" # (string) OpenAI api_key
|
||||
export OPENAI_EMBEDDING_BY_AZURE=false # (bool) OpenAI by_azure
|
||||
export OPENAI_EMBEDDING_API_VERSION="" # OpenAI azure api version
|
||||
export OPENAI_EMBEDDING_DIMS=1024 # (int) 向量维度
|
||||
export OPENAI_EMBEDDING_REQUEST_DIMS=1024
|
||||
|
||||
# ark embedding
|
||||
export ARK_EMBEDDING_MODEL=""
|
||||
export ARK_EMBEDDING_API_KEY=""
|
||||
export ARK_EMBEDDING_DIMS="2048"
|
||||
export ARK_EMBEDDING_BASE_URL=""
|
||||
# openai embedding
|
||||
export OPENAI_EMBEDDING_BASE_URL="" # (string, required) OpenAI embedding base_url
|
||||
export OPENAI_EMBEDDING_MODEL="" # (string, required) OpenAI embedding model
|
||||
export OPENAI_EMBEDDING_API_KEY="" # (string, required) OpenAI embedding api_key
|
||||
export OPENAI_EMBEDDING_BY_AZURE=false # (bool, optional) OpenAI embedding by_azure
|
||||
export OPENAI_EMBEDDING_API_VERSION="" # (string, optional) OpenAI embedding azure api version
|
||||
export OPENAI_EMBEDDING_DIMS=1024 # (int, required) OpenAI embedding dimensions
|
||||
export OPENAI_EMBEDDING_REQUEST_DIMS=1024 # (int, optional) OpenAI embedding dimensions in requests, need to be empty if api doesn't support specifying dimensions.
|
||||
|
||||
# ark embedding by volcengine / byteplus
|
||||
export ARK_EMBEDDING_MODEL="" # (string, required) Ark embedding model
|
||||
export ARK_EMBEDDING_API_KEY="" # (string, required) Ark embedding api_key
|
||||
export ARK_EMBEDDING_DIMS="2048" # (int, required) Ark embedding dimensions
|
||||
export ARK_EMBEDDING_BASE_URL="" # (string, required) Ark embedding base_url
|
||||
export ARK_EMBEDDING_API_TYPE="" # (string, optional) Ark embedding api type, should be "text_api" / "multi_modal_api". Default "text_api".
|
||||
|
||||
# ollama embedding
|
||||
export OLLAMA_EMBEDDING_BASE_URL=""
|
||||
export OLLAMA_EMBEDDING_MODEL=""
|
||||
export OLLAMA_EMBEDDING_DIMS=""
|
||||
export OLLAMA_EMBEDDING_BASE_URL="" # (string, required) Ollama embedding base_url
|
||||
export OLLAMA_EMBEDDING_MODEL="" # (string, required) Ollama embedding model
|
||||
export OLLAMA_EMBEDDING_DIMS="" # (int, required) Ollama embedding dimensions
|
||||
|
||||
# http embedding
|
||||
export HTTP_EMBEDDING_ADDR=""
|
||||
export HTTP_EMBEDDING_DIMS=1024
|
||||
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
|
||||
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
|
||||
|
||||
# Settings for OCR
|
||||
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
|
||||
@ -217,3 +219,10 @@ export CODE_RUNNER_MEMORY_LIMIT_MB=""
|
||||
export DISABLE_USER_REGISTRATION="" # default "", if you want to disable, set to true
|
||||
export ALLOW_REGISTRATION_EMAIL="" # is a list of email addresses, separated by ",". Example: "11@example.com,22@example.com"
|
||||
|
||||
# Plugin AES secret
|
||||
# PLUGIN_AES_AUTH_SECRET is the secret of used to encrypt plugin authorization payload.
|
||||
export PLUGIN_AES_AUTH_SECRET="^*6x3hdu2nc%-p38"
|
||||
# PLUGIN_AES_STATE_SECRET is the secret of used to encrypt oauth state.
|
||||
export PLUGIN_AES_STATE_SECRET="osj^kfhsd*(z!sno"
|
||||
# PLUGIN_AES_OAUTH_TOKEN_SECRET is the secret of used to encrypt oauth refresh token and access token.
|
||||
export PLUGIN_AES_OAUTH_TOKEN_SECRET="cn+$PJ(HhJ[5d*z9"
|
||||
@ -29,24 +29,7 @@ source "$VENV_DIR/bin/activate"
|
||||
pip install --upgrade pip
|
||||
# If you want to use other third-party libraries, you can install them here.
|
||||
pip install urllib3==1.26.16
|
||||
|
||||
REQUESTS_ASYNC_REPO_URL="https://gitcode.com/gh_mirrors/re/requests-async.git"
|
||||
REQUESTS_ASYNC_DIR="$BIN_DIR/requests-async"
|
||||
|
||||
if [ ! -d "$REQUESTS_ASYNC_DIR/.git" ]; then
|
||||
echo "Cloning requests-async repository..."
|
||||
rm -rf "$REQUESTS_ASYNC_DIR"
|
||||
git clone "$REQUESTS_ASYNC_REPO_URL" "$REQUESTS_ASYNC_DIR"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed to clone requests-async repository - aborting startup"
|
||||
deactivate
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "requests-async repository already exists."
|
||||
fi
|
||||
|
||||
pip install pillow==11.2.1 pdfplumber==0.11.7 python-docx==1.2.0 numpy==2.3.1 "$REQUESTS_ASYNC_DIR"
|
||||
pip install h11==0.16.0 httpx==0.28.1 pillow==11.2.1 pdfplumber==0.11.7 python-docx==1.2.0 numpy==2.3.1
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed to install Python packages - aborting startup"
|
||||
|
||||
Reference in New Issue
Block a user