Compare commits
55 Commits
release/v0
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 24e8dfdb45 | |||
| aae865dafb | |||
| 77e1931494 | |||
| 5562800958 | |||
| 263a75b1c0 | |||
| 901d0252e8 | |||
| 5ecdddbacb | |||
| 2a704fc873 | |||
| f19761fa31 | |||
| 14ce6bc112 | |||
| 6fa2acf05a | |||
| fc47e4096c | |||
| 035ed2450b | |||
| 0fef5a1634 | |||
| f09c624988 | |||
| 59c1d9aa03 | |||
| 19c63a1150 | |||
| 09d00c26cb | |||
| 3d53aaa785 | |||
| 5044cb2b85 | |||
| e7070b419c | |||
| f956c18a09 | |||
| a4b11729a6 | |||
| 1dc00e4df8 | |||
| 5e9740c047 | |||
| f940edf585 | |||
| 23a468c72c | |||
| 85e6926a14 | |||
| a9b87c188b | |||
| ee03b41ad5 | |||
| 18e45b333f | |||
| f040a511e4 | |||
| dfa9eb44e1 | |||
| 4ff734f15f | |||
| ff00dcb31b | |||
| 710bbbff2b | |||
| a734d9d8af | |||
| 174da78c78 | |||
| d58783b11c | |||
| 3030d4d627 | |||
| c79ee64fe8 | |||
| 8994cec367 | |||
| dce313b8e3 | |||
| 5d98e8ef93 | |||
| 8c3ae99643 | |||
| e0800abb99 | |||
| ffbc108875 | |||
| 6b60c07c22 | |||
| 708a6ed0c0 | |||
| 99c759addc | |||
| b38ab95623 | |||
| 9ff065cebd | |||
| e7011f2549 | |||
| 643a448157 | |||
| e03cf4cc87 |
121
.github/workflows/ci@backend.yml
vendored
Normal file
121
.github/workflows/ci@backend.yml
vendored
Normal file
@ -0,0 +1,121 @@
|
||||
name: Backend Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'docker/atlas/**'
|
||||
- '.github/workflows/ci@backend.yml'
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'docker/atlas/**'
|
||||
- '.github/workflows/ci@backend.yml'
|
||||
# Allows you to run this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
env:
|
||||
DEFAULT_GO_VERSION: "1.24"
|
||||
|
||||
jobs:
|
||||
backend-unit-test:
|
||||
name: backend-unit-test
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
repository-projects: write
|
||||
env:
|
||||
COVERAGE_FILE: coverage.out
|
||||
BREAKDOWN_FILE: main.breakdown
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ env.DEFAULT_GO_VERSION }}
|
||||
|
||||
# - name: Shutdown Ubuntu MySQL
|
||||
# run: service mysql stop
|
||||
|
||||
- name: Set Up MySQL
|
||||
uses: mirromutth/mysql-action@v1.1
|
||||
with:
|
||||
host port: 3306
|
||||
container port: 3306
|
||||
character set server: 'utf8mb4'
|
||||
collation server: 'utf8mb4_general_ci'
|
||||
mysql version: '8.4.5'
|
||||
mysql database: 'opencoze'
|
||||
mysql root password: 'root'
|
||||
|
||||
- name: Verify MySQL Startup
|
||||
run: |
|
||||
echo "Waiting for MySQL to be ready..."
|
||||
for i in {1..60}; do
|
||||
if cat /proc/net/tcp | grep 0CEA; then
|
||||
echo "MySQL port 3306 is listening!"
|
||||
break
|
||||
fi
|
||||
echo "Waiting for MySQL port... ($i/60)"
|
||||
sleep 1
|
||||
done
|
||||
echo "Final verification: MySQL port 3306 is accessible"
|
||||
|
||||
- name: Install MySQL Client
|
||||
run: sudo apt-get update && sudo apt-get install -y mysql-client
|
||||
|
||||
- name: Initialize Database
|
||||
run: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
|
||||
|
||||
- name: Run Go Test
|
||||
run: |
|
||||
modules=`find . -name "go.mod" -exec dirname {} \;`
|
||||
echo $modules
|
||||
list=""
|
||||
coverpkg=""
|
||||
if [[ ! -f "go.work" ]];then go work init;fi
|
||||
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
|
||||
go work sync
|
||||
go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
name: coze-studio-backend
|
||||
env_vars: GOLANG,Coze-Studio,BACKEND
|
||||
fail_ci_if_error: 'false'
|
||||
files: ${{ env.COVERAGE_FILE }}
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Shutdown MySQL
|
||||
if: always()
|
||||
continue-on-error: true
|
||||
run: docker rm -f $(docker ps -q --filter "ancestor=mysql:8.4.5")
|
||||
benchmark-test:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
repository-projects: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ env.DEFAULT_GO_VERSION }}
|
||||
|
||||
- name: Run Go Benchmark
|
||||
run: |
|
||||
modules=`find . -name "go.mod" -exec dirname {} \;`
|
||||
echo $modules
|
||||
list=""
|
||||
coverpkg=""
|
||||
if [[ ! -f "go.work" ]];then go work init;fi
|
||||
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
|
||||
go work sync
|
||||
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list
|
||||
1
Makefile
1
Makefile
@ -116,6 +116,7 @@ help:
|
||||
@echo " middleware - Setup middlewares docker environment, but exclude the server app."
|
||||
@echo " web - Setup web docker environment, include middlewares docker."
|
||||
@echo " down - Stop the docker containers."
|
||||
@echo " down_web - Stop the web docker containers."
|
||||
@echo " clean - Stop the docker containers and clean volumes."
|
||||
@echo " python - Setup python environment."
|
||||
@echo " atlas-hash - Rehash atlas migration files."
|
||||
|
||||
12
README.md
12
README.md
@ -37,7 +37,6 @@ The backend of Coze Studio is developed using Golang, the frontend uses React +
|
||||
|
||||
## Quickstart
|
||||
Learn how to obtain and deploy the open-source version of Coze Studio, quickly build projects, and experience Coze Studio's open-source version.
|
||||
> Detailed steps and deployment requirements can be found in [Quickstart](https://github.com/coze-dev/coze-studio/wiki/2.-Quickstart).
|
||||
|
||||
Environment requirements:
|
||||
|
||||
@ -63,9 +62,10 @@ Deployment steps:
|
||||
2. Modify the template file in the configuration file directory.
|
||||
1. Enter the directory `backend/conf/model`. Open the file `ark_doubao-seed-1.6.yaml`.
|
||||
2. Set the fields `id`, `meta.conn_config.api_key`, `meta.conn_config.model`, and save the file.
|
||||
* **id**: The model ID in Coze Studio, defined by the developers themselves, must be a non-zero integer and globally unique. Do not modify the model ID after the model goes online.
|
||||
* **meta.conn_config.api_key**: The API Key for the model service, which in this example is the API Key for Volcengine Ark. Refer to [Retrieve Volcengine Ark API Key](https://www.volcengine.com/docs/82379/1541594) for the acquisition method.
|
||||
* **meta.conn_config.model**: The model ID of the model service, which in this example is the Endpoint ID of the Volcengine Ark doubao-seed-1.6 model access point. For retrieval methods, refer to [Retrieve Endpoint ID](https://www.volcengine.com/docs/82379/1099522).
|
||||
* **id**: The model ID in Coze Studio, defined by the developer, must be a non-zero integer and globally unique. Agents or workflows call models based on model IDs. For models that have already been launched, do not modify their IDs; otherwise, it may result in model call failures.
|
||||
* **meta.conn_config.api_key**: The API Key for the model service. In this example, it is the API Key for Ark API Key. For more information, see [Get Volcengine Ark API Key](https://www.volcengine.com/docs/82379/1541594) or [Get BytePlus ModelArk API Key](https://docs.byteplus.com/en/docs/ModelArk/1361424?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source).
|
||||
* **meta.conn_config.model**: The Model name for the model service. In this example, it refers to the Model ID or Endpoint ID of Ark. For more information, see [Get Volcengine Ark Model ID](https://www.volcengine.com/docs/82379/1513689) / [Get Volcengine Ark Endpoint ID](https://www.volcengine.com/docs/82379/1099522) or [Get BytePlus ModelArk Model ID](https://docs.byteplus.com/en/docs/ModelArk/model_id?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source) / [Get BytePlus ModelArk Endpoint ID](https://docs.byteplus.com/en/docs/ModelArk/1099522?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source).
|
||||
> For users in China, you may use Volcengine Ark; for users outside China, you may use BytePlus ModelArk instead.
|
||||
3. Deploy and start the service.
|
||||
When deploying and starting Coze Studio for the first time, it may take a while to retrieve images and build local images. Please be patient. During deployment, you will see the following log information. If you see the message "Container coze-server Started," it means the Coze Studio service has started successfully.
|
||||
```Bash
|
||||
@ -77,6 +77,8 @@ Deployment steps:
|
||||
For common startup failure issues, **please refer to the [FAQ](https://github.com/coze-dev/coze-studio/wiki/9.-FAQ)**.
|
||||
4. After starting the service, you can open Coze Studio by accessing `http://localhost:8888/` through your browser.
|
||||
|
||||
> [!WARNING]
|
||||
> If you want to deploy Coze Studio in a public network environment, it is recommended to assess security risks before you begin, and take corresponding protection measures. Possible security risks include account registration functions, Python execution environments in workflow code nodes, Coze Server listening address configurations, SSRF (Server - Side Request Forgery), and some horizontal privilege escalations in APIs. For more details, refer to [Quickstart](https://github.com/coze-dev/coze-studio/wiki/2.-Quickstart#security-risks-in-public-networks).
|
||||
|
||||
## Developer Guide
|
||||
|
||||
@ -106,7 +108,7 @@ This project uses the Apache 2.0 license. For details, please refer to the [LICE
|
||||
## Community contributions
|
||||
We welcome community contributions. For contribution guidelines, please refer to [CONTRIBUTING](https://github.com/coze-dev/coze-studio/blob/main/CONTRIBUTING.md) and [Code of conduct](https://github.com/coze-dev/coze-studio/blob/main/CODE_OF_CONDUCT.md). We look forward to your contributions!
|
||||
## Security and privacy
|
||||
If you discover potential security issues in the project, or believe you may have found a security issue, please notify the ByteDance security team through our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](sec@bytedance.com).
|
||||
If you discover potential security issues in the project, or believe you may have found a security issue, please notify the ByteDance security team through our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](mailto:sec@bytedance.com).
|
||||
Please **do not** create public GitHub Issues.
|
||||
## Join Community
|
||||
|
||||
|
||||
@ -37,7 +37,6 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript,
|
||||
| API 与 SDK | * 创建会话、发起对话等 OpenAPI <br> * 通过 Chat SDK 将智能体或应用集成到自己的应用 |
|
||||
## 快速开始
|
||||
了解如何获取并部署 Coze Studio 开源版,快速构建项目、体验 Coze Studio 开源版。
|
||||
> 详细步骤及部署要求可参考[快速开始](https://github.com/coze-dev/coze-studio/wiki/2.-快速开始)。
|
||||
|
||||
环境要求:
|
||||
|
||||
@ -63,9 +62,10 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript,
|
||||
2. 在配置文件目录下,修改模版文件。
|
||||
1. 进入目录 `backend/conf/model`。打开复制后的文件`ark_doubao-seed-1.6.yaml`。
|
||||
2. 设置 `id`、`meta.conn_config.api_key`、`meta.conn_config.model` 字段,并保存文件。
|
||||
* **id**:Coze Studio 中的模型 ID,由开发者自行定义,必须是非 0 的整数,且全局唯一。模型上线后请勿修改模型 id 。
|
||||
* **meta.conn_config.api_key**:模型服务的 API Key,在本示例中为火山方舟的 API Key,获取方式可参考[获取火山方舟 API Key](https://www.volcengine.com/docs/82379/1541594)。
|
||||
* **meta.conn_config.model**:模型服务的 model ID,在本示例中为火山方舟 doubao-seed-1.6 模型接入点的 Endpoint ID,获取方式可参考[获取 Endpoint ID](https://www.volcengine.com/docs/82379/1099522)。
|
||||
* **id**:Coze Studio 中的模型 ID,由开发者自行定义,必须是非 0 的整数,且全局唯一。智能体或工作流根据模型 ID 来调用模型。对于已上线的模型,请勿修改模型 ID,否则可能导致模型调用失败。
|
||||
* **meta.conn_config.api_key**:模型服务的 API Key,在本示例中为火山方舟的 API Key,获取方式可参考[获取火山方舟 API Key](https://www.volcengine.com/docs/82379/1541594) 或[获取 Byteplus ModelArk API Key](https://docs.byteplus.com/en/docs/ModelArk/1361424?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source)。
|
||||
* **meta.conn_config.model**:模型服务的 Model name,在本示例中为火山方舟的 Model ID 或 Endpoint ID,获取方式可参考 [获取火山方舟 Model ID](https://www.volcengine.com/docs/82379/1513689) / [获取火山方舟 Endpoint ID](https://www.volcengine.com/docs/82379/1099522),或者参考[获取 BytePlus ModelArk Model ID](https://docs.byteplus.com/en/docs/ModelArk/model_id?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source) / [获取 BytePlus ModelArk Endpoint ID](https://docs.byteplus.com/en/docs/ModelArk/1099522?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source)。
|
||||
> 中国境内用户可选用火山方舟(Volcengine Ark),非中国境内的用户则可用 BytePlus ModelArk。
|
||||
3. 部署并启动服务。
|
||||
首次部署并启动 Coze Studio 需要拉取镜像、构建本地镜像,可能耗时较久,请耐心等待。部署过程中,你会看到以下日志信息。如果看到提示 "Container coze-server Started",表示 Coze Studio 服务已成功启动。
|
||||
```Bash
|
||||
@ -78,6 +78,8 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript,
|
||||
|
||||
4. 启动服务后,通过浏览器访问 `http://localhost:8888/` 即可打开 Coze Studio。
|
||||
|
||||
> [!WARNING]
|
||||
> 如果要将 Coze Studio 部署到公网环境,建议在部署前评估整体评估安全风险,例如账号注册功能、工作流代码节点 Python执行环境、Coze Server 监听地址配置、SSRF 和部分 API 水平越权的风险,并采取相应防护措施。详细信息可参考[快速开始](https://github.com/coze-dev/coze-studio/wiki/2.-%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B#%E5%85%AC%E7%BD%91%E5%AE%89%E5%85%A8%E9%A3%8E%E9%99%A9)。
|
||||
## 开发指南
|
||||
|
||||
* **项目配置**:
|
||||
@ -106,7 +108,7 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript,
|
||||
## 社区贡献
|
||||
我们欢迎社区贡献,贡献指南参见 [CONTRIBUTING](https://github.com/coze-dev/coze-studio/blob/main/CONTRIBUTING.md) 和 [Code of conduct](https://github.com/coze-dev/coze-studio/blob/main/CODE_OF_CONDUCT.md),期待您的贡献!
|
||||
## 安全与隐私
|
||||
如果你在该项目中发现潜在的安全问题,或你认为可能发现了安全问题,请通过我们的[安全中心](https://security.bytedance.com/src) 或[漏洞报告邮箱](sec@bytedance.com)通知字节跳动安全团队。
|
||||
如果你在该项目中发现潜在的安全问题,或你认为可能发现了安全问题,请通过我们的[安全中心](https://security.bytedance.com/src) 或[漏洞报告邮箱](mailto:sec@bytedance.com)通知字节跳动安全团队。
|
||||
请**不要**创建公开的 GitHub Issue。
|
||||
## 加入社区
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
@ -47,9 +48,12 @@ import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
plugin2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/playground"
|
||||
pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
@ -59,30 +63,29 @@ import (
|
||||
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
plugin3 "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge/knowledgemock"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
|
||||
entity3 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
entity5 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin/pluginmock"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search/searchmock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
mockvar "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable/varmock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
|
||||
mockvar "github.com/coze-dev/coze-studio/backend/domain/workflow/variable/varmock"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
@ -110,23 +113,23 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
type wfTestRunner struct {
|
||||
t *testing.T
|
||||
h *server.Hertz
|
||||
ctrl *gomock.Controller
|
||||
idGen *mock.MockIDGenerator
|
||||
search *searchmock.MockNotifier
|
||||
appVarS *mockvar.MockStore
|
||||
userVarS *mockvar.MockStore
|
||||
varGetter *mockvar.MockVariablesMetaGetter
|
||||
modelManage *mockmodel.MockManager
|
||||
plugin *mockPlugin.MockPluginService
|
||||
tos *storageMock.MockStorage
|
||||
knowledge *knowledgemock.MockKnowledgeOperator
|
||||
database *databasemock.MockDatabaseOperator
|
||||
pluginSrv *pluginmock.MockService
|
||||
internalModel *testutil.UTChatModel
|
||||
ctx context.Context
|
||||
closeFn func()
|
||||
t *testing.T
|
||||
h *server.Hertz
|
||||
ctrl *gomock.Controller
|
||||
idGen *mock.MockIDGenerator
|
||||
appVarS *mockvar.MockStore
|
||||
userVarS *mockvar.MockStore
|
||||
varGetter *mockvar.MockVariablesMetaGetter
|
||||
modelManage *mockmodel.MockManager
|
||||
plugin *mockPlugin.MockPluginService
|
||||
tos *storageMock.MockStorage
|
||||
knowledge *knowledgemock.MockKnowledge
|
||||
database *databasemock.MockDatabase
|
||||
pluginSrv *pluginmock.MockPluginService
|
||||
internalModel *testutil.UTChatModel
|
||||
publishPatcher *mockey.Mocker
|
||||
ctx context.Context
|
||||
closeFn func()
|
||||
}
|
||||
|
||||
var req2URL = map[reflect.Type]string{
|
||||
@ -247,13 +250,10 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
|
||||
mockTos := storageMock.NewMockStorage(ctrl)
|
||||
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
|
||||
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
|
||||
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel, nil)
|
||||
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
|
||||
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
|
||||
|
||||
mockSearchNotify := searchmock.NewMockNotifier(ctrl)
|
||||
mockey.Mock(crosssearch.GetNotifier).Return(mockSearchNotify).Build()
|
||||
mockSearchNotify.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
|
||||
|
||||
mockCU := mockCrossUser.NewMockUser(ctrl)
|
||||
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
|
||||
@ -276,12 +276,12 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
|
||||
mPlugin := mockPlugin.NewMockPluginService(ctrl)
|
||||
|
||||
mockKwOperator := knowledgemock.NewMockKnowledgeOperator(ctrl)
|
||||
knowledge.SetKnowledgeOperator(mockKwOperator)
|
||||
mockKwOperator := knowledgemock.NewMockKnowledge(ctrl)
|
||||
crossknowledge.SetDefaultSVC(mockKwOperator)
|
||||
|
||||
mockModelManage := mockmodel.NewMockManager(ctrl)
|
||||
mockModelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(nil, nil, nil).AnyTimes()
|
||||
m3 := mockey.Mock(model.GetManager).Return(mockModelManage).Build()
|
||||
m3 := mockey.Mock(crossmodelmgr.DefaultSVC).Return(mockModelManage).Build()
|
||||
|
||||
m := mockey.Mock(crossuser.DefaultSVC).Return(mockCU).Build()
|
||||
m1 := mockey.Mock(ctxutil.GetApiAuthFromCtx).Return(&entity2.ApiKey{
|
||||
@ -291,17 +291,18 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
m4 := mockey.Mock(ctxutil.MustGetUIDFromCtx).Return(int64(1)).Build()
|
||||
m5 := mockey.Mock(ctxutil.GetUIDFromCtx).Return(ptr.Of(int64(1))).Build()
|
||||
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
database.SetDatabaseOperator(mockDatabaseOperator)
|
||||
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
|
||||
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
|
||||
|
||||
mockPluginSrv := pluginmock.NewMockService(ctrl)
|
||||
plugin.SetPluginService(mockPluginSrv)
|
||||
mockPluginSrv := pluginmock.NewMockPluginService(ctrl)
|
||||
crossplugin.SetDefaultSVC(mockPluginSrv)
|
||||
|
||||
mockey.Mock((*user.UserApplicationService).MGetUserBasicInfo).Return(&playground.MGetUserBasicInfoResponse{
|
||||
UserBasicInfoMap: make(map[string]*playground.UserBasicInfo),
|
||||
}, nil).Build()
|
||||
|
||||
f := func() {
|
||||
publishPatcher.UnPatch()
|
||||
m.UnPatch()
|
||||
m1.UnPatch()
|
||||
m2.UnPatch()
|
||||
@ -314,23 +315,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
||||
}
|
||||
|
||||
return &wfTestRunner{
|
||||
t: t,
|
||||
h: h,
|
||||
ctrl: ctrl,
|
||||
idGen: mockIDGen,
|
||||
search: mockSearchNotify,
|
||||
appVarS: mockGlobalAppVarStore,
|
||||
userVarS: mockGlobalUserVarStore,
|
||||
varGetter: mockVarGetter,
|
||||
modelManage: mockModelManage,
|
||||
plugin: mPlugin,
|
||||
tos: mockTos,
|
||||
knowledge: mockKwOperator,
|
||||
database: mockDatabaseOperator,
|
||||
internalModel: utChatModel,
|
||||
ctx: context.Background(),
|
||||
closeFn: f,
|
||||
pluginSrv: mockPluginSrv,
|
||||
t: t,
|
||||
h: h,
|
||||
ctrl: ctrl,
|
||||
idGen: mockIDGen,
|
||||
appVarS: mockGlobalAppVarStore,
|
||||
userVarS: mockGlobalUserVarStore,
|
||||
varGetter: mockVarGetter,
|
||||
modelManage: mockModelManage,
|
||||
plugin: mPlugin,
|
||||
tos: mockTos,
|
||||
knowledge: mockKwOperator,
|
||||
database: mockDatabaseOperator,
|
||||
internalModel: utChatModel,
|
||||
ctx: context.Background(),
|
||||
closeFn: f,
|
||||
pluginSrv: mockPluginSrv,
|
||||
publishPatcher: publishPatcher,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2253,24 +2254,22 @@ func TestNodeWithBatchEnabled(t *testing.T) {
|
||||
})
|
||||
e := r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
assert.Equal(t, map[string]any{
|
||||
outputMap := mustUnmarshalToMap(t, e.output)
|
||||
assert.Contains(t, outputMap["output"], map[string]any{
|
||||
"output": []any{
|
||||
map[string]any{
|
||||
"output": []any{
|
||||
"answer",
|
||||
"for index 0",
|
||||
},
|
||||
"input": "answer。for index 0",
|
||||
},
|
||||
map[string]any{
|
||||
"output": []any{
|
||||
"answer",
|
||||
"for index 1",
|
||||
},
|
||||
"input": "answer,for index 1",
|
||||
},
|
||||
"answer",
|
||||
"for index 0",
|
||||
},
|
||||
}, mustUnmarshalToMap(t, e.output))
|
||||
"input": "answer。for index 0",
|
||||
})
|
||||
assert.Contains(t, outputMap["output"], map[string]any{
|
||||
"output": []any{
|
||||
"answer",
|
||||
"for index 1",
|
||||
},
|
||||
"input": "answer,for index 1",
|
||||
})
|
||||
assert.Equal(t, 2, len(outputMap["output"].([]any)))
|
||||
e.tokenEqual(10, 12)
|
||||
|
||||
// verify this workflow has previously succeeded a test run
|
||||
@ -2874,9 +2873,8 @@ func TestLLMWithSkills(t *testing.T) {
|
||||
{ID: int64(7509353598782816256), Operation: operation},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
|
||||
|
||||
plugin.SetPluginService(pluginSrv)
|
||||
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
|
||||
crossplugin.SetDefaultSVC(pluginSrv)
|
||||
|
||||
t.Run("llm with plugin tool", func(t *testing.T) {
|
||||
id := r.load("llm_node_with_skills/llm_node_with_plugin_tool.json")
|
||||
@ -2998,22 +2996,22 @@ func TestLLMWithSkills(t *testing.T) {
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
r.knowledge.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(&knowledge.RetrieveResponse{
|
||||
Slices: []*knowledge.Slice{
|
||||
{DocumentID: "1", Output: "天安门广场 :中国政治文化中心,见证了近现代重大历史事件"},
|
||||
{DocumentID: "2", Output: "八达岭长城 :明代长城的精华段,被誉为“不到长城非好汉"},
|
||||
},
|
||||
}, nil).AnyTimes()
|
||||
// r.knowledge.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(&knowledge.RetrieveResponse{
|
||||
// RetrieveSlices: []*knowledge.RetrieveSlice{
|
||||
// {Slice: &knowledge.Slice{DocumentID: 1, Output: "天安门广场 :中国政治文化中心,见证了近现代重大历史事件"}, Score: 0.9},
|
||||
// {Slice: &knowledge.Slice{DocumentID: 2, Output: "八达岭长城 :明代长城的精华段,被誉为“不到长城非好汉"}, Score: 0.8},
|
||||
// },
|
||||
// }, nil).AnyTimes()
|
||||
|
||||
t.Run("llm node with knowledge skill", func(t *testing.T) {
|
||||
id := r.load("llm_node_with_skills/llm_with_knowledge_skill.json")
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": "北京有哪些著名的景点",
|
||||
})
|
||||
e := r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
assert.Equal(t, `{"output":"八达岭长城 :明代长城的精华段,被誉为“不到长城非好汉"}`, e.output)
|
||||
})
|
||||
// t.Run("llm node with knowledge skill", func(t *testing.T) {
|
||||
// id := r.load("llm_node_with_skills/llm_with_knowledge_skill.json")
|
||||
// exeID := r.testRun(id, map[string]string{
|
||||
// "input": "北京有哪些著名的景点",
|
||||
// })
|
||||
// e := r.getProcess(id, exeID)
|
||||
// e.assertSuccess()
|
||||
// assert.Equal(t, `{"output":"八达岭长城 :明代长城的精华段,被誉为“不到长城非好汉"}`, e.output)
|
||||
// })
|
||||
})
|
||||
}
|
||||
|
||||
@ -3412,8 +3410,8 @@ func TestGetLLMNodeFCSettingsDetailAndMerged(t *testing.T) {
|
||||
{ID: 123, Operation: operation},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
|
||||
plugin.SetPluginService(pluginSrv)
|
||||
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
|
||||
crossplugin.SetDefaultSVC(pluginSrv)
|
||||
|
||||
t.Run("plugin tool info ", func(t *testing.T) {
|
||||
fcSettingDetailReq := &workflow.GetLLMNodeFCSettingDetailRequest{
|
||||
@ -3529,8 +3527,8 @@ func TestGetLLMNodeFCSettingsDetailAndMerged(t *testing.T) {
|
||||
{ID: 123, Operation: operation},
|
||||
}, nil).AnyTimes()
|
||||
|
||||
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
|
||||
plugin.SetPluginService(pluginSrv)
|
||||
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
|
||||
crossplugin.SetDefaultSVC(pluginSrv)
|
||||
|
||||
t.Run("plugin merge", func(t *testing.T) {
|
||||
fcSettingMergedReq := &workflow.GetLLMNodeFCSettingsMergedRequest{
|
||||
@ -3696,7 +3694,7 @@ func TestCopyWorkflow(t *testing.T) {
|
||||
|
||||
_, err := appworkflow.GetWorkflowDomainSVC().Get(context.Background(), &vo.GetPolicy{
|
||||
ID: wid,
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
CommitID: "",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
@ -3758,7 +3756,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
|
||||
|
||||
wf, err = appworkflow.GetWorkflowDomainSVC().Get(context.Background(), &vo.GetPolicy{
|
||||
ID: 100100100100,
|
||||
QType: vo.FromSpecificVersion,
|
||||
QType: workflowModel.FromSpecificVersion,
|
||||
Version: version,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@ -4038,7 +4036,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
mockey.PatchConvey("copy with subworkflow, subworkflow with external resource ", t, func() {
|
||||
var copiedIDs = make([]int64, 0)
|
||||
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
|
||||
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
|
||||
var ignoreIDs = map[int64]bool{
|
||||
7515027325977624576: true,
|
||||
7515027249628708864: true,
|
||||
@ -4046,15 +4044,15 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
7515027150387281920: true,
|
||||
7515027091302121472: true,
|
||||
}
|
||||
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
|
||||
if ignoreIDs[event.WorkflowID] {
|
||||
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
|
||||
if ignoreIDs[workflowID] {
|
||||
return nil
|
||||
}
|
||||
wf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
|
||||
ID: event.WorkflowID,
|
||||
QType: vo.FromLatestVersion,
|
||||
ID: workflowID,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
copiedIDs = append(copiedIDs, event.WorkflowID)
|
||||
copiedIDs = append(copiedIDs, workflowID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "v0.0.1", wf.Version)
|
||||
canvas := &vo.Canvas{}
|
||||
@ -4094,7 +4092,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
subWf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
|
||||
ID: wfId,
|
||||
QType: vo.FromLatestVersion,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
subworkflowCanvas := &vo.Canvas{}
|
||||
@ -4143,7 +4141,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
|
||||
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||
|
||||
appID := "7513788954458456064"
|
||||
appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
|
||||
@ -4186,21 +4184,21 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
mockey.PatchConvey("copy only with external resource", t, func() {
|
||||
var copiedIDs = make([]int64, 0)
|
||||
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
|
||||
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
|
||||
var ignoreIDs = map[int64]bool{
|
||||
7516518409656336384: true,
|
||||
7516516198096306176: true,
|
||||
}
|
||||
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
|
||||
if ignoreIDs[event.WorkflowID] {
|
||||
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
|
||||
if ignoreIDs[workflowID] {
|
||||
return nil
|
||||
}
|
||||
wf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
|
||||
ID: event.WorkflowID,
|
||||
QType: vo.FromLatestVersion,
|
||||
ID: workflowID,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
|
||||
copiedIDs = append(copiedIDs, event.WorkflowID)
|
||||
copiedIDs = append(copiedIDs, workflowID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "v0.0.1", wf.Version)
|
||||
canvas := &vo.Canvas{}
|
||||
@ -4255,7 +4253,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
|
||||
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||
|
||||
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
|
||||
TargetKnowledgeID: 100100,
|
||||
@ -4296,6 +4294,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
||||
func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||
mockey.PatchConvey("test move workflow", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
r.publishPatcher.UnPatch()
|
||||
defer r.closeFn()
|
||||
vars := map[string]*vo.TypeInfo{
|
||||
"app_v1": {
|
||||
@ -4315,21 +4314,21 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||
r.varGetter.EXPECT().GetAppVariablesMeta(gomock.Any(), gomock.Any(), gomock.Any()).Return(vars, nil).AnyTimes()
|
||||
t.Run("move workflow", func(t *testing.T) {
|
||||
|
||||
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
|
||||
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
|
||||
|
||||
named2Idx := []string{"c1", "c2", "cc1", "main"}
|
||||
callCount := 0
|
||||
initialWf2ID := map[string]int64{}
|
||||
old2newID := map[int64]int64{}
|
||||
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
|
||||
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
|
||||
if callCount <= 3 {
|
||||
initialWf2ID[named2Idx[callCount]] = event.WorkflowID
|
||||
initialWf2ID[named2Idx[callCount]] = workflowID
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
if OpType == crosssearch.Created {
|
||||
if oldID, ok := initialWf2ID[*event.Name]; ok {
|
||||
old2newID[oldID] = event.WorkflowID
|
||||
if op == search.Created {
|
||||
if oldID, ok := initialWf2ID[*r.Name]; ok {
|
||||
old2newID[oldID] = workflowID
|
||||
}
|
||||
}
|
||||
|
||||
@ -4337,7 +4336,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
|
||||
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||
|
||||
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
|
||||
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
|
||||
@ -4455,6 +4454,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
||||
func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
||||
mockey.PatchConvey("test duplicate work", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
r.publishPatcher.UnPatch()
|
||||
defer r.closeFn()
|
||||
|
||||
vars := map[string]*vo.TypeInfo{
|
||||
@ -4474,7 +4474,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
||||
|
||||
r.varGetter.EXPECT().GetAppVariablesMeta(gomock.Any(), gomock.Any(), gomock.Any()).Return(vars, nil).AnyTimes()
|
||||
var copiedIDs = make([]int64, 0)
|
||||
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
|
||||
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
|
||||
var ignoreIDs = map[int64]bool{
|
||||
7515027325977624576: true,
|
||||
7515027249628708864: true,
|
||||
@ -4483,16 +4483,16 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
||||
7515027091302121472: true,
|
||||
7515027325977624579: true,
|
||||
}
|
||||
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
|
||||
if ignoreIDs[event.WorkflowID] {
|
||||
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
|
||||
if ignoreIDs[workflowID] {
|
||||
return nil
|
||||
}
|
||||
copiedIDs = append(copiedIDs, event.WorkflowID)
|
||||
copiedIDs = append(copiedIDs, workflowID)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
|
||||
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
|
||||
|
||||
appIDInt64 := int64(7513788954458456064)
|
||||
|
||||
@ -4660,7 +4660,7 @@ func TestJsonSerializationDeserializationWithWarning(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetAppVariablesFOrSubProcesses(t *testing.T) {
|
||||
func TestSetAppVariablesForSubProcesses(t *testing.T) {
|
||||
mockey.PatchConvey("app variables for sub_process", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
@ -4677,3 +4677,79 @@ func TestSetAppVariablesFOrSubProcesses(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpImplicitDependencies(t *testing.T) {
|
||||
mockey.PatchConvey("test http implicit dependencies", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
r.appVarS.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return("1.0", nil).AnyTimes()
|
||||
|
||||
idStr := r.load("httprequester/http_implicit_dependencies.json")
|
||||
|
||||
r.publish(idStr, "v0.0.1", true)
|
||||
|
||||
runner := mockcode.NewMockRunner(r.ctrl)
|
||||
runner.EXPECT().Run(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *coderunner.RunRequest) (*coderunner.RunResponse, error) {
|
||||
in := request.Params["input"]
|
||||
_ = in
|
||||
result := make(map[string]any)
|
||||
err := sonic.UnmarshalString(in.(string), &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &coderunner.RunResponse{
|
||||
Result: result,
|
||||
}, nil
|
||||
}).AnyTimes()
|
||||
|
||||
code.SetCodeRunner(runner)
|
||||
|
||||
mockey.PatchConvey("test http node implicit dependencies", func() {
|
||||
input := map[string]string{
|
||||
"input": "a",
|
||||
}
|
||||
result, _ := r.openapiSyncRun(idStr, input)
|
||||
|
||||
batchRets := result["batch"].([]any)
|
||||
loopRets := result["loop"].([]any)
|
||||
|
||||
for _, r := range batchRets {
|
||||
assert.Contains(t, []any{
|
||||
"http://echo.apifox.com/anything?aa=1.0&cc=1",
|
||||
"http://echo.apifox.com/anything?aa=1.0&cc=2",
|
||||
}, r)
|
||||
}
|
||||
for _, r := range loopRets {
|
||||
assert.Contains(t, []any{
|
||||
"http://echo.apifox.com/anything?a=1&m=123",
|
||||
"http://echo.apifox.com/anything?a=2&m=123",
|
||||
}, r)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
mockey.PatchConvey("node debug http node implicit dependencies", func() {
|
||||
exeID := r.nodeDebug(idStr, "109387",
|
||||
withNDInput(map[string]string{
|
||||
"__apiInfo_url_87fc7c69536cae843fa7f5113cf0067b": "m",
|
||||
"__apiInfo_url_ac86361e3cd503952e71986dc091fa6f": "a",
|
||||
"__body_bodyData_json_ac86361e3cd503952e71986dc091fa6f": "b",
|
||||
"__body_bodyData_json_f77817a7cf8441279e1cfd8af4eeb1da": "1",
|
||||
}))
|
||||
|
||||
e := r.getProcess(idStr, exeID, withSpecificNodeID("109387"))
|
||||
e.assertSuccess()
|
||||
|
||||
ret := make(map[string]any)
|
||||
err := sonic.UnmarshalString(e.output, &ret)
|
||||
assert.Nil(t, err)
|
||||
err = sonic.UnmarshalString(ret["body"].(string), &ret)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ret["url"].(string), "http://echo.apifox.com/anything?a=a&m=m")
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@ -6837,7 +6837,7 @@ type OpenMessageApi struct {
|
||||
//message content
|
||||
Content string `thrift:"content,4" form:"content" json:"content" query:"content"`
|
||||
//session id
|
||||
ConversationID int64 `thrift:"conversation_id,5" form:"conversation_id" json:"conversation_id" query:"conversation_id"`
|
||||
ConversationID int64 `thrift:"conversation_id,5" form:"conversation_id" json:"conversation_id,string" query:"conversation_id"`
|
||||
// custom field
|
||||
MetaData map[string]string `thrift:"meta_data,6" form:"meta_data" json:"meta_data" query:"meta_data"`
|
||||
//creation time
|
||||
@ -6845,7 +6845,7 @@ type OpenMessageApi struct {
|
||||
//update time
|
||||
UpdatedAt int64 `thrift:"updated_at,8" form:"updated_at" json:"updated_at" query:"updated_at"`
|
||||
// ID of a conversation
|
||||
ChatID int64 `thrift:"chat_id,9" form:"chat_id" json:"chat_id" query:"chat_id"`
|
||||
ChatID int64 `thrift:"chat_id,9" form:"chat_id" json:"chat_id,string" query:"chat_id"`
|
||||
// Content type, text/mix
|
||||
ContentType string `thrift:"content_type,10" form:"content_type" json:"content_type" query:"content_type"`
|
||||
//Message Type answer/question/function_call/tool_response
|
||||
|
||||
@ -204,3 +204,110 @@ type GetAllDatabaseByAppIDRequest struct {
|
||||
type GetAllDatabaseByAppIDResponse struct {
|
||||
Databases []*Database // online databases
|
||||
}
|
||||
|
||||
type SQLParam struct {
|
||||
Value string
|
||||
IsNull bool
|
||||
}
|
||||
type CustomSQLRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SQL string
|
||||
Params []SQLParam
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type Object = map[string]any
|
||||
|
||||
type Response struct {
|
||||
RowNumber *int64
|
||||
Objects []Object
|
||||
}
|
||||
|
||||
type Operator string
|
||||
type ClauseRelation string
|
||||
|
||||
const (
|
||||
ClauseRelationAND ClauseRelation = "and"
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorLesser Operator = "<"
|
||||
OperatorGreaterOrEqual Operator = ">="
|
||||
OperatorLesserOrEqual Operator = "<="
|
||||
OperatorIn Operator = "in"
|
||||
OperatorNotIn Operator = "not_in"
|
||||
OperatorIsNull Operator = "is_null"
|
||||
OperatorIsNotNull Operator = "is_not_null"
|
||||
OperatorLike Operator = "like"
|
||||
OperatorNotLike Operator = "not_like"
|
||||
)
|
||||
|
||||
type ClauseGroup struct {
|
||||
Single *Clause
|
||||
Multi *MultiClause
|
||||
}
|
||||
type Clause struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
}
|
||||
type MultiClause struct {
|
||||
Clauses []*Clause
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type ConditionStr struct {
|
||||
Left string
|
||||
Operator Operator
|
||||
Right any
|
||||
}
|
||||
|
||||
type ConditionGroup struct {
|
||||
Conditions []*ConditionStr
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type QueryRequest struct {
|
||||
DatabaseInfoID int64
|
||||
SelectFields []string
|
||||
Limit int64
|
||||
ConditionGroup *ConditionGroup
|
||||
OrderClauses []*OrderClause
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type OrderClause struct {
|
||||
FieldID string
|
||||
IsAsc bool
|
||||
}
|
||||
type UpdateRequest struct {
|
||||
DatabaseInfoID int64
|
||||
ConditionGroup *ConditionGroup
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
type InsertRequest struct {
|
||||
DatabaseInfoID int64
|
||||
Fields map[string]any
|
||||
IsDebugRun bool
|
||||
UserID string
|
||||
ConnectorID int64
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
@ -124,6 +125,7 @@ type RetrievalStrategy struct {
|
||||
EnableQueryRewrite bool
|
||||
EnableRerank bool
|
||||
EnableNL2SQL bool
|
||||
IsPersonalOnly bool
|
||||
}
|
||||
|
||||
type SelectType int64
|
||||
@ -283,3 +285,69 @@ type CopyKnowledgeResponse struct {
|
||||
type MoveKnowledgeToLibraryRequest struct {
|
||||
KnowledgeID int64
|
||||
}
|
||||
|
||||
type ParseMode string
|
||||
|
||||
const (
|
||||
FastParseMode = "fast_mode"
|
||||
AccurateParseMode = "accurate_mode"
|
||||
)
|
||||
|
||||
type ChunkType string
|
||||
|
||||
const (
|
||||
ChunkTypeDefault ChunkType = "default"
|
||||
ChunkTypeCustom ChunkType = "custom"
|
||||
ChunkTypeLeveled ChunkType = "leveled"
|
||||
)
|
||||
|
||||
type ParsingStrategy struct {
|
||||
ParseMode ParseMode
|
||||
ExtractImage bool
|
||||
ExtractTable bool
|
||||
ImageOCR bool
|
||||
}
|
||||
type ChunkingStrategy struct {
|
||||
ChunkType ChunkType
|
||||
ChunkSize int64
|
||||
Separator string
|
||||
Overlap int64
|
||||
}
|
||||
|
||||
type CreateDocumentRequest struct {
|
||||
KnowledgeID int64
|
||||
ParsingStrategy *ParsingStrategy
|
||||
ChunkingStrategy *ChunkingStrategy
|
||||
FileURL string
|
||||
FileName string
|
||||
FileExtension parser.FileExtension
|
||||
}
|
||||
type CreateDocumentResponse struct {
|
||||
DocumentID int64
|
||||
FileName string
|
||||
FileURL string
|
||||
}
|
||||
|
||||
type DeleteDocumentRequest struct {
|
||||
DocumentID string
|
||||
}
|
||||
|
||||
type DeleteDocumentResponse struct {
|
||||
IsSuccess bool
|
||||
}
|
||||
|
||||
type KnowledgeDetail struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
IconURL string `json:"-"`
|
||||
FormatType int64 `json:"-"`
|
||||
}
|
||||
|
||||
type ListKnowledgeDetailRequest struct {
|
||||
KnowledgeIDs []int64
|
||||
}
|
||||
|
||||
type ListKnowledgeDetailResponse struct {
|
||||
KnowledgeDetails []*KnowledgeDetail
|
||||
}
|
||||
|
||||
24
backend/api/model/crossdomain/modelmgr/modelmgr.go
Normal file
24
backend/api/model/crossdomain/modelmgr/modelmgr.go
Normal file
@ -0,0 +1,24 @@
|
||||
package model
|
||||
|
||||
type LLMParams struct {
|
||||
ModelName string `json:"modelName"`
|
||||
ModelType int64 `json:"modelType"`
|
||||
Prompt string `json:"prompt"` // user prompt
|
||||
Temperature *float64 `json:"temperature"`
|
||||
FrequencyPenalty float64 `json:"frequencyPenalty"`
|
||||
PresencePenalty float64 `json:"presencePenalty"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
TopP *float64 `json:"topP"`
|
||||
TopK *int `json:"topK"`
|
||||
EnableChatHistory bool `json:"enableChatHistory"`
|
||||
SystemPrompt string `json:"systemPrompt"`
|
||||
ResponseFormat ResponseFormat `json:"responseFormat"`
|
||||
}
|
||||
|
||||
type ResponseFormat int64
|
||||
|
||||
const (
|
||||
ResponseFormatText ResponseFormat = 0
|
||||
ResponseFormatMarkdown ResponseFormat = 1
|
||||
ResponseFormatJSON ResponseFormat = 2
|
||||
)
|
||||
@ -255,6 +255,9 @@ func (op *Openapi3Operation) ToEinoSchemaParameterInfo(ctx context.Context) (map
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if paramInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := result[paramName]; ok {
|
||||
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramName)
|
||||
75
backend/api/model/crossdomain/plugin/workflow.go
Normal file
75
backend/api/model/crossdomain/plugin/workflow.go
Normal file
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
)
|
||||
|
||||
type ToolsInfoRequest struct {
|
||||
PluginEntity PluginEntity
|
||||
ToolIDs []int64
|
||||
IsDraft bool
|
||||
}
|
||||
|
||||
type PluginEntity struct {
|
||||
PluginID int64
|
||||
PluginVersion *string // nil or "0" means draft, "" means latest/online version, otherwise is specific version
|
||||
}
|
||||
|
||||
type ToolsInfoResponse struct {
|
||||
PluginID int64
|
||||
SpaceID int64
|
||||
Version string
|
||||
PluginName string
|
||||
Description string
|
||||
IconURL string
|
||||
PluginType int64
|
||||
ToolInfoList map[int64]ToolInfoW
|
||||
LatestVersion *string
|
||||
IsOfficial bool
|
||||
AppID int64
|
||||
}
|
||||
|
||||
type ToolInfoW struct {
|
||||
ToolName string
|
||||
ToolID int64
|
||||
Description string
|
||||
DebugExample *DebugExample
|
||||
|
||||
Inputs []*workflow.APIParameter
|
||||
Outputs []*workflow.APIParameter
|
||||
}
|
||||
|
||||
type DebugExample struct {
|
||||
ReqExample string
|
||||
RespExample string
|
||||
}
|
||||
|
||||
type ToolsInvokableRequest struct {
|
||||
PluginEntity PluginEntity
|
||||
ToolsInvokableInfo map[int64]*ToolsInvokableInfo
|
||||
IsDraft bool
|
||||
}
|
||||
|
||||
type WorkflowAPIParameters = []*workflow.APIParameter
|
||||
|
||||
type ToolsInvokableInfo struct {
|
||||
ToolID int64
|
||||
RequestAPIParametersConfig WorkflowAPIParameters
|
||||
ResponseAPIParametersConfig WorkflowAPIParameters
|
||||
}
|
||||
@ -23,7 +23,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
)
|
||||
|
||||
type AgentRuntime struct {
|
||||
|
||||
@ -14,7 +14,15 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package vo
|
||||
package workflow
|
||||
|
||||
type Locator uint8
|
||||
|
||||
const (
|
||||
FromDraft Locator = iota
|
||||
FromSpecificVersion
|
||||
FromLatestVersion
|
||||
)
|
||||
|
||||
type ExecuteConfig struct {
|
||||
ID int64
|
||||
@ -22,7 +22,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/openauth"
|
||||
"github.com/coze-dev/coze-studio/backend/application/template"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/app"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
|
||||
@ -39,18 +39,19 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/application/upload"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
"github.com/coze-dev/coze-studio/backend/application/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
agentrunImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/agentrun"
|
||||
connectorImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/connector"
|
||||
conversationImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/conversation"
|
||||
@ -59,12 +60,14 @@ import (
|
||||
dataCopyImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/datacopy"
|
||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/knowledge"
|
||||
messageImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/message"
|
||||
modelmgrImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/modelmgr"
|
||||
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
|
||||
searchImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/search"
|
||||
singleagentImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/singleagent"
|
||||
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
|
||||
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
implEventbus "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
)
|
||||
@ -130,7 +133,7 @@ func Init(ctx context.Context) (err error) {
|
||||
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
|
||||
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
|
||||
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
|
||||
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC))
|
||||
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC, infra.TOSClient))
|
||||
crossvariables.SetDefaultSVC(variablesImpl.InitDomainService(primaryServices.memorySVC.VariablesDomainSVC))
|
||||
crossworkflow.SetDefaultSVC(workflowImpl.InitDomainService(primaryServices.workflowSVC.DomainSVC))
|
||||
crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC))
|
||||
@ -140,6 +143,7 @@ func Init(ctx context.Context) (err error) {
|
||||
crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC))
|
||||
crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra))
|
||||
crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC))
|
||||
crossmodelmgr.SetDefaultSVC(modelmgrImpl.InitDomainService(infra.ModelMgr, nil))
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -188,7 +192,9 @@ func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*pr
|
||||
|
||||
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
|
||||
|
||||
knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
|
||||
knowledgeSVC, err := knowledge.InitService(ctx,
|
||||
basicServices.toKnowledgeServiceComponents(memorySVC),
|
||||
basicServices.eventbus.resourceEventBus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -252,14 +258,19 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
|
||||
|
||||
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
|
||||
return &knowledge.ServiceComponents{
|
||||
DB: b.infra.DB,
|
||||
IDGenSVC: b.infra.IDGenSVC,
|
||||
Storage: b.infra.TOSClient,
|
||||
RDB: memoryService.RDBDomainSVC,
|
||||
ImageX: b.infra.ImageXClient,
|
||||
ES: b.infra.ESClient,
|
||||
EventBus: b.eventbus.resourceEventBus,
|
||||
CacheCli: b.infra.CacheCli,
|
||||
DB: b.infra.DB,
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
RDB: memoryService.RDBDomainSVC,
|
||||
Producer: b.infra.KnowledgeEventProducer,
|
||||
SearchStoreManagers: b.infra.SearchStoreManagers,
|
||||
ParseManager: b.infra.ParserManager,
|
||||
Storage: b.infra.TOSClient,
|
||||
Rewriter: b.infra.Rewriter,
|
||||
Reranker: b.infra.Reranker,
|
||||
NL2Sql: b.infra.NL2SQL,
|
||||
OCR: b.infra.OCR,
|
||||
CacheCli: b.infra.CacheCli,
|
||||
ModelFactory: chatmodel.NewDefaultFactory(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,19 +287,19 @@ func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {
|
||||
|
||||
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
|
||||
return &workflow.ServiceComponents{
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
DB: b.infra.DB,
|
||||
Cache: b.infra.CacheCli,
|
||||
Tos: b.infra.TOSClient,
|
||||
ImageX: b.infra.ImageXClient,
|
||||
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
|
||||
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
||||
PluginDomainSVC: pluginSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
||||
ModelManager: b.infra.ModelMgr,
|
||||
DomainNotifier: b.eventbus.resourceEventBus,
|
||||
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
||||
CodeRunner: b.infra.CodeRunner,
|
||||
IDGen: b.infra.IDGenSVC,
|
||||
DB: b.infra.DB,
|
||||
Cache: b.infra.CacheCli,
|
||||
Tos: b.infra.TOSClient,
|
||||
ImageX: b.infra.ImageXClient,
|
||||
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
|
||||
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
|
||||
PluginDomainSVC: pluginSVC.DomainSVC,
|
||||
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
|
||||
DomainNotifier: b.eventbus.resourceEventBus,
|
||||
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
|
||||
CodeRunner: b.infra.CodeRunner,
|
||||
WorkflowBuildInChatModel: b.infra.WorkflowBuildInChatModel,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -18,40 +18,86 @@ package appinfra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genai"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
|
||||
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
|
||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type AppDependencies struct {
|
||||
DB *gorm.DB
|
||||
CacheCli cache.Cmdable
|
||||
IDGenSVC idgen.IDGenerator
|
||||
ESClient es.Client
|
||||
ImageXClient imagex.ImageX
|
||||
TOSClient storage.Storage
|
||||
ResourceEventProducer eventbus.Producer
|
||||
AppEventProducer eventbus.Producer
|
||||
ModelMgr modelmgr.Manager
|
||||
CodeRunner coderunner.Runner
|
||||
DB *gorm.DB
|
||||
CacheCli cache.Cmdable
|
||||
IDGenSVC idgen.IDGenerator
|
||||
ESClient es.Client
|
||||
ImageXClient imagex.ImageX
|
||||
TOSClient storage.Storage
|
||||
ResourceEventProducer eventbus.Producer
|
||||
AppEventProducer eventbus.Producer
|
||||
KnowledgeEventProducer eventbus.Producer
|
||||
ModelMgr modelmgr.Manager
|
||||
CodeRunner coderunner.Runner
|
||||
OCR ocr.OCR
|
||||
ParserManager parser.Manager
|
||||
SearchStoreManagers []searchstore.Manager
|
||||
Reranker rerank.Reranker
|
||||
Rewriter messages2query.MessagesToQuery
|
||||
NL2SQL nl2sql.NL2SQL
|
||||
WorkflowBuildInChatModel chatmodel.BaseChatModel
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
@ -60,55 +106,195 @@ func Init(ctx context.Context) (*AppDependencies, error) {
|
||||
|
||||
deps.DB, err = mysql.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init db failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.CacheCli = redis.New()
|
||||
|
||||
deps.IDGenSVC, err = idgen.New(deps.CacheCli)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init id gen svc failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.ESClient, err = es.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init es client failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.ImageXClient, err = initImageX(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init imagex client failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.TOSClient, err = initTOS(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init tos client failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.ResourceEventProducer, err = initResourceEventBusProducer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init resource event bus producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.AppEventProducer, err = initAppEventProducer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.Reranker = initReranker()
|
||||
|
||||
deps.Rewriter, err = initRewriter(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init rewriter failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.NL2SQL, err = initNL2SQL(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init nl2sql failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.ModelMgr, err = initModelMgr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("init model manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.CodeRunner = initCodeRunner()
|
||||
|
||||
deps.OCR = initOCR()
|
||||
|
||||
imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init search store managers failed, err=%w", err)
|
||||
}
|
||||
|
||||
return deps, nil
|
||||
}
|
||||
|
||||
func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) {
|
||||
// es full text search
|
||||
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
|
||||
|
||||
// vector search
|
||||
mgr, err := getVectorStore(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
|
||||
}
|
||||
|
||||
func initReranker() rerank.Reranker {
|
||||
rerankerType := os.Getenv("RERANK_TYPE")
|
||||
switch rerankerType {
|
||||
case "vikingdb":
|
||||
return vikingReranker.NewReranker(getVikingRerankerConfig())
|
||||
case "rrf":
|
||||
return rrf.NewRRFReranker(0)
|
||||
default:
|
||||
return rrf.NewRRFReranker(0)
|
||||
}
|
||||
}
|
||||
func getVikingRerankerConfig() *vikingReranker.Config {
|
||||
return &vikingReranker.Config{
|
||||
AK: os.Getenv("VIKINGDB_RERANK_AK"),
|
||||
SK: os.Getenv("VIKINGDB_RERANK_SK"),
|
||||
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
|
||||
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
|
||||
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
|
||||
}
|
||||
}
|
||||
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
|
||||
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rewriter, nil
|
||||
}
|
||||
|
||||
func getWorkingDirectory() string {
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||
root = os.Getenv("PWD")
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||
b, err := os.ReadFile(jsonFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m2qMessages []*schema.Message
|
||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||
for i := range m2qMessages {
|
||||
tpl[i] = m2qMessages[i]
|
||||
}
|
||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||
}
|
||||
|
||||
func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) {
|
||||
n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return n2s, nil
|
||||
}
|
||||
|
||||
func initImageX(ctx context.Context) (imagex.ImageX, error) {
|
||||
|
||||
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
|
||||
|
||||
if uploadComponentType != consts.FileUploadComponentTypeImagex {
|
||||
return storage.NewImagex(ctx)
|
||||
}
|
||||
@ -147,6 +333,17 @@ func initAppEventProducer() (eventbus.Producer, error) {
|
||||
return appEventProducer, nil
|
||||
}
|
||||
|
||||
func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
return knowledgeProducer, nil
|
||||
}
|
||||
|
||||
func initCodeRunner() coderunner.Runner {
|
||||
switch typ := os.Getenv(consts.CodeRunnerType); typ {
|
||||
case "sandbox":
|
||||
@ -183,3 +380,337 @@ func initCodeRunner() coderunner.Runner {
|
||||
return direct.NewRunner()
|
||||
}
|
||||
}
|
||||
|
||||
func initOCR() ocr.OCR {
|
||||
var ocr ocr.OCR
|
||||
switch os.Getenv(consts.OCRType) {
|
||||
case "ve":
|
||||
ocrAK := os.Getenv(consts.VeOCRAK)
|
||||
ocrSK := os.Getenv(consts.VeOCRSK)
|
||||
if ocrAK == "" || ocrSK == "" {
|
||||
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
|
||||
}
|
||||
inst := visual.NewInstance()
|
||||
inst.Client.SetAccessKey(ocrAK)
|
||||
inst.Client.SetSecretKey(ocrSK)
|
||||
ocr = veocr.NewOCR(&veocr.Config{Client: inst})
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPOCRAPIURL)
|
||||
client := &http.Client{}
|
||||
ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url})
|
||||
default:
|
||||
// accept ocr not configured
|
||||
}
|
||||
|
||||
return ocr
|
||||
}
|
||||
|
||||
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) {
|
||||
var parserManager parser.Manager
|
||||
parserType := os.Getenv(consts.ParserType)
|
||||
switch parserType {
|
||||
case "builtin", "":
|
||||
parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel)
|
||||
case "paddleocr":
|
||||
url := os.Getenv(consts.PPStructureAPIURL)
|
||||
client := &http.Client{}
|
||||
apiConfig := &ppstructure.APIConfig{
|
||||
Client: client,
|
||||
URL: url,
|
||||
}
|
||||
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel)
|
||||
default:
|
||||
return nil, fmt.Errorf("parser type %s not supported", parserType)
|
||||
}
|
||||
|
||||
return parserManager, nil
|
||||
}
|
||||
|
||||
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
switch vsType {
|
||||
case "milvus":
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
milvusAddr := os.Getenv("MILVUS_ADDR")
|
||||
user := os.Getenv("MILVUS_USER")
|
||||
password := os.Getenv("MILVUS_PASSWORD")
|
||||
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
||||
Address: milvusAddr,
|
||||
Username: user,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
|
||||
Client: mc,
|
||||
Embedding: emb,
|
||||
EnableHybrid: ptr.Of(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
case "vikingdb":
|
||||
var (
|
||||
host = os.Getenv("VIKING_DB_HOST")
|
||||
region = os.Getenv("VIKING_DB_REGION")
|
||||
ak = os.Getenv("VIKING_DB_AK")
|
||||
sk = os.Getenv("VIKING_DB_SK")
|
||||
scheme = os.Getenv("VIKING_DB_SCHEME")
|
||||
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
||||
)
|
||||
if ak == "" || sk == "" {
|
||||
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
||||
}
|
||||
if host == "" {
|
||||
host = "api-vikingdb.volces.com"
|
||||
}
|
||||
if region == "" {
|
||||
region = "cn-beijing"
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var embConfig *vikingdb.VikingEmbeddingConfig
|
||||
if modelName != "" {
|
||||
embName := vikingdb.VikingEmbeddingModelName(modelName)
|
||||
if embName.Dimensions() == 0 {
|
||||
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
||||
}
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
||||
ModelName: embName,
|
||||
ModelVersion: embName.ModelVersion(),
|
||||
DenseWeight: ptr.Of(0.2),
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
|
||||
embConfig = &vikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
EnableHybrid: false,
|
||||
BuiltinEmbedding: builtinEmbedding,
|
||||
}
|
||||
}
|
||||
|
||||
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
||||
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil, // use default config
|
||||
EmbeddingConfig: embConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
var batchSize int
|
||||
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
|
||||
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
|
||||
batchSize = 100
|
||||
} else {
|
||||
batchSize = int(bs)
|
||||
}
|
||||
|
||||
var emb embedding.Embedder
|
||||
|
||||
switch os.Getenv("EMBEDDING_TYPE") {
|
||||
case "openai":
|
||||
var (
|
||||
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
|
||||
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
|
||||
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
|
||||
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
|
||||
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
|
||||
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
|
||||
)
|
||||
|
||||
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
|
||||
}
|
||||
|
||||
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: openAIEmbeddingApiKey,
|
||||
ByAzure: byAzure,
|
||||
BaseURL: openAIEmbeddingBaseURL,
|
||||
APIVersion: openAIEmbeddingApiVersion,
|
||||
Model: openAIEmbeddingModel,
|
||||
// Dimensions: ptr.Of(int(dims)),
|
||||
}
|
||||
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
|
||||
if reqDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(reqDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ark":
|
||||
var (
|
||||
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
|
||||
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
|
||||
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
|
||||
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
||||
// ARK_EMBEDDING_AK will be removed in the future
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if arkEmbeddingAPIType != "" {
|
||||
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
|
||||
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
|
||||
} else {
|
||||
apiType = t
|
||||
}
|
||||
}
|
||||
|
||||
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: func() string {
|
||||
if arkEmbeddingApiKey != "" {
|
||||
return arkEmbeddingApiKey
|
||||
}
|
||||
return arkEmbeddingAK
|
||||
}(),
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
APIType: &apiType,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ollama":
|
||||
var (
|
||||
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
|
||||
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
|
||||
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
|
||||
BaseURL: ollamaEmbeddingBaseURL,
|
||||
Model: ollamaEmbeddingModel,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
case "gemini":
|
||||
var (
|
||||
geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL")
|
||||
geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL")
|
||||
geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY")
|
||||
geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS")
|
||||
geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI
|
||||
geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT")
|
||||
geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION")
|
||||
)
|
||||
|
||||
if len(geminiEmbeddingModel) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingApiKey) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingDims) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required")
|
||||
}
|
||||
if len(geminiEmbeddingBackend) == 0 {
|
||||
return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required")
|
||||
}
|
||||
|
||||
dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr)
|
||||
}
|
||||
|
||||
backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr)
|
||||
}
|
||||
|
||||
geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: geminiEmbeddingApiKey,
|
||||
Backend: genai.Backend(backend),
|
||||
Project: geminiEmbeddingProject,
|
||||
Location: geminiEmbeddingLocation,
|
||||
HTTPOptions: genai.HTTPOptions{
|
||||
BaseURL: geminiEmbeddingBaseURL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
|
||||
Client: geminiCli,
|
||||
Model: geminiEmbeddingModel,
|
||||
OutputDimensionality: ptr.Of(int32(dims)),
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
|
||||
}
|
||||
case "http":
|
||||
var (
|
||||
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
|
||||
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
|
||||
)
|
||||
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
|
||||
}
|
||||
emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
package appinfra
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -33,7 +33,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
)
|
||||
|
||||
func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
|
||||
getEnv := func(key string) string {
|
||||
if val := os.Getenv(envPrefix + key); val != "" {
|
||||
return val
|
||||
@ -99,7 +99,7 @@ func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
|
||||
return nil, false, fmt.Errorf("builtin %s chat model init failed, %w", envPrefix, err)
|
||||
}
|
||||
if bcm != nil {
|
||||
configured = true
|
||||
@ -53,6 +53,10 @@ func (m *OpenapiMessageApplication) GetApiMessageList(ctx context.Context, mr *m
|
||||
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
|
||||
}
|
||||
|
||||
if mr.Limit == nil {
|
||||
mr.Limit = ptr.Of(int64(50))
|
||||
}
|
||||
|
||||
msgListMeta := &entity.ListMeta{
|
||||
ConversationID: currentConversation.ID,
|
||||
AgentID: currentConversation.AgentID,
|
||||
|
||||
@ -18,418 +18,27 @@ package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
netHTTP "net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
|
||||
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||
ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
|
||||
arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
|
||||
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
DB *gorm.DB
|
||||
IDGenSVC idgen.IDGenerator
|
||||
Storage storage.Storage
|
||||
RDB rdb.RDB
|
||||
ImageX imagex.ImageX
|
||||
ES es.Client
|
||||
EventBus search.ResourceEventBus
|
||||
CacheCli cache.Cmdable
|
||||
}
|
||||
type ServiceComponents = knowledgeImpl.KnowledgeSVCConfig
|
||||
|
||||
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
|
||||
ctx := context.Background()
|
||||
func InitService(ctx context.Context, c *ServiceComponents, bus search.ResourceEventBus) (*KnowledgeApplicationService, error) {
|
||||
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(c)
|
||||
|
||||
nameServer := os.Getenv(consts.MQServer)
|
||||
|
||||
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
|
||||
}
|
||||
|
||||
var sManagers []searchstore.Manager
|
||||
|
||||
// es full text search
|
||||
sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES}))
|
||||
|
||||
// vector search
|
||||
mgr, err := getVectorStore(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vector store failed, err=%w", err)
|
||||
}
|
||||
sManagers = append(sManagers, mgr)
|
||||
|
||||
var ocrImpl ocr.OCR
|
||||
switch os.Getenv("OCR_TYPE") {
|
||||
case "ve":
|
||||
ocrAK := os.Getenv("VE_OCR_AK")
|
||||
ocrSK := os.Getenv("VE_OCR_SK")
|
||||
if ocrAK == "" || ocrSK == "" {
|
||||
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
|
||||
}
|
||||
inst := visual.NewInstance()
|
||||
inst.Client.SetAccessKey(ocrAK)
|
||||
inst.Client.SetSecretKey(ocrSK)
|
||||
ocrImpl = veocr.NewOCR(&veocr.Config{Client: inst})
|
||||
case "paddleocr":
|
||||
ppocrURL := os.Getenv("PADDLEOCR_OCR_API_URL")
|
||||
client := &netHTTP.Client{}
|
||||
ocrImpl = veocr.NewPPOCR(&veocr.PPOCRConfig{Client: client, URL: ppocrURL})
|
||||
default:
|
||||
// accept ocr not configured
|
||||
}
|
||||
|
||||
root, err := os.Getwd()
|
||||
if err != nil {
|
||||
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
|
||||
root = os.Getenv("PWD")
|
||||
}
|
||||
|
||||
var rewriter messages2query.MessagesToQuery
|
||||
if rewriterChatModel, _, err := internal.GetBuiltinChatModel(ctx, "M2Q_"); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
filePath := filepath.Join(root, "resources/conf/prompt/messages_to_query_template_jinja2.json")
|
||||
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rewriter, err = builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var n2s nl2sql.NL2SQL
|
||||
if n2sChatModel, _, err := internal.GetBuiltinChatModel(ctx, "NL2SQL_"); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
filePath := filepath.Join(root, "resources/conf/prompt/nl2sql_template_jinja2.json")
|
||||
n2sTemplate, err := readJinja2PromptTemplate(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n2s, err = builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
imageAnnoChatModel, configured, err := internal.GetBuiltinChatModel(ctx, "IA_")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(&knowledgeImpl.KnowledgeSVCConfig{
|
||||
DB: c.DB,
|
||||
IDGen: c.IDGenSVC,
|
||||
RDB: c.RDB,
|
||||
Producer: knowledgeProducer,
|
||||
SearchStoreManagers: sManagers,
|
||||
ParseManager: builtinParser.NewManager(c.Storage, ocrImpl, imageAnnoChatModel), // default builtin
|
||||
Storage: c.Storage,
|
||||
Rewriter: rewriter,
|
||||
Reranker: rrf.NewRRFReranker(0), // default rrf
|
||||
NL2Sql: n2s,
|
||||
OCR: ocrImpl,
|
||||
CacheCli: c.CacheCli,
|
||||
IsAutoAnnotationSupported: configured,
|
||||
ModelFactory: chatmodelImpl.NewDefaultFactory(),
|
||||
})
|
||||
|
||||
if err = eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
|
||||
if err := eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
|
||||
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
|
||||
}
|
||||
|
||||
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
|
||||
KnowledgeSVC.eventBus = c.EventBus
|
||||
KnowledgeSVC.eventBus = bus
|
||||
KnowledgeSVC.storage = c.Storage
|
||||
return KnowledgeSVC, nil
|
||||
}
|
||||
|
||||
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
|
||||
vsType := os.Getenv("VECTOR_STORE_TYPE")
|
||||
|
||||
switch vsType {
|
||||
case "milvus":
|
||||
cctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
milvusAddr := os.Getenv("MILVUS_ADDR")
|
||||
mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{Address: milvusAddr})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{
|
||||
Client: mc,
|
||||
Embedding: emb,
|
||||
EnableHybrid: ptr.Of(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
case "vikingdb":
|
||||
var (
|
||||
host = os.Getenv("VIKING_DB_HOST")
|
||||
region = os.Getenv("VIKING_DB_REGION")
|
||||
ak = os.Getenv("VIKING_DB_AK")
|
||||
sk = os.Getenv("VIKING_DB_SK")
|
||||
scheme = os.Getenv("VIKING_DB_SCHEME")
|
||||
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
|
||||
)
|
||||
if ak == "" || sk == "" {
|
||||
return nil, fmt.Errorf("invalid vikingdb ak / sk")
|
||||
}
|
||||
if host == "" {
|
||||
host = "api-vikingdb.volces.com"
|
||||
}
|
||||
if region == "" {
|
||||
region = "cn-beijing"
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
var embConfig *ssvikingdb.VikingEmbeddingConfig
|
||||
if modelName != "" {
|
||||
embName := ssvikingdb.VikingEmbeddingModelName(modelName)
|
||||
if embName.Dimensions() == 0 {
|
||||
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
|
||||
}
|
||||
embConfig = &ssvikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: true,
|
||||
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
|
||||
ModelName: embName,
|
||||
ModelVersion: embName.ModelVersion(),
|
||||
DenseWeight: ptr.Of(0.2),
|
||||
BuiltinEmbedding: nil,
|
||||
}
|
||||
} else {
|
||||
builtinEmbedding, err := getEmbedding(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
|
||||
}
|
||||
|
||||
embConfig = &ssvikingdb.VikingEmbeddingConfig{
|
||||
UseVikingEmbedding: false,
|
||||
EnableHybrid: false,
|
||||
BuiltinEmbedding: builtinEmbedding,
|
||||
}
|
||||
}
|
||||
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
|
||||
mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{
|
||||
Service: svc,
|
||||
IndexingConfig: nil, // use default config
|
||||
EmbeddingConfig: embConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
|
||||
}
|
||||
}
|
||||
|
||||
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
|
||||
var batchSize int
|
||||
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
|
||||
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
|
||||
batchSize = 100
|
||||
} else {
|
||||
batchSize = int(bs)
|
||||
}
|
||||
|
||||
var emb embedding.Embedder
|
||||
|
||||
switch os.Getenv("EMBEDDING_TYPE") {
|
||||
case "openai":
|
||||
var (
|
||||
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
|
||||
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
|
||||
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
|
||||
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
|
||||
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
|
||||
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
|
||||
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
|
||||
)
|
||||
|
||||
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
|
||||
}
|
||||
|
||||
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
openAICfg := &openai.EmbeddingConfig{
|
||||
APIKey: openAIEmbeddingApiKey,
|
||||
ByAzure: byAzure,
|
||||
BaseURL: openAIEmbeddingBaseURL,
|
||||
APIVersion: openAIEmbeddingApiVersion,
|
||||
Model: openAIEmbeddingModel,
|
||||
// Dimensions: ptr.Of(int(dims)),
|
||||
}
|
||||
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
|
||||
if reqDims > 0 {
|
||||
// some openai model not support request dims
|
||||
openAICfg.Dimensions = ptr.Of(int(reqDims))
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ark":
|
||||
var (
|
||||
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
|
||||
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
|
||||
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
|
||||
// deprecated: use ARK_EMBEDDING_API_KEY instead
|
||||
// ARK_EMBEDDING_AK will be removed in the future
|
||||
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
|
||||
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
|
||||
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
apiType := ark.APITypeText
|
||||
if arkEmbeddingAPIType != "" {
|
||||
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
|
||||
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
|
||||
} else {
|
||||
apiType = t
|
||||
}
|
||||
}
|
||||
|
||||
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
|
||||
APIKey: func() string {
|
||||
if arkEmbeddingApiKey != "" {
|
||||
return arkEmbeddingApiKey
|
||||
}
|
||||
return arkEmbeddingAK
|
||||
}(),
|
||||
Model: arkEmbeddingModel,
|
||||
BaseURL: arkEmbeddingBaseURL,
|
||||
APIType: &apiType,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "ollama":
|
||||
var (
|
||||
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
|
||||
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
|
||||
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
|
||||
)
|
||||
|
||||
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
|
||||
}
|
||||
|
||||
emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{
|
||||
BaseURL: ollamaEmbeddingBaseURL,
|
||||
Model: ollamaEmbeddingModel,
|
||||
}, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
case "http":
|
||||
var (
|
||||
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
|
||||
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
|
||||
)
|
||||
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
|
||||
}
|
||||
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
|
||||
}
|
||||
|
||||
return emb, nil
|
||||
}
|
||||
|
||||
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
|
||||
b, err := os.ReadFile(jsonFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m2qMessages []*schema.Message
|
||||
if err = json.Unmarshal(b, &m2qMessages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
|
||||
for i := range m2qMessages {
|
||||
tpl[i] = m2qMessages[i]
|
||||
}
|
||||
return prompt.FromMessages(schema.Jinja2, tpl...), nil
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/application/search"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
databaseEntity "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
|
||||
@ -43,7 +43,7 @@ import (
|
||||
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
|
||||
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
|
||||
@ -199,11 +199,7 @@ func (p *PromptApplicationService) updatePromptResource(ctx context.Context, req
|
||||
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
|
||||
}
|
||||
|
||||
promptResource.Name = req.Prompt.GetName()
|
||||
promptResource.Description = req.Prompt.GetDescription()
|
||||
promptResource.PromptText = req.Prompt.GetPromptText()
|
||||
|
||||
err = p.DomainSVC.UpdatePromptResource(ctx, promptResource)
|
||||
err = p.DomainSVC.UpdatePromptResource(ctx, promptID, req.Prompt.Name, req.Prompt.Description, req.Prompt.PromptText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -354,8 +354,8 @@ func (s *SearchApplicationService) packProjectResource(ctx context.Context, reso
|
||||
logs.CtxErrorf(ctx, "GetDataInfo failed, resID=%d, resType=%d, err=%v",
|
||||
resource.ResID, resource.ResType, err)
|
||||
} else {
|
||||
info.BizResStatus = ptr.Of(*di.status)
|
||||
if *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
|
||||
info.BizResStatus = di.status
|
||||
if di.status != nil && *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
|
||||
actions := slices.Clone(info.Actions)
|
||||
for _, a := range actions {
|
||||
if a.Key == common.ProjectResourceActionKey_Disable {
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
intelligence "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
@ -240,7 +241,7 @@ func (s *SingleAgentApplicationService) fetchWorkflowDetails(ctx context.Context
|
||||
return a.GetWorkflowId()
|
||||
}),
|
||||
},
|
||||
QType: vo.FromLatestVersion,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
}
|
||||
ret, _, err := s.appContext.WorkflowDomainSVC.MGet(ctx, policy)
|
||||
if err != nil {
|
||||
|
||||
@ -38,7 +38,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/playground"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
|
||||
@ -524,8 +524,7 @@ func (s *SingleAgentApplicationService) GetAgentDraftDisplayInfo(ctx context.Con
|
||||
func (s *SingleAgentApplicationService) ValidateAgentDraftAccess(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid == nil {
|
||||
uid = ptr.Of(int64(888))
|
||||
// return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
|
||||
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
|
||||
}
|
||||
|
||||
do, err := s.DomainSVC.GetSingleAgentDraft(ctx, agentID)
|
||||
@ -716,15 +715,18 @@ func (s *SingleAgentApplicationService) GetAgentOnlineInfo(ctx context.Context,
|
||||
AgentID: ptr.Of(si.ObjectID),
|
||||
Command: si.ShortcutCommand,
|
||||
Components: slices.Transform(si.Components, func(i *playground.Components) *bot_common.ShortcutCommandComponent {
|
||||
return &bot_common.ShortcutCommandComponent{
|
||||
sc := &bot_common.ShortcutCommandComponent{
|
||||
Name: i.Name,
|
||||
Description: i.Description,
|
||||
Type: i.InputType.String(),
|
||||
ToolParameter: ptr.Of(i.Parameter),
|
||||
Options: i.Options,
|
||||
DefaultValue: ptr.Of(i.DefaultValue.Value),
|
||||
IsHide: i.Hide,
|
||||
}
|
||||
if i.DefaultValue != nil {
|
||||
sc.DefaultValue = ptr.Of(i.DefaultValue.Value)
|
||||
}
|
||||
return sc
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
61
backend/application/workflow/eventbus.go
Normal file
61
backend/application/workflow/eventbus.go
Normal file
@ -0,0 +1,61 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
)
|
||||
|
||||
var eventBus service.ResourceEventBus
|
||||
|
||||
func setEventBus(bus service.ResourceEventBus) {
|
||||
eventBus = bus
|
||||
}
|
||||
|
||||
func PublishWorkflowResource(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
|
||||
if r == nil {
|
||||
r = &search.ResourceDocument{}
|
||||
}
|
||||
|
||||
r.ResType = common.ResType_Workflow
|
||||
r.ResID = workflowID
|
||||
r.ResSubType = mode
|
||||
|
||||
event := &entity.ResourceDomainEvent{
|
||||
OpType: entity.OpType(op),
|
||||
Resource: r,
|
||||
}
|
||||
|
||||
if op == search.Created {
|
||||
event.Resource.CreateTimeMS = r.CreateTimeMS
|
||||
event.Resource.UpdateTimeMS = r.UpdateTimeMS
|
||||
} else if op == search.Updated {
|
||||
event.Resource.UpdateTimeMS = r.UpdateTimeMS
|
||||
}
|
||||
|
||||
err := eventBus.PublishResources(ctx, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -18,90 +18,88 @@ package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
||||
|
||||
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
|
||||
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
|
||||
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
|
||||
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
|
||||
wfsearch "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/search"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/workflow/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
|
||||
crossvariable "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type ServiceComponents struct {
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache cache.Cmdable
|
||||
DatabaseDomainSVC dbservice.Database
|
||||
VariablesDomainSVC variables.Variables
|
||||
PluginDomainSVC plugin.PluginService
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
ModelManager modelmgr.Manager
|
||||
DomainNotifier search.ResourceEventBus
|
||||
Tos storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
CPStore compose.CheckPointStore
|
||||
CodeRunner coderunner.Runner
|
||||
IDGen idgen.IDGenerator
|
||||
DB *gorm.DB
|
||||
Cache cache.Cmdable
|
||||
DatabaseDomainSVC dbservice.Database
|
||||
VariablesDomainSVC variables.Variables
|
||||
PluginDomainSVC plugin.PluginService
|
||||
KnowledgeDomainSVC knowledge.Knowledge
|
||||
DomainNotifier search.ResourceEventBus
|
||||
Tos storage.Storage
|
||||
ImageX imagex.ImageX
|
||||
CPStore compose.CheckPointStore
|
||||
CodeRunner coderunner.Runner
|
||||
WorkflowBuildInChatModel chatmodel.BaseChatModel
|
||||
}
|
||||
|
||||
func InitService(ctx context.Context, components *ServiceComponents) (*ApplicationService, error) {
|
||||
bcm, ok, err := internal.GetBuiltinChatModel(ctx, "WKR_")
|
||||
func initWorkflowConfig() (workflow.WorkflowConfig, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
||||
configBs, err := os.ReadFile(filepath.Join(wd, "resources/conf/workflow/config.yaml"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var cfg *config.WorkflowConfig
|
||||
err = yaml.Unmarshal(configBs, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func InitService(_ context.Context, components *ServiceComponents) (*ApplicationService, error) {
|
||||
service.RegisterAllNodeAdaptors()
|
||||
|
||||
cfg, err := initWorkflowConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
|
||||
components.Tos, components.CPStore, bcm)
|
||||
components.Tos, components.CPStore, components.WorkflowBuildInChatModel, cfg)
|
||||
|
||||
workflow.SetRepository(workflowRepo)
|
||||
|
||||
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
|
||||
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
|
||||
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
|
||||
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
|
||||
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
|
||||
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
|
||||
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
|
||||
crosscode.SetCodeRunner(components.CodeRunner)
|
||||
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
|
||||
|
||||
code.SetCodeRunner(components.CodeRunner)
|
||||
callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler())
|
||||
setEventBus(components.DomainNotifier)
|
||||
|
||||
SVC.DomainSVC = workflowDomainSVC
|
||||
SVC.ImageX = components.ImageX
|
||||
SVC.TosClient = components.Tos
|
||||
SVC.IDGenerator = components.IDGen
|
||||
|
||||
return SVC, err
|
||||
return SVC, nil
|
||||
}
|
||||
|
||||
@ -25,28 +25,31 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/playground"
|
||||
pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
resource "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
appknowledge "github.com/coze-dev/coze-studio/backend/application/knowledge"
|
||||
appmemory "github.com/coze-dev/coze-studio/backend/application/memory"
|
||||
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/application/user"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
domainWorkflow "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
workflowDomain "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
@ -54,6 +57,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/maps"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
@ -182,6 +186,18 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = PublishWorkflowResource(ctx, id, ptr.Of(int32(wf.Mode)), search.Created, &search.ResourceDocument{
|
||||
Name: &wf.Name,
|
||||
APPID: wf.AppID,
|
||||
SpaceID: &wf.SpaceID,
|
||||
OwnerID: &wf.CreatorID,
|
||||
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
|
||||
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrNotifyWorkflowResourceChangeErr, err)
|
||||
}
|
||||
|
||||
return &workflow.CreateWorkflowResponse{
|
||||
Data: &workflow.CreateWorkflowData{
|
||||
WorkflowID: strconv.FormatInt(id, 10),
|
||||
@ -232,7 +248,8 @@ func (w *ApplicationService) UpdateWorkflowMeta(ctx context.Context, req *workfl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = GetWorkflowDomainSVC().UpdateMeta(ctx, mustParseInt64(req.GetWorkflowID()), &vo.MetaUpdate{
|
||||
workflowID := mustParseInt64(req.GetWorkflowID())
|
||||
err = GetWorkflowDomainSVC().UpdateMeta(ctx, workflowID, &vo.MetaUpdate{
|
||||
Name: req.Name,
|
||||
Desc: req.Desc,
|
||||
IconURI: req.IconURI,
|
||||
@ -240,33 +257,31 @@ func (w *ApplicationService) UpdateWorkflowMeta(ctx context.Context, req *workfl
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
err := PublishWorkflowResource(ctx, workflowID, nil, search.Updated, &search.ResourceDocument{
|
||||
Name: req.Name,
|
||||
UpdateTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "publish update workflow resource failed, workflowID: %d, err: %v", workflowID, err)
|
||||
}
|
||||
})
|
||||
|
||||
return &workflow.UpdateWorkflowMetaResponse{}, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) DeleteWorkflow(ctx context.Context, req *workflow.DeleteWorkflowRequest) (
|
||||
_ *workflow.DeleteWorkflowResponse, err error,
|
||||
) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
_, err = w.BatchDeleteWorkflow(ctx, &workflow.BatchDeleteWorkflowRequest{
|
||||
WorkflowIDList: []string{req.GetWorkflowID()},
|
||||
SpaceID: req.SpaceID,
|
||||
Action: req.Action,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkUserSpace(ctx, ctxutil.MustGetUIDFromCtx(ctx), mustParseInt64(req.GetSpaceID())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{ID: ptr.Of(mustParseInt64(req.GetWorkflowID()))})
|
||||
if err != nil {
|
||||
return &workflow.DeleteWorkflowResponse{
|
||||
Data: &workflow.DeleteWorkflowData{
|
||||
Status: workflow.DeleteStatus_FAIL,
|
||||
},
|
||||
}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &workflow.DeleteWorkflowResponse{
|
||||
@ -276,19 +291,25 @@ func (w *ApplicationService) DeleteWorkflow(ctx context.Context, req *workflow.D
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) deleteWorkflowResource(ctx context.Context, policy *vo.DeletePolicy) error {
|
||||
ids, err := GetWorkflowDomainSVC().Delete(ctx, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
for _, id := range ids {
|
||||
if err = PublishWorkflowResource(ctx, id, nil, search.Deleted, &search.ResourceDocument{}); err != nil {
|
||||
logs.CtxErrorf(ctx, "publish delete workflow event resource failed, workflowID: %d, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) BatchDeleteWorkflow(ctx context.Context, req *workflow.BatchDeleteWorkflowRequest) (
|
||||
_ *workflow.BatchDeleteWorkflowResponse, err error,
|
||||
) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
_ *workflow.BatchDeleteWorkflowResponse, err error) {
|
||||
if err := checkUserSpace(ctx, ctxutil.MustGetUIDFromCtx(ctx), mustParseInt64(req.GetSpaceID())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -300,7 +321,7 @@ func (w *ApplicationService) BatchDeleteWorkflow(ctx context.Context, req *workf
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
|
||||
err = w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
|
||||
IDs: ids,
|
||||
})
|
||||
if err != nil {
|
||||
@ -335,7 +356,7 @@ func (w *ApplicationService) GetCanvasInfo(ctx context.Context, req *workflow.Ge
|
||||
|
||||
wf, err := GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
|
||||
ID: mustParseInt64(req.GetWorkflowID()),
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -431,19 +452,19 @@ func (w *ApplicationService) TestRun(ctx context.Context, req *workflow.WorkFlow
|
||||
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
|
||||
}
|
||||
|
||||
exeCfg := vo.ExecuteConfig{
|
||||
exeCfg := workflowModel.ExecuteConfig{
|
||||
ID: mustParseInt64(req.GetWorkflowID()),
|
||||
From: vo.FromDraft,
|
||||
From: workflowModel.FromDraft,
|
||||
CommitID: req.GetCommitID(),
|
||||
Operator: uID,
|
||||
Mode: vo.ExecuteModeDebug,
|
||||
Mode: workflowModel.ExecuteModeDebug,
|
||||
AppID: appID,
|
||||
AgentID: agentID,
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
ConnectorUID: strconv.FormatInt(uID, 10),
|
||||
TaskType: vo.TaskTypeForeground,
|
||||
SyncPattern: vo.SyncPatternAsync,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
TaskType: workflowModel.TaskTypeForeground,
|
||||
SyncPattern: workflowModel.SyncPatternAsync,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
}
|
||||
|
||||
@ -503,18 +524,18 @@ func (w *ApplicationService) NodeDebug(ctx context.Context, req *workflow.Workfl
|
||||
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
|
||||
}
|
||||
|
||||
exeCfg := vo.ExecuteConfig{
|
||||
exeCfg := workflowModel.ExecuteConfig{
|
||||
ID: mustParseInt64(req.GetWorkflowID()),
|
||||
From: vo.FromDraft,
|
||||
From: workflowModel.FromDraft,
|
||||
Operator: uID,
|
||||
Mode: vo.ExecuteModeNodeDebug,
|
||||
Mode: workflowModel.ExecuteModeNodeDebug,
|
||||
AppID: appID,
|
||||
AgentID: agentID,
|
||||
ConnectorID: consts.CozeConnectorID,
|
||||
ConnectorUID: strconv.FormatInt(uID, 10),
|
||||
TaskType: vo.TaskTypeForeground,
|
||||
SyncPattern: vo.SyncPatternAsync,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
TaskType: workflowModel.TaskTypeForeground,
|
||||
SyncPattern: workflowModel.SyncPatternAsync,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
}
|
||||
|
||||
@ -832,17 +853,7 @@ func (w *ApplicationService) GetNodeExecuteHistory(ctx context.Context, req *wor
|
||||
}
|
||||
|
||||
func (w *ApplicationService) DeleteWorkflowsByAppID(ctx context.Context, appID int64) (err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
return GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
|
||||
return w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
})
|
||||
}
|
||||
@ -866,7 +877,7 @@ func (w *ApplicationService) CheckWorkflowsExistByAppID(ctx context.Context, app
|
||||
Page: 0,
|
||||
},
|
||||
},
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaOnly: true,
|
||||
})
|
||||
|
||||
@ -874,8 +885,7 @@ func (w *ApplicationService) CheckWorkflowsExistByAppID(ctx context.Context, app
|
||||
}
|
||||
|
||||
func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (
|
||||
_ int64, _ []*vo.ValidateIssue, err error,
|
||||
) {
|
||||
_ int64, _ []*vo.ValidateIssue, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
@ -891,7 +901,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
pluginMap := make(map[int64]*vo.PluginEntity)
|
||||
pluginMap := make(map[int64]*plugin.PluginEntity)
|
||||
pluginToolMap := make(map[int64]int64)
|
||||
|
||||
if len(ds.PluginIDs) > 0 {
|
||||
@ -906,7 +916,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
return 0, nil, err
|
||||
}
|
||||
pInfo := response.Plugin
|
||||
pluginMap[id] = &vo.PluginEntity{
|
||||
pluginMap[id] = &plugin.PluginEntity{
|
||||
PluginID: pInfo.ID,
|
||||
PluginVersion: pInfo.Version,
|
||||
}
|
||||
@ -958,7 +968,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
|
||||
}
|
||||
|
||||
relatedWorkflows, vIssues, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
|
||||
relatedWorkflows, vIssues, err := w.copyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
|
||||
PluginMap: pluginMap,
|
||||
PluginToolMap: pluginToolMap,
|
||||
KnowledgeMap: relatedKnowledgeMap,
|
||||
@ -980,6 +990,32 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
return copiedWf.ID, vIssues, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) copyWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, appID int64, related vo.ExternalResourceRelated) (map[int64]entity.IDVersionPair, []*vo.ValidateIssue, error) {
|
||||
resp, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, related)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for index := range resp.CopiedWorkflows {
|
||||
wf := resp.CopiedWorkflows[index]
|
||||
|
||||
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
|
||||
Name: &wf.Name,
|
||||
SpaceID: &wf.SpaceID,
|
||||
OwnerID: &wf.CreatorID,
|
||||
|
||||
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
|
||||
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to publish workflow resource, workflow id=%d, err=%v", wf.ID, err)
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp.WorkflowIDVersionMap, resp.ValidateIssues, nil
|
||||
}
|
||||
|
||||
type ExternalResource struct {
|
||||
PluginMap map[int64]int64
|
||||
PluginToolMap map[int64]int64
|
||||
@ -998,9 +1034,9 @@ func (w *ApplicationService) DuplicateWorkflowsByAppID(ctx context.Context, sour
|
||||
}
|
||||
}()
|
||||
|
||||
pluginMap := make(map[int64]*vo.PluginEntity)
|
||||
pluginMap := make(map[int64]*plugin.PluginEntity)
|
||||
for o, n := range externalResource.PluginMap {
|
||||
pluginMap[o] = &vo.PluginEntity{
|
||||
pluginMap[o] = &plugin.PluginEntity{
|
||||
PluginID: n,
|
||||
}
|
||||
}
|
||||
@ -1011,7 +1047,29 @@ func (w *ApplicationService) DuplicateWorkflowsByAppID(ctx context.Context, sour
|
||||
DatabaseMap: externalResource.DatabaseMap,
|
||||
}
|
||||
|
||||
return GetWorkflowDomainSVC().DuplicateWorkflowsByAppID(ctx, sourceAppID, targetAppID, externalResourceRelated)
|
||||
copiedWorkflowArray, err := GetWorkflowDomainSVC().DuplicateWorkflowsByAppID(ctx, sourceAppID, targetAppID, externalResourceRelated)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "[DuplicateWorkflowsByAppID] %s", conv.DebugJsonToStr(copiedWorkflowArray))
|
||||
|
||||
for index := range copiedWorkflowArray {
|
||||
wf := copiedWorkflowArray[index]
|
||||
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
|
||||
Name: &wf.Name,
|
||||
SpaceID: &wf.SpaceID,
|
||||
OwnerID: &wf.CreatorID,
|
||||
APPID: &targetAppID,
|
||||
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
|
||||
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to publish workflow resource, workflow id=%d, err=%v", wf.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, workflowID int64, appID int64) (
|
||||
@ -1027,7 +1085,7 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
|
||||
}
|
||||
}()
|
||||
|
||||
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
|
||||
wf, err := w.copyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
|
||||
TargetAppID: &appID,
|
||||
})
|
||||
if err != nil {
|
||||
@ -1037,6 +1095,28 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
|
||||
return wf.ID, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) copyWorkflow(ctx context.Context, workflowID int64, policy vo.CopyWorkflowPolicy) (*entity.Workflow, error) {
|
||||
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
|
||||
Name: &wf.Name,
|
||||
APPID: wf.AppID,
|
||||
SpaceID: &wf.SpaceID,
|
||||
OwnerID: &wf.CreatorID,
|
||||
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
|
||||
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "public copy workflow event failed, workflowID=%d, err=%v", wf.ID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wf, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/
|
||||
appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
|
||||
defer func() {
|
||||
@ -1054,7 +1134,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
pluginMap := make(map[int64]*vo.PluginEntity)
|
||||
pluginMap := make(map[int64]*plugin.PluginEntity)
|
||||
if len(ds.PluginIDs) > 0 {
|
||||
for idx := range ds.PluginIDs {
|
||||
id := ds.PluginIDs[idx]
|
||||
@ -1062,7 +1142,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
pluginMap[id] = &vo.PluginEntity{
|
||||
pluginMap[id] = &plugin.PluginEntity{
|
||||
PluginID: pInfo.ID,
|
||||
PluginVersion: pInfo.Version,
|
||||
}
|
||||
@ -1091,7 +1171,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
}
|
||||
}
|
||||
|
||||
relatedWorkflows, vIssues, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
|
||||
relatedWorkflows, vIssues, err := w.copyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
|
||||
PluginMap: pluginMap,
|
||||
})
|
||||
if err != nil {
|
||||
@ -1109,7 +1189,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
|
||||
}
|
||||
|
||||
deleteWorkflowIDs := xmaps.Keys(relatedWorkflows)
|
||||
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
|
||||
err = w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
|
||||
IDs: deleteWorkflowIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@ -1250,7 +1330,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
||||
return &workflow.OpenAPIStreamRunFlowResponse{
|
||||
ID: strconv.Itoa(messageID),
|
||||
Event: string(DoneEvent),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
}, nil
|
||||
case entity.WorkflowFailed, entity.WorkflowCancel:
|
||||
var wfe vo.WorkflowError
|
||||
@ -1260,7 +1340,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
||||
return &workflow.OpenAPIStreamRunFlowResponse{
|
||||
ID: strconv.Itoa(messageID),
|
||||
Event: string(ErrEvent),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
ErrorCode: ptr.Of(int64(wfe.Code())),
|
||||
ErrorMessage: ptr.Of(wfe.Msg()),
|
||||
}, nil
|
||||
@ -1269,7 +1349,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
||||
return &workflow.OpenAPIStreamRunFlowResponse{
|
||||
ID: strconv.Itoa(messageID),
|
||||
Event: string(InterruptEvent),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
InterruptData: &workflow.Interrupt{
|
||||
EventID: fmt.Sprintf("%d/%d", executeID, msg.InterruptEvent.ID),
|
||||
Type: workflow.InterruptType(msg.InterruptEvent.EventType),
|
||||
@ -1281,7 +1361,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
||||
return &workflow.OpenAPIStreamRunFlowResponse{
|
||||
ID: strconv.Itoa(messageID),
|
||||
Event: string(InterruptEvent),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
|
||||
InterruptData: &workflow.Interrupt{
|
||||
EventID: fmt.Sprintf("%d/%d", executeID, msg.InterruptEvent.ID),
|
||||
Type: workflow.InterruptType(msg.InterruptEvent.ToolInterruptEvent.EventType),
|
||||
@ -1397,20 +1477,20 @@ func (w *ApplicationService) OpenAPIStreamRun(ctx context.Context, req *workflow
|
||||
connectorID = apiKeyInfo.ConnectorID
|
||||
}
|
||||
|
||||
exeCfg := vo.ExecuteConfig{
|
||||
exeCfg := workflowModel.ExecuteConfig{
|
||||
ID: meta.ID,
|
||||
From: vo.FromSpecificVersion,
|
||||
From: workflowModel.FromSpecificVersion,
|
||||
Version: *meta.LatestPublishedVersion,
|
||||
Operator: userID,
|
||||
Mode: vo.ExecuteModeRelease,
|
||||
Mode: workflowModel.ExecuteModeRelease,
|
||||
AppID: appID,
|
||||
AgentID: agentID,
|
||||
ConnectorID: connectorID,
|
||||
ConnectorUID: strconv.FormatInt(userID, 10),
|
||||
TaskType: vo.TaskTypeForeground,
|
||||
SyncPattern: vo.SyncPatternStream,
|
||||
TaskType: workflowModel.TaskTypeForeground,
|
||||
SyncPattern: workflowModel.SyncPatternStream,
|
||||
InputFailFast: true,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
}
|
||||
|
||||
if exeCfg.AppID != nil && exeCfg.AgentID != nil {
|
||||
@ -1471,12 +1551,12 @@ func (w *ApplicationService) OpenAPIStreamResume(ctx context.Context, req *workf
|
||||
connectorID = mustParseInt64(req.GetConnectorID())
|
||||
}
|
||||
|
||||
sr, err := GetWorkflowDomainSVC().StreamResume(ctx, resumeReq, vo.ExecuteConfig{
|
||||
sr, err := GetWorkflowDomainSVC().StreamResume(ctx, resumeReq, workflowModel.ExecuteConfig{
|
||||
Operator: userID,
|
||||
Mode: vo.ExecuteModeRelease,
|
||||
Mode: workflowModel.ExecuteModeRelease,
|
||||
ConnectorID: connectorID,
|
||||
ConnectorUID: strconv.FormatInt(userID, 10),
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1546,18 +1626,18 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
|
||||
connectorID = apiKeyInfo.ConnectorID
|
||||
}
|
||||
|
||||
exeCfg := vo.ExecuteConfig{
|
||||
exeCfg := workflowModel.ExecuteConfig{
|
||||
ID: meta.ID,
|
||||
From: vo.FromSpecificVersion,
|
||||
From: workflowModel.FromSpecificVersion,
|
||||
Version: *meta.LatestPublishedVersion,
|
||||
Operator: userID,
|
||||
Mode: vo.ExecuteModeRelease,
|
||||
Mode: workflowModel.ExecuteModeRelease,
|
||||
AppID: appID,
|
||||
AgentID: agentID,
|
||||
ConnectorID: connectorID,
|
||||
ConnectorUID: strconv.FormatInt(userID, 10),
|
||||
InputFailFast: true,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
}
|
||||
|
||||
if exeCfg.AppID != nil && exeCfg.AgentID != nil {
|
||||
@ -1565,8 +1645,8 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
|
||||
}
|
||||
|
||||
if req.GetIsAsync() {
|
||||
exeCfg.SyncPattern = vo.SyncPatternAsync
|
||||
exeCfg.TaskType = vo.TaskTypeBackground
|
||||
exeCfg.SyncPattern = workflowModel.SyncPatternAsync
|
||||
exeCfg.TaskType = workflowModel.TaskTypeBackground
|
||||
exeID, err := GetWorkflowDomainSVC().AsyncExecute(ctx, exeCfg, parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1574,12 +1654,12 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
|
||||
|
||||
return &workflow.OpenAPIRunFlowResponse{
|
||||
ExecuteID: ptr.Of(strconv.FormatInt(exeID, 10)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, exeID, meta.SpaceID, meta.ID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, exeID, meta.SpaceID, meta.ID)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
exeCfg.SyncPattern = vo.SyncPatternSync
|
||||
exeCfg.TaskType = vo.TaskTypeForeground
|
||||
exeCfg.SyncPattern = workflowModel.SyncPatternSync
|
||||
exeCfg.TaskType = workflowModel.TaskTypeForeground
|
||||
wfExe, tPlan, err := GetWorkflowDomainSVC().SyncExecute(ctx, exeCfg, parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1610,7 +1690,7 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
|
||||
return &workflow.OpenAPIRunFlowResponse{
|
||||
Data: data,
|
||||
ExecuteID: ptr.Of(strconv.FormatInt(wfExe.ID, 10)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, wfExe.ID, wfExe.SpaceID, meta.ID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, wfExe.ID, wfExe.SpaceID, meta.ID)),
|
||||
Token: ptr.Of(wfExe.TokenInfo.InputTokens + wfExe.TokenInfo.OutputTokens),
|
||||
Cost: ptr.Of("0.00000"),
|
||||
}, nil
|
||||
@ -1650,11 +1730,11 @@ func (w *ApplicationService) OpenAPIGetWorkflowRunHistory(ctx context.Context, r
|
||||
|
||||
var runMode *workflow.WorkflowRunMode
|
||||
switch exe.SyncPattern {
|
||||
case vo.SyncPatternSync:
|
||||
case workflowModel.SyncPatternSync:
|
||||
runMode = ptr.Of(workflow.WorkflowRunMode_Sync)
|
||||
case vo.SyncPatternAsync:
|
||||
case workflowModel.SyncPatternAsync:
|
||||
runMode = ptr.Of(workflow.WorkflowRunMode_Async)
|
||||
case vo.SyncPatternStream:
|
||||
case workflowModel.SyncPatternStream:
|
||||
runMode = ptr.Of(workflow.WorkflowRunMode_Stream)
|
||||
default:
|
||||
}
|
||||
@ -1671,7 +1751,7 @@ func (w *ApplicationService) OpenAPIGetWorkflowRunHistory(ctx context.Context, r
|
||||
LogID: ptr.Of(exe.LogID),
|
||||
CreateTime: ptr.Of(exe.CreatedAt.Unix()),
|
||||
UpdateTime: updateTime,
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, exe.ID, exe.SpaceID, exe.WorkflowID)),
|
||||
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, exe.ID, exe.SpaceID, exe.WorkflowID)),
|
||||
Input: exe.Input,
|
||||
Output: exe.Output,
|
||||
Token: ptr.Of(exe.TokenInfo.InputTokens + exe.TokenInfo.OutputTokens),
|
||||
@ -1805,10 +1885,10 @@ func (w *ApplicationService) TestResume(ctx context.Context, req *workflow.Workf
|
||||
EventID: mustParseInt64(req.GetEventID()),
|
||||
ResumeData: req.GetData(),
|
||||
}
|
||||
err = GetWorkflowDomainSVC().AsyncResume(ctx, resumeReq, vo.ExecuteConfig{
|
||||
err = GetWorkflowDomainSVC().AsyncResume(ctx, resumeReq, workflowModel.ExecuteConfig{
|
||||
Operator: ptr.FromOrDefault(ctxutil.GetUIDFromCtx(ctx), 0),
|
||||
Mode: vo.ExecuteModeDebug, // at this stage it could be debug or node debug, we will decide it within AsyncResume
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
Mode: workflowModel.ExecuteModeDebug, // at this stage it could be debug or node debug, we will decide it within AsyncResume
|
||||
BizType: workflowModel.BizTypeWorkflow,
|
||||
Cancellable: true,
|
||||
})
|
||||
if err != nil {
|
||||
@ -1948,7 +2028,7 @@ func (w *ApplicationService) PublishWorkflow(ctx context.Context, req *workflow.
|
||||
Force: req.GetForce(),
|
||||
}
|
||||
|
||||
err = GetWorkflowDomainSVC().Publish(ctx, info)
|
||||
err = w.publishWorkflowResource(ctx, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2003,13 +2083,13 @@ func (w *ApplicationService) ListWorkflow(ctx context.Context, req *workflow.Get
|
||||
}
|
||||
|
||||
status := req.GetStatus()
|
||||
var qType vo.Locator
|
||||
var qType workflowModel.Locator
|
||||
if status == workflow.WorkFlowListStatus_UnPublished {
|
||||
option.PublishStatus = ptr.Of(vo.UnPublished)
|
||||
qType = vo.FromDraft
|
||||
qType = workflowModel.FromDraft
|
||||
} else if status == workflow.WorkFlowListStatus_HadPublished {
|
||||
option.PublishStatus = ptr.Of(vo.HasPublished)
|
||||
qType = vo.FromLatestVersion
|
||||
qType = workflowModel.FromLatestVersion
|
||||
}
|
||||
|
||||
if len(req.GetName()) > 0 {
|
||||
@ -2077,9 +2157,9 @@ func (w *ApplicationService) ListWorkflow(ctx context.Context, req *workflow.Get
|
||||
},
|
||||
}
|
||||
|
||||
if qType == vo.FromDraft {
|
||||
if qType == workflowModel.FromDraft {
|
||||
ww.UpdateTime = w.DraftMeta.Timestamp.Unix()
|
||||
} else if qType == vo.FromLatestVersion || qType == vo.FromSpecificVersion {
|
||||
} else if qType == workflowModel.FromLatestVersion || qType == workflowModel.FromSpecificVersion {
|
||||
ww.UpdateTime = w.VersionMeta.VersionCreatedAt.Unix()
|
||||
} else if w.UpdatedAt != nil {
|
||||
ww.UpdateTime = w.UpdatedAt.Unix()
|
||||
@ -2164,7 +2244,7 @@ func (w *ApplicationService) GetWorkflowDetail(ctx context.Context, req *workflo
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: ids,
|
||||
},
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
@ -2270,7 +2350,7 @@ func (w *ApplicationService) GetWorkflowDetailInfo(ctx context.Context, req *wor
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: draftIDs,
|
||||
},
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
@ -2283,7 +2363,7 @@ func (w *ApplicationService) GetWorkflowDetailInfo(ctx context.Context, req *wor
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: versionIDs,
|
||||
},
|
||||
QType: vo.FromSpecificVersion,
|
||||
QType: workflowModel.FromSpecificVersion,
|
||||
MetaOnly: false,
|
||||
Versions: id2Version,
|
||||
})
|
||||
@ -2484,8 +2564,8 @@ func (w *ApplicationService) GetApiDetail(ctx context.Context, req *workflow.Get
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolInfoResponse, err := crossplugin.GetPluginService().GetPluginToolsInfo(ctx, &crossplugin.ToolsInfoRequest{
|
||||
PluginEntity: crossplugin.Entity{
|
||||
toolInfoResponse, err := crossplugin.DefaultSVC().GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{
|
||||
PluginEntity: plugin.PluginEntity{
|
||||
PluginID: pluginID,
|
||||
PluginVersion: req.PluginVersion,
|
||||
},
|
||||
@ -2551,8 +2631,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
|
||||
}
|
||||
|
||||
var (
|
||||
pluginSvc = crossplugin.GetPluginService()
|
||||
pluginToolsInfoReqs = make(map[int64]*crossplugin.ToolsInfoRequest)
|
||||
pluginSvc = crossplugin.DefaultSVC()
|
||||
pluginToolsInfoReqs = make(map[int64]*plugin.ToolsInfoRequest)
|
||||
pluginDetailMap = make(map[string]*workflow.PluginDetail)
|
||||
toolsDetailInfo = make(map[string]*workflow.APIDetail)
|
||||
workflowDetailMap = make(map[string]*workflow.WorkflowDetail)
|
||||
@ -2574,8 +2654,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
|
||||
if r, ok := pluginToolsInfoReqs[pluginID]; ok {
|
||||
r.ToolIDs = append(r.ToolIDs, toolID)
|
||||
} else {
|
||||
pluginToolsInfoReqs[pluginID] = &crossplugin.ToolsInfoRequest{
|
||||
PluginEntity: crossplugin.Entity{
|
||||
pluginToolsInfoReqs[pluginID] = &plugin.ToolsInfoRequest{
|
||||
PluginEntity: plugin.PluginEntity{
|
||||
PluginID: pluginID,
|
||||
PluginVersion: pl.PluginVersion,
|
||||
},
|
||||
@ -2656,7 +2736,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: draftIDs,
|
||||
},
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
@ -2669,7 +2749,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: versionIDs,
|
||||
},
|
||||
QType: vo.FromSpecificVersion,
|
||||
QType: workflowModel.FromSpecificVersion,
|
||||
MetaOnly: false,
|
||||
Versions: id2Version,
|
||||
})
|
||||
@ -2705,18 +2785,18 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
|
||||
}
|
||||
|
||||
if len(req.GetDatasetList()) > 0 {
|
||||
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
|
||||
knowledgeOperator := crossknowledge.DefaultSVC()
|
||||
knowledgeIDs, err := slices.TransformWithErrorCheck(req.GetDatasetList(), func(a *workflow.DatasetFCItem) (int64, error) {
|
||||
return strconv.ParseInt(a.GetDatasetID(), 10, 64)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
details, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{KnowledgeIDs: knowledgeIDs})
|
||||
details, err := knowledgeOperator.ListKnowledgeDetail(ctx, &model.ListKnowledgeDetailRequest{KnowledgeIDs: knowledgeIDs})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeDetailMap = slices.ToMap(details.KnowledgeDetails, func(kd *crossknowledge.KnowledgeDetail) (string, *workflow.DatasetDetail) {
|
||||
knowledgeDetailMap = slices.ToMap(details.KnowledgeDetails, func(kd *model.KnowledgeDetail) (string, *workflow.DatasetDetail) {
|
||||
return strconv.FormatInt(kd.ID, 10), &workflow.DatasetDetail{
|
||||
ID: strconv.FormatInt(kd.ID, 10),
|
||||
Name: kd.Name,
|
||||
@ -2759,7 +2839,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
|
||||
var fcPluginSetting *workflow.FCPluginSetting
|
||||
if req.GetPluginFcSetting() != nil {
|
||||
var (
|
||||
pluginSvc = crossplugin.GetPluginService()
|
||||
pluginSvc = crossplugin.DefaultSVC()
|
||||
pluginFcSetting = req.GetPluginFcSetting()
|
||||
isDraft = pluginFcSetting.GetIsDraft()
|
||||
)
|
||||
@ -2774,8 +2854,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pluginReq := &crossplugin.ToolsInfoRequest{
|
||||
PluginEntity: vo.PluginEntity{
|
||||
pluginReq := &plugin.ToolsInfoRequest{
|
||||
PluginEntity: plugin.PluginEntity{
|
||||
PluginID: pluginID,
|
||||
},
|
||||
ToolIDs: []int64{toolID},
|
||||
@ -2816,7 +2896,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
|
||||
|
||||
policy := &vo.GetPolicy{
|
||||
ID: wID,
|
||||
QType: ternary.IFElse(len(setting.WorkflowVersion) == 0, vo.FromDraft, vo.FromSpecificVersion),
|
||||
QType: ternary.IFElse(len(setting.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: setting.WorkflowVersion,
|
||||
}
|
||||
|
||||
@ -2891,7 +2971,7 @@ func (w *ApplicationService) GetPlaygroundPluginList(ctx context.Context, req *p
|
||||
SpaceID: ptr.Of(req.GetSpaceID()),
|
||||
PublishStatus: ptr.Of(vo.HasPublished),
|
||||
},
|
||||
QType: vo.FromLatestVersion,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
} else if req.GetPage() > 0 && req.GetSize() > 0 {
|
||||
wfs, _, err = GetWorkflowDomainSVC().MGet(ctx, &vo.MGetPolicy{
|
||||
@ -2903,7 +2983,7 @@ func (w *ApplicationService) GetPlaygroundPluginList(ctx context.Context, req *p
|
||||
SpaceID: ptr.Of(req.GetSpaceID()),
|
||||
PublishStatus: ptr.Of(vo.HasPublished),
|
||||
},
|
||||
QType: vo.FromLatestVersion,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
}
|
||||
|
||||
@ -2977,7 +3057,7 @@ func (w *ApplicationService) CopyWorkflow(ctx context.Context, req *workflow.Cop
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
|
||||
wf, err := w.copyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
|
||||
ShouldModifyWorkflowName: true,
|
||||
})
|
||||
if err != nil {
|
||||
@ -3043,7 +3123,7 @@ func (w *ApplicationService) GetHistorySchema(ctx context.Context, req *workflow
|
||||
// get the workflow entity for that workflowID and commitID
|
||||
policy := &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(exe.Version) > 0, vo.FromSpecificVersion, vo.FromDraft),
|
||||
QType: ternary.IFElse(len(exe.Version) > 0, workflowModel.FromSpecificVersion, workflowModel.FromDraft),
|
||||
Version: exe.Version,
|
||||
CommitID: exe.CommitID,
|
||||
}
|
||||
@ -3101,7 +3181,7 @@ func (w *ApplicationService) GetExampleWorkFlowList(ctx context.Context, req *wo
|
||||
|
||||
wfs, _, err := GetWorkflowDomainSVC().MGet(ctx, &vo.MGetPolicy{
|
||||
MetaQuery: option,
|
||||
QType: vo.FromDraft,
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
@ -3180,7 +3260,7 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, wid, vo.CopyWorkflowPolicy{
|
||||
wf, err := w.copyWorkflow(ctx, wid, vo.CopyWorkflowPolicy{
|
||||
ShouldModifyWorkflowName: true,
|
||||
TargetSpaceID: ptr.Of(req.GetTargetSpaceID()),
|
||||
TargetAppID: ptr.Of(int64(0)),
|
||||
@ -3189,7 +3269,7 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = GetWorkflowDomainSVC().Publish(ctx, &vo.PublishPolicy{
|
||||
err = w.publishWorkflowResource(ctx, &vo.PublishPolicy{
|
||||
ID: wf.ID,
|
||||
Version: "v0.0.0",
|
||||
CommitID: wf.CommitID,
|
||||
@ -3270,6 +3350,26 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (w *ApplicationService) publishWorkflowResource(ctx context.Context, policy *vo.PublishPolicy) error {
|
||||
err := GetWorkflowDomainSVC().Publish(ctx, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
now := time.Now().UnixMilli()
|
||||
if err := PublishWorkflowResource(ctx, policy.ID, nil, search.Updated, &search.ResourceDocument{
|
||||
PublishStatus: ptr.Of(resource.PublishStatus_Published),
|
||||
UpdateTimeMS: ptr.Of(now),
|
||||
PublishTimeMS: ptr.Of(now),
|
||||
}); err != nil {
|
||||
logs.CtxErrorf(ctx, "publish workflow resource failed workflowID = %d, err: %v", policy.ID, err)
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustParseInt64(s string) int64 {
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
@ -3518,13 +3618,13 @@ func toVariable(p *workflow.APIParameter) (*vo.Variable, error) {
|
||||
v.Type = vo.VariableTypeBoolean
|
||||
case workflow.ParameterType_Array:
|
||||
v.Type = vo.VariableTypeList
|
||||
if len(p.SubParameters) == 1 {
|
||||
if len(p.SubParameters) == 1 && p.SubType != nil && *p.SubType != workflow.ParameterType_Object {
|
||||
av, err := toVariable(p.SubParameters[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.Schema = &av
|
||||
} else if len(p.SubParameters) > 1 {
|
||||
} else {
|
||||
subVs := make([]any, 0)
|
||||
for _, ap := range p.SubParameters {
|
||||
av, err := toVariable(ap)
|
||||
|
||||
4
backend/conf/workflow/config.yaml
Normal file
4
backend/conf/workflow/config.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
NodeOfCodeConfig:
|
||||
SupportThirdPartModules:
|
||||
- httpx
|
||||
- numpy
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossagent
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossconnector
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossconversation
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossdatabase
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
type Database interface {
|
||||
ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error)
|
||||
PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (resp *database.PublishDatabaseResponse, err error)
|
||||
@ -30,6 +31,12 @@ type Database interface {
|
||||
UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error
|
||||
MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error)
|
||||
GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error)
|
||||
|
||||
Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error)
|
||||
Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error)
|
||||
Update(context.Context, *database.UpdateRequest) (*database.Response, error)
|
||||
Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error)
|
||||
Delete(context.Context, *database.DeleteRequest) (*database.Response, error)
|
||||
}
|
||||
|
||||
var defaultSVC Database
|
||||
@ -0,0 +1,219 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: database.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
|
||||
//
|
||||
|
||||
// Package databasemock is a generated GoMock package.
|
||||
package databasemock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
database "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDatabase is a mock of Database interface.
|
||||
type MockDatabase struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDatabaseMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockDatabaseMockRecorder is the mock recorder for MockDatabase.
|
||||
type MockDatabaseMockRecorder struct {
|
||||
mock *MockDatabase
|
||||
}
|
||||
|
||||
// NewMockDatabase creates a new mock instance.
|
||||
func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase {
|
||||
mock := &MockDatabase{ctrl: ctrl}
|
||||
mock.recorder = &MockDatabaseMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BindDatabase mocks base method.
|
||||
func (m *MockDatabase) BindDatabase(ctx context.Context, req *database.BindDatabaseToAgentRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BindDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BindDatabase indicates an expected call of BindDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) BindDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDatabase", reflect.TypeOf((*MockDatabase)(nil).BindDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockDatabase) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockDatabaseMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabase)(nil).Delete), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteDatabase mocks base method.
|
||||
func (m *MockDatabase) DeleteDatabase(ctx context.Context, req *database.DeleteDatabaseRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteDatabase indicates an expected call of DeleteDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) DeleteDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDatabase", reflect.TypeOf((*MockDatabase)(nil).DeleteDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Execute mocks base method.
|
||||
func (m *MockDatabase) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Execute", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Execute indicates an expected call of Execute.
|
||||
func (mr *MockDatabaseMockRecorder) Execute(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabase)(nil).Execute), ctx, request)
|
||||
}
|
||||
|
||||
// ExecuteSQL mocks base method.
|
||||
func (m *MockDatabase) ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExecuteSQL", ctx, req)
|
||||
ret0, _ := ret[0].(*database.ExecuteSQLResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExecuteSQL indicates an expected call of ExecuteSQL.
|
||||
func (mr *MockDatabaseMockRecorder) ExecuteSQL(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteSQL", reflect.TypeOf((*MockDatabase)(nil).ExecuteSQL), ctx, req)
|
||||
}
|
||||
|
||||
// GetAllDatabaseByAppID mocks base method.
|
||||
func (m *MockDatabase) GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllDatabaseByAppID", ctx, req)
|
||||
ret0, _ := ret[0].(*database.GetAllDatabaseByAppIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllDatabaseByAppID indicates an expected call of GetAllDatabaseByAppID.
|
||||
func (mr *MockDatabaseMockRecorder) GetAllDatabaseByAppID(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllDatabaseByAppID", reflect.TypeOf((*MockDatabase)(nil).GetAllDatabaseByAppID), ctx, req)
|
||||
}
|
||||
|
||||
// Insert mocks base method.
|
||||
func (m *MockDatabase) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Insert", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Insert indicates an expected call of Insert.
|
||||
func (mr *MockDatabaseMockRecorder) Insert(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabase)(nil).Insert), ctx, request)
|
||||
}
|
||||
|
||||
// MGetDatabase mocks base method.
|
||||
func (m *MockDatabase) MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(*database.MGetDatabaseResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetDatabase indicates an expected call of MGetDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) MGetDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDatabase", reflect.TypeOf((*MockDatabase)(nil).MGetDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// PublishDatabase mocks base method.
|
||||
func (m *MockDatabase) PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (*database.PublishDatabaseResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(*database.PublishDatabaseResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PublishDatabase indicates an expected call of PublishDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) PublishDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDatabase", reflect.TypeOf((*MockDatabase)(nil).PublishDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockDatabase) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Query", ctx, request)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockDatabaseMockRecorder) Query(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabase)(nil).Query), ctx, request)
|
||||
}
|
||||
|
||||
// UnBindDatabase mocks base method.
|
||||
func (m *MockDatabase) UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnBindDatabase", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnBindDatabase indicates an expected call of UnBindDatabase.
|
||||
func (mr *MockDatabaseMockRecorder) UnBindDatabase(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnBindDatabase", reflect.TypeOf((*MockDatabase)(nil).UnBindDatabase), ctx, req)
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockDatabase) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", arg0, arg1)
|
||||
ret0, _ := ret[0].(*database.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockDatabaseMockRecorder) Update(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabase)(nil).Update), arg0, arg1)
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossdatacopy
|
||||
package datacopy
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossknowledge
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -22,11 +22,16 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
|
||||
type Knowledge interface {
|
||||
ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (response *knowledge.ListKnowledgeResponse, err error)
|
||||
GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (response *knowledge.GetKnowledgeByIDResponse, err error)
|
||||
Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error)
|
||||
DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error
|
||||
MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (response *knowledge.MGetKnowledgeByIDResponse, err error)
|
||||
Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error)
|
||||
Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error)
|
||||
ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error)
|
||||
}
|
||||
|
||||
var defaultSVC Knowledge
|
||||
@ -0,0 +1,161 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: knowledge.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
|
||||
//
|
||||
|
||||
// Package knowledgemock is a generated GoMock package.
|
||||
package knowledgemock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockKnowledge is a mock of Knowledge interface.
|
||||
type MockKnowledge struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockKnowledgeMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockKnowledgeMockRecorder is the mock recorder for MockKnowledge.
|
||||
type MockKnowledgeMockRecorder struct {
|
||||
mock *MockKnowledge
|
||||
}
|
||||
|
||||
// NewMockKnowledge creates a new mock instance.
|
||||
func NewMockKnowledge(ctrl *gomock.Controller) *MockKnowledge {
|
||||
mock := &MockKnowledge{ctrl: ctrl}
|
||||
mock.recorder = &MockKnowledgeMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockKnowledge) EXPECT() *MockKnowledgeMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockKnowledge) Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, r)
|
||||
ret0, _ := ret[0].(*knowledge.DeleteDocumentResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockKnowledgeMockRecorder) Delete(ctx, r any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockKnowledge)(nil).Delete), ctx, r)
|
||||
}
|
||||
|
||||
// DeleteKnowledge mocks base method.
|
||||
func (m *MockKnowledge) DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteKnowledge", ctx, request)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteKnowledge indicates an expected call of DeleteKnowledge.
|
||||
func (mr *MockKnowledgeMockRecorder) DeleteKnowledge(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKnowledge", reflect.TypeOf((*MockKnowledge)(nil).DeleteKnowledge), ctx, request)
|
||||
}
|
||||
|
||||
// GetKnowledgeByID mocks base method.
|
||||
func (m *MockKnowledge) GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (*knowledge.GetKnowledgeByIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKnowledgeByID", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.GetKnowledgeByIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKnowledgeByID indicates an expected call of GetKnowledgeByID.
|
||||
func (mr *MockKnowledgeMockRecorder) GetKnowledgeByID(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).GetKnowledgeByID), ctx, request)
|
||||
}
|
||||
|
||||
// ListKnowledge mocks base method.
|
||||
func (m *MockKnowledge) ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (*knowledge.ListKnowledgeResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListKnowledge", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.ListKnowledgeResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListKnowledge indicates an expected call of ListKnowledge.
|
||||
func (mr *MockKnowledgeMockRecorder) ListKnowledge(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledge", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledge), ctx, request)
|
||||
}
|
||||
|
||||
// ListKnowledgeDetail mocks base method.
|
||||
func (m *MockKnowledge) ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListKnowledgeDetail", ctx, req)
|
||||
ret0, _ := ret[0].(*knowledge.ListKnowledgeDetailResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListKnowledgeDetail indicates an expected call of ListKnowledgeDetail.
|
||||
func (mr *MockKnowledgeMockRecorder) ListKnowledgeDetail(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledgeDetail", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledgeDetail), ctx, req)
|
||||
}
|
||||
|
||||
// MGetKnowledgeByID mocks base method.
|
||||
func (m *MockKnowledge) MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (*knowledge.MGetKnowledgeByIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetKnowledgeByID", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.MGetKnowledgeByIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetKnowledgeByID indicates an expected call of MGetKnowledgeByID.
|
||||
func (mr *MockKnowledgeMockRecorder) MGetKnowledgeByID(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).MGetKnowledgeByID), ctx, request)
|
||||
}
|
||||
|
||||
// Retrieve mocks base method.
|
||||
func (m *MockKnowledge) Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Retrieve", ctx, req)
|
||||
ret0, _ := ret[0].(*knowledge.RetrieveResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Retrieve indicates an expected call of Retrieve.
|
||||
func (mr *MockKnowledgeMockRecorder) Retrieve(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockKnowledge)(nil).Retrieve), ctx, req)
|
||||
}
|
||||
|
||||
// Store mocks base method.
|
||||
func (m *MockKnowledge) Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Store", ctx, document)
|
||||
ret0, _ := ret[0].(*knowledge.CreateDocumentResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Store indicates an expected call of Store.
|
||||
func (mr *MockKnowledgeMockRecorder) Store(ctx, document any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockKnowledge)(nil).Store), ctx, document)
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossmessage
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
41
backend/crossdomain/contract/modelmgr/modelmgr.go
Normal file
41
backend/crossdomain/contract/modelmgr/modelmgr.go
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
eino "github.com/cloudwego/eino/components/model"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
|
||||
type Manager interface {
|
||||
GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error)
|
||||
}
|
||||
|
||||
var defaultSVC Manager
|
||||
|
||||
func DefaultSVC() Manager {
|
||||
return defaultSVC
|
||||
}
|
||||
|
||||
func SetDefaultSVC(svc Manager) {
|
||||
defaultSVC = svc
|
||||
}
|
||||
@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: model.go
|
||||
// Source: modelmgr.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination modelmock/model_mock.go --package mockmodel -source model.go
|
||||
// mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
|
||||
//
|
||||
|
||||
// Package mockmodel is a generated GoMock package.
|
||||
@ -14,7 +14,7 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
model "github.com/cloudwego/eino/components/model"
|
||||
model0 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
model0 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
modelmgr "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
@ -23,6 +23,7 @@ import (
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
@ -14,14 +14,18 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossplugin
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go
|
||||
type PluginService interface {
|
||||
MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*model.PluginInfo, err error)
|
||||
MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *model.MGetPluginLatestVersionResponse, err error)
|
||||
@ -35,6 +39,14 @@ type PluginService interface {
|
||||
PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (resp *model.PublishAPPPluginsResponse, err error)
|
||||
GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*model.PluginInfo, err error)
|
||||
MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*model.ToolInfo, err error)
|
||||
GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) (*model.ToolsInfoResponse, error)
|
||||
GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) (map[int64]InvokableTool, error)
|
||||
ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error)
|
||||
}
|
||||
|
||||
type InvokableTool interface {
|
||||
Info(ctx context.Context) (*schema.ToolInfo, error)
|
||||
PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflow.ExecuteConfig) (string, error)
|
||||
}
|
||||
|
||||
var defaultSVC PluginService
|
||||
324
backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go
Normal file
324
backend/crossdomain/contract/plugin/pluginmock/plugin_mock.go
Normal file
@ -0,0 +1,324 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: plugin.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go
|
||||
//
|
||||
|
||||
// Package pluginmock is a generated GoMock package.
|
||||
package pluginmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
schema "github.com/cloudwego/eino/schema"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflow "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
plugin0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPluginService is a mock of PluginService interface.
|
||||
type MockPluginService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPluginServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockPluginServiceMockRecorder is the mock recorder for MockPluginService.
|
||||
type MockPluginServiceMockRecorder struct {
|
||||
mock *MockPluginService
|
||||
}
|
||||
|
||||
// NewMockPluginService creates a new mock instance.
|
||||
func NewMockPluginService(ctrl *gomock.Controller) *MockPluginService {
|
||||
mock := &MockPluginService{ctrl: ctrl}
|
||||
mock.recorder = &MockPluginServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPluginService) EXPECT() *MockPluginServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BindAgentTools mocks base method.
|
||||
func (m *MockPluginService) BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BindAgentTools", ctx, agentID, toolIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BindAgentTools indicates an expected call of BindAgentTools.
|
||||
func (mr *MockPluginServiceMockRecorder) BindAgentTools(ctx, agentID, toolIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindAgentTools", reflect.TypeOf((*MockPluginService)(nil).BindAgentTools), ctx, agentID, toolIDs)
|
||||
}
|
||||
|
||||
// DeleteDraftPlugin mocks base method.
|
||||
func (m *MockPluginService) DeleteDraftPlugin(ctx context.Context, PluginID int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteDraftPlugin", ctx, PluginID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteDraftPlugin indicates an expected call of DeleteDraftPlugin.
|
||||
func (mr *MockPluginServiceMockRecorder) DeleteDraftPlugin(ctx, PluginID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDraftPlugin", reflect.TypeOf((*MockPluginService)(nil).DeleteDraftPlugin), ctx, PluginID)
|
||||
}
|
||||
|
||||
// DuplicateDraftAgentTools mocks base method.
|
||||
func (m *MockPluginService) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DuplicateDraftAgentTools", ctx, fromAgentID, toAgentID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DuplicateDraftAgentTools indicates an expected call of DuplicateDraftAgentTools.
|
||||
func (mr *MockPluginServiceMockRecorder) DuplicateDraftAgentTools(ctx, fromAgentID, toAgentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DuplicateDraftAgentTools", reflect.TypeOf((*MockPluginService)(nil).DuplicateDraftAgentTools), ctx, fromAgentID, toAgentID)
|
||||
}
|
||||
|
||||
// ExecutePlugin mocks base method.
|
||||
func (m *MockPluginService) ExecutePlugin(ctx context.Context, input map[string]any, pe *plugin.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExecutePlugin", ctx, input, pe, toolID, cfg)
|
||||
ret0, _ := ret[0].(map[string]any)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExecutePlugin indicates an expected call of ExecutePlugin.
|
||||
func (mr *MockPluginServiceMockRecorder) ExecutePlugin(ctx, input, pe, toolID, cfg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecutePlugin", reflect.TypeOf((*MockPluginService)(nil).ExecutePlugin), ctx, input, pe, toolID, cfg)
|
||||
}
|
||||
|
||||
// ExecuteTool mocks base method.
|
||||
func (m *MockPluginService) ExecuteTool(ctx context.Context, req *plugin.ExecuteToolRequest, opts ...plugin.ExecuteToolOpt) (*plugin.ExecuteToolResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, req}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "ExecuteTool", varargs...)
|
||||
ret0, _ := ret[0].(*plugin.ExecuteToolResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExecuteTool indicates an expected call of ExecuteTool.
|
||||
func (mr *MockPluginServiceMockRecorder) ExecuteTool(ctx, req any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, req}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTool", reflect.TypeOf((*MockPluginService)(nil).ExecuteTool), varargs...)
|
||||
}
|
||||
|
||||
// GetAPPAllPlugins mocks base method.
|
||||
func (m *MockPluginService) GetAPPAllPlugins(ctx context.Context, appID int64) ([]*plugin.PluginInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAPPAllPlugins", ctx, appID)
|
||||
ret0, _ := ret[0].([]*plugin.PluginInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAPPAllPlugins indicates an expected call of GetAPPAllPlugins.
|
||||
func (mr *MockPluginServiceMockRecorder) GetAPPAllPlugins(ctx, appID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPPAllPlugins", reflect.TypeOf((*MockPluginService)(nil).GetAPPAllPlugins), ctx, appID)
|
||||
}
|
||||
|
||||
// GetPluginInvokableTools mocks base method.
|
||||
func (m *MockPluginService) GetPluginInvokableTools(ctx context.Context, req *plugin.ToolsInvokableRequest) (map[int64]plugin0.InvokableTool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPluginInvokableTools", ctx, req)
|
||||
ret0, _ := ret[0].(map[int64]plugin0.InvokableTool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPluginInvokableTools indicates an expected call of GetPluginInvokableTools.
|
||||
func (mr *MockPluginServiceMockRecorder) GetPluginInvokableTools(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInvokableTools", reflect.TypeOf((*MockPluginService)(nil).GetPluginInvokableTools), ctx, req)
|
||||
}
|
||||
|
||||
// GetPluginToolsInfo mocks base method.
|
||||
func (m *MockPluginService) GetPluginToolsInfo(ctx context.Context, req *plugin.ToolsInfoRequest) (*plugin.ToolsInfoResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPluginToolsInfo", ctx, req)
|
||||
ret0, _ := ret[0].(*plugin.ToolsInfoResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPluginToolsInfo indicates an expected call of GetPluginToolsInfo.
|
||||
func (mr *MockPluginServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockPluginService)(nil).GetPluginToolsInfo), ctx, req)
|
||||
}
|
||||
|
||||
// MGetAgentTools mocks base method.
|
||||
func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *plugin.MGetAgentToolsRequest) ([]*plugin.ToolInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetAgentTools", ctx, req)
|
||||
ret0, _ := ret[0].([]*plugin.ToolInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetAgentTools indicates an expected call of MGetAgentTools.
|
||||
func (mr *MockPluginServiceMockRecorder) MGetAgentTools(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetAgentTools", reflect.TypeOf((*MockPluginService)(nil).MGetAgentTools), ctx, req)
|
||||
}
|
||||
|
||||
// MGetPluginLatestVersion mocks base method.
|
||||
func (m *MockPluginService) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (*plugin.MGetPluginLatestVersionResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetPluginLatestVersion", ctx, pluginIDs)
|
||||
ret0, _ := ret[0].(*plugin.MGetPluginLatestVersionResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetPluginLatestVersion indicates an expected call of MGetPluginLatestVersion.
|
||||
func (mr *MockPluginServiceMockRecorder) MGetPluginLatestVersion(ctx, pluginIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetPluginLatestVersion", reflect.TypeOf((*MockPluginService)(nil).MGetPluginLatestVersion), ctx, pluginIDs)
|
||||
}
|
||||
|
||||
// MGetVersionPlugins mocks base method.
|
||||
func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []plugin.VersionPlugin) ([]*plugin.PluginInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetVersionPlugins", ctx, versionPlugins)
|
||||
ret0, _ := ret[0].([]*plugin.PluginInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetVersionPlugins indicates an expected call of MGetVersionPlugins.
|
||||
func (mr *MockPluginServiceMockRecorder) MGetVersionPlugins(ctx, versionPlugins any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionPlugins", reflect.TypeOf((*MockPluginService)(nil).MGetVersionPlugins), ctx, versionPlugins)
|
||||
}
|
||||
|
||||
// MGetVersionTools mocks base method.
|
||||
func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []plugin.VersionTool) ([]*plugin.ToolInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetVersionTools", ctx, versionTools)
|
||||
ret0, _ := ret[0].([]*plugin.ToolInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetVersionTools indicates an expected call of MGetVersionTools.
|
||||
func (mr *MockPluginServiceMockRecorder) MGetVersionTools(ctx, versionTools any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionTools", reflect.TypeOf((*MockPluginService)(nil).MGetVersionTools), ctx, versionTools)
|
||||
}
|
||||
|
||||
// PublishAPPPlugins mocks base method.
|
||||
func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *plugin.PublishAPPPluginsRequest) (*plugin.PublishAPPPluginsResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishAPPPlugins", ctx, req)
|
||||
ret0, _ := ret[0].(*plugin.PublishAPPPluginsResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PublishAPPPlugins indicates an expected call of PublishAPPPlugins.
|
||||
func (mr *MockPluginServiceMockRecorder) PublishAPPPlugins(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAPPPlugins", reflect.TypeOf((*MockPluginService)(nil).PublishAPPPlugins), ctx, req)
|
||||
}
|
||||
|
||||
// PublishAgentTools mocks base method.
|
||||
func (m *MockPluginService) PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishAgentTools", ctx, agentID, agentVersion)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PublishAgentTools indicates an expected call of PublishAgentTools.
|
||||
func (mr *MockPluginServiceMockRecorder) PublishAgentTools(ctx, agentID, agentVersion any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAgentTools", reflect.TypeOf((*MockPluginService)(nil).PublishAgentTools), ctx, agentID, agentVersion)
|
||||
}
|
||||
|
||||
// PublishPlugin mocks base method.
|
||||
func (m *MockPluginService) PublishPlugin(ctx context.Context, req *plugin.PublishPluginRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishPlugin", ctx, req)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PublishPlugin indicates an expected call of PublishPlugin.
|
||||
func (mr *MockPluginServiceMockRecorder) PublishPlugin(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishPlugin", reflect.TypeOf((*MockPluginService)(nil).PublishPlugin), ctx, req)
|
||||
}
|
||||
|
||||
// MockInvokableTool is a mock of InvokableTool interface.
|
||||
type MockInvokableTool struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockInvokableToolMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockInvokableToolMockRecorder is the mock recorder for MockInvokableTool.
|
||||
type MockInvokableToolMockRecorder struct {
|
||||
mock *MockInvokableTool
|
||||
}
|
||||
|
||||
// NewMockInvokableTool creates a new mock instance.
|
||||
func NewMockInvokableTool(ctrl *gomock.Controller) *MockInvokableTool {
|
||||
mock := &MockInvokableTool{ctrl: ctrl}
|
||||
mock.recorder = &MockInvokableToolMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockInvokableTool) EXPECT() *MockInvokableToolMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Info mocks base method.
|
||||
func (m *MockInvokableTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Info", ctx)
|
||||
ret0, _ := ret[0].(*schema.ToolInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Info indicates an expected call of Info.
|
||||
func (mr *MockInvokableToolMockRecorder) Info(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockInvokableTool)(nil).Info), ctx)
|
||||
}
|
||||
|
||||
// PluginInvoke mocks base method.
|
||||
func (m *MockInvokableTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflow.ExecuteConfig) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PluginInvoke", ctx, argumentsInJSON, cfg)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PluginInvoke indicates an expected call of PluginInvoke.
|
||||
func (mr *MockInvokableToolMockRecorder) PluginInvoke(ctx, argumentsInJSON, cfg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PluginInvoke", reflect.TypeOf((*MockInvokableTool)(nil).PluginInvoke), ctx, argumentsInJSON, cfg)
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crosssearch
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -14,19 +14,25 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package crossvariables
|
||||
package variables
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
|
||||
)
|
||||
|
||||
// TODO (@fanlv): Parameter references need to be modified.
|
||||
type Variables interface {
|
||||
GetVariableInstance(ctx context.Context, e *variables.UserVariableMeta, keywords []string) ([]*kvmemory.KVItem, error)
|
||||
SetVariableInstance(ctx context.Context, e *variables.UserVariableMeta, items []*kvmemory.KVItem) ([]string, error)
|
||||
DecryptSysUUIDKey(ctx context.Context, encryptSysUUIDKey string) *variables.UserVariableMeta
|
||||
GetVariableChannelInstance(ctx context.Context, e *variables.UserVariableMeta, keywords []string, varChannel *project_memory.VariableChannel) ([]*kvmemory.KVItem, error)
|
||||
GetProjectVariablesMeta(ctx context.Context, projectID, version string) (*entity.VariablesMeta, error)
|
||||
GetAgentVariableMeta(ctx context.Context, agentID int64, version string) (*entity.VariablesMeta, error)
|
||||
}
|
||||
|
||||
var defaultSVC Variables
|
||||
@ -23,28 +23,28 @@ import (
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
// TODO (@fanlv): Parameter references need to be modified.
|
||||
type Workflow interface {
|
||||
WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]workflow.ToolFromWorkflow, error)
|
||||
DeleteWorkflow(ctx context.Context, id int64) error
|
||||
PublishWorkflow(ctx context.Context, info *vo.PublishPolicy) (err error)
|
||||
WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterruptEvent, resumeData string,
|
||||
allInterruptEvents map[string]*workflowEntity.ToolInterruptEvent) einoCompose.Option
|
||||
ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error)
|
||||
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
|
||||
SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
|
||||
WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option
|
||||
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
|
||||
SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
|
||||
WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option
|
||||
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func())
|
||||
}
|
||||
|
||||
type ExecuteConfig = vo.ExecuteConfig
|
||||
type ExecuteMode = vo.ExecuteMode
|
||||
type ExecuteConfig = workflowModel.ExecuteConfig
|
||||
type ExecuteMode = workflowModel.ExecuteMode
|
||||
type NodeType = entity.NodeType
|
||||
|
||||
type WorkflowMessage = entity.Message
|
||||
@ -59,14 +59,14 @@ const (
|
||||
ExecuteModeNodeDebug ExecuteMode = "node_debug"
|
||||
)
|
||||
|
||||
type TaskType = vo.TaskType
|
||||
type TaskType = workflowModel.TaskType
|
||||
|
||||
const (
|
||||
TaskTypeForeground TaskType = "foreground"
|
||||
TaskTypeBackground TaskType = "background"
|
||||
)
|
||||
|
||||
type BizType = vo.BizType
|
||||
type BizType = workflowModel.BizType
|
||||
|
||||
const (
|
||||
BizTypeAgent BizType = "agent"
|
||||
@ -19,7 +19,7 @@ package agentrun
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
|
||||
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
|
||||
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"context"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/connector"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
|
||||
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
|
||||
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"context"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ package crossuser
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
|
||||
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/user/service"
|
||||
)
|
||||
|
||||
@ -18,10 +18,20 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
var defaultSVC crossdatabase.Database
|
||||
@ -65,3 +75,406 @@ func (c *databaseImpl) MGetDatabase(ctx context.Context, req *model.MGetDatabase
|
||||
func (c *databaseImpl) GetAllDatabaseByAppID(ctx context.Context, req *model.GetAllDatabaseByAppIDRequest) (*model.GetAllDatabaseByAppIDResponse, error) {
|
||||
return c.DomainSVC.GetAllDatabaseByAppID(ctx, req)
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Execute(ctx context.Context, request *model.CustomSQLRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Custom,
|
||||
SQL: &request.SQL,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SQLParams = make([]*model.SQLParamVal, 0, len(request.Params))
|
||||
for i := range request.Params {
|
||||
param := request.Params[i]
|
||||
req.SQLParams = append(req.SQLParams, &model.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: ¶m.Value,
|
||||
ISNull: param.IsNull,
|
||||
})
|
||||
}
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if rows affected is nil use 0 instead
|
||||
if response.RowsAffected == nil {
|
||||
response.RowsAffected = ptr.Of(int64(0))
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Delete(ctx context.Context, request *model.DeleteRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Delete,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Query(ctx context.Context, request *model.QueryRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Select,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SelectFieldList = &model.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
|
||||
for i := range request.SelectFields {
|
||||
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
|
||||
}
|
||||
|
||||
req.OrderByList = make([]model.OrderBy, 0)
|
||||
for i := range request.OrderClauses {
|
||||
clause := request.OrderClauses[i]
|
||||
req.OrderByList = append(req.OrderByList, model.OrderBy{
|
||||
Field: clause.FieldID,
|
||||
Direction: toOrderDirection(clause.IsAsc),
|
||||
})
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
req.Limit = &limit
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Update(ctx context.Context, request *model.UpdateRequest) (*model.Response, error) {
|
||||
|
||||
var (
|
||||
err error
|
||||
condition *model.ComplexCondition
|
||||
params []*model.SQLParamVal
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Update,
|
||||
SQLParams: make([]*model.SQLParamVal, 0),
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.UserID = conv.Int64ToStr(*uid)
|
||||
}
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
condition, params, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Condition = condition
|
||||
req.SQLParams = append(req.SQLParams, params...)
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) Insert(ctx context.Context, request *model.InsertRequest) (*model.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: model.OperateType_Insert,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *databaseImpl) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
|
||||
resp, err := d.DomainSVC.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return resp.Database.ID, nil
|
||||
}
|
||||
|
||||
func buildComplexCondition(conditionGroup *model.ConditionGroup) (*model.ComplexCondition, []*model.SQLParamVal, error) {
|
||||
condition := &model.ComplexCondition{}
|
||||
logic, err := toLogic(conditionGroup.Relation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Logic = logic
|
||||
|
||||
params := make([]*model.SQLParamVal, 0)
|
||||
for i := range conditionGroup.Conditions {
|
||||
var (
|
||||
nCond = conditionGroup.Conditions[i]
|
||||
vals []*model.SQLParamVal
|
||||
dCond = &model.Condition{
|
||||
Left: nCond.Left,
|
||||
}
|
||||
)
|
||||
opt, err := toOperation(nCond.Operator)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dCond.Operation = opt
|
||||
|
||||
if isNullOrNotNull(opt) {
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
continue
|
||||
}
|
||||
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
|
||||
params = append(params, vals...)
|
||||
|
||||
}
|
||||
return condition, params, nil
|
||||
}
|
||||
|
||||
func toMapStringAny(m map[string]string) map[string]any {
|
||||
ret := make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
ret[k] = v
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func toOperation(operator model.Operator) (model.Operation, error) {
|
||||
switch operator {
|
||||
case model.OperatorEqual:
|
||||
return model.Operation_EQUAL, nil
|
||||
case model.OperatorNotEqual:
|
||||
return model.Operation_NOT_EQUAL, nil
|
||||
case model.OperatorGreater:
|
||||
return model.Operation_GREATER_THAN, nil
|
||||
case model.OperatorGreaterOrEqual:
|
||||
return model.Operation_GREATER_EQUAL, nil
|
||||
case model.OperatorLesser:
|
||||
return model.Operation_LESS_THAN, nil
|
||||
case model.OperatorLesserOrEqual:
|
||||
return model.Operation_LESS_EQUAL, nil
|
||||
case model.OperatorIn:
|
||||
return model.Operation_IN, nil
|
||||
case model.OperatorNotIn:
|
||||
return model.Operation_NOT_IN, nil
|
||||
case model.OperatorIsNotNull:
|
||||
return model.Operation_IS_NOT_NULL, nil
|
||||
case model.OperatorIsNull:
|
||||
return model.Operation_IS_NULL, nil
|
||||
case model.OperatorLike:
|
||||
return model.Operation_LIKE, nil
|
||||
case model.OperatorNotLike:
|
||||
return model.Operation_NOT_LIKE, nil
|
||||
default:
|
||||
return model.Operation(0), fmt.Errorf("invalid operator %v", operator)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRightValue(operator model.Operation, right any) (string, []*model.SQLParamVal, error) {
|
||||
|
||||
if isInOrNotIn(operator) {
|
||||
var (
|
||||
vals = make([]*model.SQLParamVal, 0)
|
||||
anyVals = make([]any, 0)
|
||||
commas = make([]string, 0, len(anyVals))
|
||||
)
|
||||
|
||||
anyVals = right.([]any)
|
||||
for i := range anyVals {
|
||||
v := cast.ToString(anyVals[i])
|
||||
vals = append(vals, &model.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
|
||||
commas = append(commas, "?")
|
||||
}
|
||||
value := "(" + strings.Join(commas, ",") + ")"
|
||||
return value, vals, nil
|
||||
}
|
||||
|
||||
rightValue, err := cast.ToStringE(right)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if isLikeOrNotLike(operator) {
|
||||
var (
|
||||
value = "?"
|
||||
v = "%s" + rightValue + "%s"
|
||||
)
|
||||
return value, []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
|
||||
}
|
||||
|
||||
return "?", []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
|
||||
}
|
||||
|
||||
func resolveUpsertRow(fields map[string]any) ([]*model.UpsertRow, []*model.SQLParamVal, error) {
|
||||
upsertRow := &model.UpsertRow{Records: make([]*model.Record, 0, len(fields))}
|
||||
params := make([]*model.SQLParamVal, 0)
|
||||
for key, value := range fields {
|
||||
val, err := cast.ToStringE(value)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
record := &model.Record{
|
||||
FieldId: key,
|
||||
FieldValue: "?",
|
||||
}
|
||||
upsertRow.Records = append(upsertRow.Records, record)
|
||||
params = append(params, &model.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: &val,
|
||||
})
|
||||
}
|
||||
return []*model.UpsertRow{upsertRow}, params, nil
|
||||
}
|
||||
|
||||
func isNullOrNotNull(opt model.Operation) bool {
|
||||
return opt == model.Operation_IS_NOT_NULL || opt == model.Operation_IS_NULL
|
||||
}
|
||||
|
||||
func isLikeOrNotLike(opt model.Operation) bool {
|
||||
return opt == model.Operation_LIKE || opt == model.Operation_NOT_LIKE
|
||||
}
|
||||
|
||||
func isInOrNotIn(opt model.Operation) bool {
|
||||
return opt == model.Operation_IN || opt == model.Operation_NOT_IN
|
||||
}
|
||||
|
||||
func toOrderDirection(isAsc bool) table.SortDirection {
|
||||
if isAsc {
|
||||
return table.SortDirection_ASC
|
||||
}
|
||||
return table.SortDirection_Desc
|
||||
}
|
||||
|
||||
func toLogic(relation model.ClauseRelation) (model.Logic, error) {
|
||||
switch relation {
|
||||
case model.ClauseRelationOR:
|
||||
return model.Logic_Or, nil
|
||||
case model.ClauseRelationAND:
|
||||
return model.Logic_And, nil
|
||||
default:
|
||||
return model.Logic(0), fmt.Errorf("invalid relation %v", relation)
|
||||
}
|
||||
}
|
||||
|
||||
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *model.Response {
|
||||
objects := make([]model.Object, 0, len(response.Records))
|
||||
for i := range response.Records {
|
||||
objects = append(objects, response.Records[i])
|
||||
}
|
||||
return &model.Response{
|
||||
Objects: objects,
|
||||
RowNumber: response.RowsAffected,
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
|
||||
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/datacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/datacopy/service"
|
||||
)
|
||||
|
||||
@ -18,10 +18,18 @@ package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
var defaultSVC crossknowledge.Knowledge
|
||||
@ -57,3 +65,120 @@ func (i *impl) GetKnowledgeByID(ctx context.Context, request *model.GetKnowledge
|
||||
func (i *impl) MGetKnowledgeByID(ctx context.Context, request *model.MGetKnowledgeByIDRequest) (response *model.MGetKnowledgeByIDResponse, err error) {
|
||||
return i.DomainSVC.MGetKnowledgeByID(ctx, request)
|
||||
}
|
||||
|
||||
func (i *impl) Store(ctx context.Context, document *model.CreateDocumentRequest) (*model.CreateDocumentResponse, error) {
|
||||
var (
|
||||
ps *entity.ParsingStrategy
|
||||
cs = &entity.ChunkingStrategy{}
|
||||
)
|
||||
|
||||
if document.ParsingStrategy == nil {
|
||||
return nil, errors.New("document parsing strategy is required")
|
||||
}
|
||||
|
||||
if document.ChunkingStrategy == nil {
|
||||
return nil, errors.New("document chunking strategy is required")
|
||||
}
|
||||
|
||||
if document.ParsingStrategy.ParseMode == model.AccurateParseMode {
|
||||
ps = &entity.ParsingStrategy{}
|
||||
ps.ExtractImage = document.ParsingStrategy.ExtractImage
|
||||
ps.ExtractTable = document.ParsingStrategy.ExtractTable
|
||||
ps.ImageOCR = document.ParsingStrategy.ImageOCR
|
||||
}
|
||||
|
||||
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cs.ChunkType = chunkType
|
||||
cs.Separator = document.ChunkingStrategy.Separator
|
||||
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
|
||||
cs.Overlap = document.ChunkingStrategy.Overlap
|
||||
|
||||
req := &entity.Document{
|
||||
Info: knowledge.Info{
|
||||
Name: document.FileName,
|
||||
},
|
||||
KnowledgeID: document.KnowledgeID,
|
||||
Type: knowledge.DocumentTypeText,
|
||||
URL: document.FileURL,
|
||||
Source: entity.DocumentSourceLocal,
|
||||
ParsingStrategy: ps,
|
||||
ChunkingStrategy: cs,
|
||||
FileExtension: document.FileExtension,
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.Info.CreatorID = *uid
|
||||
}
|
||||
|
||||
response, err := i.DomainSVC.CreateDocument(ctx, &service.CreateDocumentRequest{
|
||||
Documents: []*entity.Document{req},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kCResponse := &model.CreateDocumentResponse{
|
||||
FileURL: document.FileURL,
|
||||
DocumentID: response.Documents[0].Info.ID,
|
||||
FileName: response.Documents[0].Info.Name,
|
||||
}
|
||||
|
||||
return kCResponse, nil
|
||||
}
|
||||
|
||||
func (i *impl) Delete(ctx context.Context, r *model.DeleteDocumentRequest) (*model.DeleteDocumentResponse, error) {
|
||||
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
|
||||
}
|
||||
|
||||
err = i.DomainSVC.DeleteDocument(ctx, &service.DeleteDocumentRequest{
|
||||
DocumentID: docID,
|
||||
})
|
||||
if err != nil {
|
||||
return &model.DeleteDocumentResponse{IsSuccess: false}, err
|
||||
}
|
||||
|
||||
return &model.DeleteDocumentResponse{IsSuccess: true}, nil
|
||||
}
|
||||
|
||||
func (i *impl) ListKnowledgeDetail(ctx context.Context, req *model.ListKnowledgeDetailRequest) (*model.ListKnowledgeDetailResponse, error) {
|
||||
response, err := i.DomainSVC.MGetKnowledgeByID(ctx, &service.MGetKnowledgeByIDRequest{
|
||||
KnowledgeIDs: req.KnowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &model.ListKnowledgeDetailResponse{
|
||||
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *model.KnowledgeDetail {
|
||||
return &model.KnowledgeDetail{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Description: a.Description,
|
||||
IconURL: a.IconURL,
|
||||
FormatType: int64(a.Type),
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func toChunkType(typ model.ChunkType) (parser.ChunkType, error) {
|
||||
switch typ {
|
||||
case model.ChunkTypeDefault:
|
||||
return parser.ChunkTypeDefault, nil
|
||||
case model.ChunkTypeCustom:
|
||||
return parser.ChunkTypeCustom, nil
|
||||
case model.ChunkTypeLeveled:
|
||||
return parser.ChunkTypeLeveled, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown chunk type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"context"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
|
||||
)
|
||||
|
||||
|
||||
@ -14,37 +14,38 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package model
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
eino "github.com/cloudwego/eino/components/model"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
chatmodel2 "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type ModelManager struct {
|
||||
type modelManager struct {
|
||||
modelMgr modelmgr.Manager
|
||||
factory chatmodel.Factory
|
||||
}
|
||||
|
||||
func NewModelManager(m modelmgr.Manager, f chatmodel.Factory) *ModelManager {
|
||||
func InitDomainService(m modelmgr.Manager, f chatmodel.Factory) crossmodelmgr.Manager {
|
||||
if f == nil {
|
||||
f = chatmodel2.NewDefaultFactory()
|
||||
}
|
||||
return &ModelManager{
|
||||
return &modelManager{
|
||||
modelMgr: m,
|
||||
factory: f,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelManager) GetModel(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
func (m *modelManager) GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error) {
|
||||
modelID := params.ModelType
|
||||
models, err := m.modelMgr.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
|
||||
IDs: []int64{modelID},
|
||||
@ -18,23 +18,46 @@ package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
var defaultSVC crossplugin.PluginService
|
||||
|
||||
type impl struct {
|
||||
DomainSVC plugin.PluginService
|
||||
tos storage.Storage
|
||||
}
|
||||
|
||||
func InitDomainService(c plugin.PluginService) crossplugin.PluginService {
|
||||
func InitDomainService(c plugin.PluginService, tos storage.Storage) crossplugin.PluginService {
|
||||
defaultSVC = &impl{
|
||||
DomainSVC: c,
|
||||
tos: tos,
|
||||
}
|
||||
|
||||
return defaultSVC
|
||||
@ -105,3 +128,479 @@ func (s *impl) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*mo
|
||||
|
||||
return plugins, nil
|
||||
}
|
||||
|
||||
type pluginInfo struct {
|
||||
*entity.PluginInfo
|
||||
LatestVersion *string
|
||||
}
|
||||
|
||||
func (s *impl) getPluginsWithTools(ctx context.Context, pluginEntity *model.PluginEntity, toolIDs []int64, isDraft bool) (
|
||||
_ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var pluginsInfo []*entity.PluginInfo
|
||||
var latestPluginInfo *entity.PluginInfo
|
||||
pluginID := pluginEntity.PluginID
|
||||
if isDraft {
|
||||
plugins, err := s.DomainSVC.MGetDraftPlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
|
||||
plugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
|
||||
} else {
|
||||
plugins, err := s.DomainSVC.MGetVersionPlugins(ctx, []entity.VersionPlugin{
|
||||
{PluginID: pluginID, Version: *pluginEntity.PluginVersion},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
|
||||
onlinePlugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, pi := range onlinePlugins {
|
||||
if pi.ID == pluginID {
|
||||
latestPluginInfo = pi
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pInfo *entity.PluginInfo
|
||||
for _, p := range pluginsInfo {
|
||||
if p.ID == pluginID {
|
||||
pInfo = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if pInfo == nil {
|
||||
return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10)))
|
||||
}
|
||||
|
||||
if isDraft {
|
||||
tools, err := s.DomainSVC.MGetDraftTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
|
||||
tools, err := s.DomainSVC.MGetOnlineTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
} else {
|
||||
eVersionTools := slices.Transform(toolIDs, func(tid int64) entity.VersionTool {
|
||||
return entity.VersionTool{
|
||||
ToolID: tid,
|
||||
Version: *pluginEntity.PluginVersion,
|
||||
}
|
||||
})
|
||||
tools, err := s.DomainSVC.MGetVersionTools(ctx, eVersionTools)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
}
|
||||
|
||||
if latestPluginInfo != nil {
|
||||
return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil
|
||||
}
|
||||
|
||||
return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil
|
||||
}
|
||||
|
||||
func (s *impl) GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) (
|
||||
_ *model.ToolsInfoResponse, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var toolsInfo []*entity.ToolInfo
|
||||
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
|
||||
pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, err := s.tos.GetObjectUrl(ctx, pInfo.GetIconURI())
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrTOSError, err)
|
||||
}
|
||||
|
||||
response := &model.ToolsInfoResponse{
|
||||
PluginID: pInfo.ID,
|
||||
SpaceID: pInfo.SpaceID,
|
||||
Version: pInfo.GetVersion(),
|
||||
PluginName: pInfo.GetName(),
|
||||
Description: pInfo.GetDesc(),
|
||||
IconURL: url,
|
||||
PluginType: int64(pInfo.PluginType),
|
||||
ToolInfoList: make(map[int64]model.ToolInfoW),
|
||||
LatestVersion: pInfo.LatestVersion,
|
||||
IsOfficial: pInfo.IsOfficial(),
|
||||
AppID: pInfo.GetAPPID(),
|
||||
}
|
||||
|
||||
for _, tf := range toolsInfo {
|
||||
inputs, err := tf.ToReqAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs, err := tf.ToRespAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolExample := pInfo.GetToolExample(ctx, tf.GetName())
|
||||
|
||||
var (
|
||||
requestExample string
|
||||
responseExample string
|
||||
)
|
||||
if toolExample != nil {
|
||||
requestExample = toolExample.RequestExample
|
||||
responseExample = toolExample.ResponseExample
|
||||
}
|
||||
|
||||
response.ToolInfoList[tf.ID] = model.ToolInfoW{
|
||||
ToolID: tf.ID,
|
||||
ToolName: tf.GetName(),
|
||||
Inputs: slices.Transform(inputs, toWorkflowAPIParameter),
|
||||
Outputs: slices.Transform(outputs, toWorkflowAPIParameter),
|
||||
Description: tf.GetDesc(),
|
||||
DebugExample: &model.DebugExample{
|
||||
ReqExample: requestExample,
|
||||
RespExample: responseExample,
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *impl) GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) (
|
||||
_ map[int64]crossplugin.InvokableTool, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var toolsInfo []*entity.ToolInfo
|
||||
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
|
||||
pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{
|
||||
PluginID: req.PluginEntity.PluginID,
|
||||
PluginVersion: req.PluginEntity.PluginVersion,
|
||||
}, maps.Keys(req.ToolsInvokableInfo), isDraft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[int64]crossplugin.InvokableTool{}
|
||||
for _, tf := range toolsInfo {
|
||||
tl := &pluginInvokeTool{
|
||||
pluginEntity: model.PluginEntity{
|
||||
PluginID: pInfo.ID,
|
||||
PluginVersion: pInfo.Version,
|
||||
},
|
||||
client: s.DomainSVC,
|
||||
toolInfo: tf,
|
||||
IsDraft: isDraft,
|
||||
}
|
||||
|
||||
if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) {
|
||||
reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter)
|
||||
respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter)
|
||||
|
||||
tl.toolOperation, err = pluginutil.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tl.toolOperation.OperationID = tf.Operation.OperationID
|
||||
tl.toolOperation.Summary = tf.Operation.Summary
|
||||
}
|
||||
|
||||
result[tf.ID] = tl
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *impl) ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity,
|
||||
toolID int64, cfg workflowModel.ExecuteConfig) (map[string]any, error) {
|
||||
args, err := sonic.MarshalString(input)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
|
||||
var uID string
|
||||
if cfg.AgentID != nil {
|
||||
uID = cfg.ConnectorUID
|
||||
} else {
|
||||
uID = conv.Int64ToStr(cfg.Operator)
|
||||
}
|
||||
|
||||
req := &service.ExecuteToolRequest{
|
||||
UserID: uID,
|
||||
PluginID: pe.PluginID,
|
||||
ToolID: toolID,
|
||||
ExecScene: model.ExecSceneOfWorkflow,
|
||||
ArgumentsInJson: args,
|
||||
ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0",
|
||||
}
|
||||
execOpts := []entity.ExecuteToolOpt{
|
||||
model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
}
|
||||
|
||||
if pe.PluginVersion != nil {
|
||||
execOpts = append(execOpts, model.WithToolVersion(*pe.PluginVersion))
|
||||
}
|
||||
|
||||
r, err := s.DomainSVC.ExecuteTool(ctx, req, execOpts...)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
pluginTIE, ok := extra.(*model.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
|
||||
}
|
||||
|
||||
var eventType workflow3.EventType
|
||||
switch pluginTIE.Event {
|
||||
case model.InterruptEventTypeOfToolNeedOAuth:
|
||||
eventType = workflow3.EventType_WorkflowOauthPlugin
|
||||
default:
|
||||
return nil, vo.WrapError(errno.ErrPluginAPIErr,
|
||||
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
|
||||
}
|
||||
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
ie := &entity2.InterruptEvent{
|
||||
ID: id,
|
||||
InterruptData: pluginTIE.ToolNeedOAuth.Message,
|
||||
EventType: eventType,
|
||||
}
|
||||
|
||||
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
|
||||
interruptData := ie.InterruptData
|
||||
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var output map[string]any
|
||||
err = sonic.UnmarshalString(r.TrimmedResp, &output)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
type pluginInvokeTool struct {
|
||||
pluginEntity model.PluginEntity
|
||||
client service.PluginService
|
||||
toolInfo *entity.ToolInfo
|
||||
toolOperation *openapi3.Operation
|
||||
IsDraft bool
|
||||
}
|
||||
|
||||
func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var parameterInfo map[string]*schema.ParameterInfo
|
||||
if p.toolOperation != nil {
|
||||
parameterInfo, err = model.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx)
|
||||
} else {
|
||||
parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.ToolInfo{
|
||||
Name: p.toolInfo.GetName(),
|
||||
Desc: p.toolInfo.GetDesc(),
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflowModel.ExecuteConfig) (string, error) {
|
||||
req := &service.ExecuteToolRequest{
|
||||
UserID: conv.Int64ToStr(cfg.Operator),
|
||||
PluginID: p.pluginEntity.PluginID,
|
||||
ToolID: p.toolInfo.ID,
|
||||
ExecScene: model.ExecSceneOfWorkflow,
|
||||
ArgumentsInJson: argumentsInJSON,
|
||||
ExecDraftTool: p.IsDraft,
|
||||
}
|
||||
execOpts := []entity.ExecuteToolOpt{
|
||||
model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
}
|
||||
|
||||
if p.pluginEntity.PluginVersion != nil {
|
||||
execOpts = append(execOpts, model.WithToolVersion(*p.pluginEntity.PluginVersion))
|
||||
}
|
||||
|
||||
if p.toolOperation != nil {
|
||||
execOpts = append(execOpts, model.WithOpenapiOperation(model.NewOpenapi3Operation(p.toolOperation)))
|
||||
}
|
||||
|
||||
r, err := p.client.ExecuteTool(ctx, req, execOpts...)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
pluginTIE, ok := extra.(*model.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
|
||||
}
|
||||
|
||||
var eventType workflow3.EventType
|
||||
switch pluginTIE.Event {
|
||||
case model.InterruptEventTypeOfToolNeedOAuth:
|
||||
eventType = workflow3.EventType_WorkflowOauthPlugin
|
||||
default:
|
||||
return "", vo.WrapError(errno.ErrPluginAPIErr,
|
||||
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
|
||||
}
|
||||
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return "", vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
ie := &entity2.InterruptEvent{
|
||||
ID: id,
|
||||
InterruptData: pluginTIE.ToolNeedOAuth.Message,
|
||||
EventType: eventType,
|
||||
}
|
||||
|
||||
tie := &entity2.ToolInterruptEvent{
|
||||
ToolCallID: compose.GetToolCallID(ctx),
|
||||
ToolName: p.toolInfo.GetName(),
|
||||
InterruptEvent: ie,
|
||||
}
|
||||
|
||||
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
|
||||
_ = tie
|
||||
interruptData := ie.InterruptData
|
||||
return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return r.TrimmedResp, nil
|
||||
}
|
||||
|
||||
func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter {
|
||||
if parameter == nil {
|
||||
return nil
|
||||
}
|
||||
p := &common.APIParameter{
|
||||
ID: parameter.ID,
|
||||
Name: parameter.Name,
|
||||
Desc: parameter.Desc,
|
||||
Type: common.ParameterType(parameter.Type),
|
||||
Location: common.ParameterLocation(parameter.Location),
|
||||
IsRequired: parameter.IsRequired,
|
||||
GlobalDefault: parameter.GlobalDefault,
|
||||
GlobalDisable: parameter.GlobalDisable,
|
||||
LocalDefault: parameter.LocalDefault,
|
||||
LocalDisable: parameter.LocalDisable,
|
||||
VariableRef: parameter.VariableRef,
|
||||
}
|
||||
if parameter.SubType != nil {
|
||||
p.SubType = ptr.Of(common.ParameterType(*parameter.SubType))
|
||||
}
|
||||
|
||||
if parameter.DefaultParamSource != nil {
|
||||
p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource))
|
||||
}
|
||||
if parameter.AssistType != nil {
|
||||
p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType))
|
||||
}
|
||||
|
||||
if len(parameter.SubParameters) > 0 {
|
||||
p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters))
|
||||
for _, subParam := range parameter.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam))
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter {
|
||||
if parameter == nil {
|
||||
return nil
|
||||
}
|
||||
p := &workflow3.APIParameter{
|
||||
ID: parameter.ID,
|
||||
Name: parameter.Name,
|
||||
Desc: parameter.Desc,
|
||||
Type: workflow3.ParameterType(parameter.Type),
|
||||
Location: workflow3.ParameterLocation(parameter.Location),
|
||||
IsRequired: parameter.IsRequired,
|
||||
GlobalDefault: parameter.GlobalDefault,
|
||||
GlobalDisable: parameter.GlobalDisable,
|
||||
LocalDefault: parameter.LocalDefault,
|
||||
LocalDisable: parameter.LocalDisable,
|
||||
VariableRef: parameter.VariableRef,
|
||||
}
|
||||
if parameter.SubType != nil {
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType))
|
||||
}
|
||||
|
||||
if parameter.DefaultParamSource != nil {
|
||||
p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource))
|
||||
}
|
||||
if parameter.AssistType != nil {
|
||||
p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType))
|
||||
}
|
||||
|
||||
// Check if it's an array that needs unwrapping.
|
||||
if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" {
|
||||
arrayItem := parameter.SubParameters[0]
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type))
|
||||
// If the "[Array Item]" is an object, its sub-parameters become the array's sub-parameters.
|
||||
if arrayItem.Type == common.ParameterType_Object {
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters))
|
||||
for _, subParam := range arrayItem.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
}
|
||||
} else {
|
||||
// The array's SubType is the Type of the "[Array Item]".
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, 1)
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem))
|
||||
p.SubParameters[0].Name = "" // Remove the "[Array Item]" name.
|
||||
}
|
||||
} else if len(parameter.SubParameters) > 0 { // A simple object or a non-wrapped array.
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters))
|
||||
for _, subParam := range parameter.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"context"
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
)
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
|
||||
@ -21,8 +21,8 @@ import (
|
||||
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
|
||||
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
|
||||
)
|
||||
@ -65,3 +65,16 @@ func (s *impl) DecryptSysUUIDKey(ctx context.Context, encryptSysUUIDKey string)
|
||||
ConnectorID: m.ConnectorID,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *impl) GetVariableChannelInstance(ctx context.Context, e *model.UserVariableMeta, keywords []string, varChannel *project_memory.VariableChannel) ([]*kvmemory.KVItem, error) {
|
||||
m := entity.NewUserVariableMeta(e)
|
||||
return s.DomainSVC.GetVariableChannelInstance(ctx, m, keywords, varChannel)
|
||||
}
|
||||
|
||||
func (s *impl) GetProjectVariablesMeta(ctx context.Context, projectID string, version string) (*entity.VariablesMeta, error) {
|
||||
return s.DomainSVC.GetProjectVariablesMeta(ctx, projectID, version)
|
||||
}
|
||||
|
||||
func (s *impl) GetAgentVariableMeta(ctx context.Context, agentID int64, version string) (*entity.VariablesMeta, error) {
|
||||
return s.DomainSVC.GetAgentVariableMeta(ctx, agentID, version)
|
||||
}
|
||||
|
||||
@ -23,7 +23,8 @@ import (
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
@ -50,16 +51,6 @@ func (i *impl) WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy
|
||||
return i.DomainSVC.WorkflowAsModelTool(ctx, policies)
|
||||
}
|
||||
|
||||
func (i *impl) PublishWorkflow(ctx context.Context, info *vo.PublishPolicy) (err error) {
|
||||
return i.DomainSVC.Publish(ctx, info)
|
||||
}
|
||||
|
||||
func (i *impl) DeleteWorkflow(ctx context.Context, id int64) error {
|
||||
return i.DomainSVC.Delete(ctx, &vo.DeletePolicy{
|
||||
ID: ptr.Of(id),
|
||||
})
|
||||
}
|
||||
|
||||
func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *vo.ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error) {
|
||||
return i.DomainSVC.ReleaseApplicationWorkflows(ctx, appID, config)
|
||||
}
|
||||
@ -67,15 +58,15 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
func (i *impl) WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterruptEvent, resumeData string, allInterruptEvents map[string]*workflowEntity.ToolInterruptEvent) einoCompose.Option {
|
||||
return i.DomainSVC.WithResumeToolWorkflow(resumingEvent, resumeData, allInterruptEvents)
|
||||
}
|
||||
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
|
||||
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
|
||||
return i.DomainSVC.SyncExecute(ctx, config, input)
|
||||
}
|
||||
|
||||
func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option {
|
||||
func (i *impl) WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option {
|
||||
return i.DomainSVC.WithExecuteConfig(cfg)
|
||||
}
|
||||
|
||||
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
|
||||
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
|
||||
return i.DomainSVC.WithMessagePipe()
|
||||
}
|
||||
|
||||
|
||||
@ -1,447 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
type DatabaseRepository struct {
|
||||
client service.Database
|
||||
}
|
||||
|
||||
func NewDatabaseRepository(client service.Database) *DatabaseRepository {
|
||||
return &DatabaseRepository{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Execute(ctx context.Context, request *nodedatabase.CustomSQLRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Custom,
|
||||
SQL: &request.SQL,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SQLParams = make([]*database.SQLParamVal, 0, len(request.Params))
|
||||
for i := range request.Params {
|
||||
param := request.Params[i]
|
||||
req.SQLParams = append(req.SQLParams, &database.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: ¶m.Value,
|
||||
ISNull: param.IsNull,
|
||||
})
|
||||
}
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if rows affected is nil use 0 instead
|
||||
if response.RowsAffected == nil {
|
||||
response.RowsAffected = ptr.Of(int64(0))
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Delete(ctx context.Context, request *nodedatabase.DeleteRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Delete,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Query(ctx context.Context, request *nodedatabase.QueryRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Select,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.SelectFieldList = &database.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
|
||||
for i := range request.SelectFields {
|
||||
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
|
||||
}
|
||||
|
||||
req.OrderByList = make([]database.OrderBy, 0)
|
||||
for i := range request.OrderClauses {
|
||||
clause := request.OrderClauses[i]
|
||||
req.OrderByList = append(req.OrderByList, database.OrderBy{
|
||||
Field: clause.FieldID,
|
||||
Direction: toOrderDirection(clause.IsAsc),
|
||||
})
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
req.Limit = &limit
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Update(ctx context.Context, request *nodedatabase.UpdateRequest) (*nodedatabase.Response, error) {
|
||||
|
||||
var (
|
||||
err error
|
||||
condition *database.ComplexCondition
|
||||
params []*database.SQLParamVal
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Update,
|
||||
SQLParams: make([]*database.SQLParamVal, 0),
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.UserID = conv.Int64ToStr(*uid)
|
||||
}
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.ConditionGroup != nil {
|
||||
condition, params, err = buildComplexCondition(request.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Condition = condition
|
||||
req.SQLParams = append(req.SQLParams, params...)
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) Insert(ctx context.Context, request *nodedatabase.InsertRequest) (*nodedatabase.Response, error) {
|
||||
var (
|
||||
err error
|
||||
databaseInfoID = request.DatabaseInfoID
|
||||
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
|
||||
)
|
||||
|
||||
if request.IsDebugRun {
|
||||
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &service.ExecuteSQLRequest{
|
||||
DatabaseID: databaseInfoID,
|
||||
OperateType: database.OperateType_Insert,
|
||||
TableType: tableType,
|
||||
UserID: request.UserID,
|
||||
ConnectorID: ptr.Of(request.ConnectorID),
|
||||
}
|
||||
|
||||
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := d.client.ExecuteSQL(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toNodeDateBaseResponse(response), nil
|
||||
}
|
||||
|
||||
func (d *DatabaseRepository) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
|
||||
resp, err := d.client.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return resp.Database.ID, nil
|
||||
|
||||
}
|
||||
|
||||
func buildComplexCondition(conditionGroup *nodedatabase.ConditionGroup) (*database.ComplexCondition, []*database.SQLParamVal, error) {
|
||||
condition := &database.ComplexCondition{}
|
||||
logic, err := toLogic(conditionGroup.Relation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Logic = logic
|
||||
|
||||
params := make([]*database.SQLParamVal, 0)
|
||||
for i := range conditionGroup.Conditions {
|
||||
var (
|
||||
nCond = conditionGroup.Conditions[i]
|
||||
vals []*database.SQLParamVal
|
||||
dCond = &database.Condition{
|
||||
Left: nCond.Left,
|
||||
}
|
||||
)
|
||||
opt, err := toOperation(nCond.Operator)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dCond.Operation = opt
|
||||
|
||||
if isNullOrNotNull(opt) {
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
continue
|
||||
}
|
||||
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
condition.Conditions = append(condition.Conditions, dCond)
|
||||
|
||||
params = append(params, vals...)
|
||||
|
||||
}
|
||||
return condition, params, nil
|
||||
}
|
||||
|
||||
func toMapStringAny(m map[string]string) map[string]any {
|
||||
ret := make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
ret[k] = v
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func toOperation(operator nodedatabase.Operator) (database.Operation, error) {
|
||||
switch operator {
|
||||
case nodedatabase.OperatorEqual:
|
||||
return database.Operation_EQUAL, nil
|
||||
case nodedatabase.OperatorNotEqual:
|
||||
return database.Operation_NOT_EQUAL, nil
|
||||
case nodedatabase.OperatorGreater:
|
||||
return database.Operation_GREATER_THAN, nil
|
||||
case nodedatabase.OperatorGreaterOrEqual:
|
||||
return database.Operation_GREATER_EQUAL, nil
|
||||
case nodedatabase.OperatorLesser:
|
||||
return database.Operation_LESS_THAN, nil
|
||||
case nodedatabase.OperatorLesserOrEqual:
|
||||
return database.Operation_LESS_EQUAL, nil
|
||||
case nodedatabase.OperatorIn:
|
||||
return database.Operation_IN, nil
|
||||
case nodedatabase.OperatorNotIn:
|
||||
return database.Operation_NOT_IN, nil
|
||||
case nodedatabase.OperatorIsNotNull:
|
||||
return database.Operation_IS_NOT_NULL, nil
|
||||
case nodedatabase.OperatorIsNull:
|
||||
return database.Operation_IS_NULL, nil
|
||||
case nodedatabase.OperatorLike:
|
||||
return database.Operation_LIKE, nil
|
||||
case nodedatabase.OperatorNotLike:
|
||||
return database.Operation_NOT_LIKE, nil
|
||||
default:
|
||||
return database.Operation(0), fmt.Errorf("invalid operator %v", operator)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveRightValue(operator database.Operation, right any) (string, []*database.SQLParamVal, error) {
|
||||
|
||||
if isInOrNotIn(operator) {
|
||||
var (
|
||||
vals = make([]*database.SQLParamVal, 0)
|
||||
anyVals = make([]any, 0)
|
||||
commas = make([]string, 0, len(anyVals))
|
||||
)
|
||||
|
||||
anyVals = right.([]any)
|
||||
for i := range anyVals {
|
||||
v := cast.ToString(anyVals[i])
|
||||
vals = append(vals, &database.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
|
||||
commas = append(commas, "?")
|
||||
}
|
||||
value := "(" + strings.Join(commas, ",") + ")"
|
||||
return value, vals, nil
|
||||
}
|
||||
|
||||
rightValue, err := cast.ToStringE(right)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if isLikeOrNotLike(operator) {
|
||||
var (
|
||||
value = "?"
|
||||
v = "%s" + rightValue + "%s"
|
||||
)
|
||||
return value, []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
|
||||
}
|
||||
|
||||
return "?", []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
|
||||
}
|
||||
|
||||
func resolveUpsertRow(fields map[string]any) ([]*database.UpsertRow, []*database.SQLParamVal, error) {
|
||||
upsertRow := &database.UpsertRow{Records: make([]*database.Record, 0, len(fields))}
|
||||
params := make([]*database.SQLParamVal, 0)
|
||||
for key, value := range fields {
|
||||
val, err := cast.ToStringE(value)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
record := &database.Record{
|
||||
FieldId: key,
|
||||
FieldValue: "?",
|
||||
}
|
||||
upsertRow.Records = append(upsertRow.Records, record)
|
||||
params = append(params, &database.SQLParamVal{
|
||||
ValueType: table.FieldItemType_Text,
|
||||
Value: &val,
|
||||
})
|
||||
}
|
||||
return []*database.UpsertRow{upsertRow}, params, nil
|
||||
}
|
||||
|
||||
func isNullOrNotNull(opt database.Operation) bool {
|
||||
return opt == database.Operation_IS_NOT_NULL || opt == database.Operation_IS_NULL
|
||||
}
|
||||
|
||||
func isLikeOrNotLike(opt database.Operation) bool {
|
||||
return opt == database.Operation_LIKE || opt == database.Operation_NOT_LIKE
|
||||
}
|
||||
|
||||
func isInOrNotIn(opt database.Operation) bool {
|
||||
return opt == database.Operation_IN || opt == database.Operation_NOT_IN
|
||||
}
|
||||
|
||||
func toOrderDirection(isAsc bool) table.SortDirection {
|
||||
if isAsc {
|
||||
return table.SortDirection_ASC
|
||||
}
|
||||
return table.SortDirection_Desc
|
||||
}
|
||||
|
||||
func toLogic(relation nodedatabase.ClauseRelation) (database.Logic, error) {
|
||||
switch relation {
|
||||
case nodedatabase.ClauseRelationOR:
|
||||
return database.Logic_Or, nil
|
||||
case nodedatabase.ClauseRelationAND:
|
||||
return database.Logic_And, nil
|
||||
default:
|
||||
return database.Logic(0), fmt.Errorf("invalid relation %v", relation)
|
||||
}
|
||||
}
|
||||
|
||||
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *nodedatabase.Response {
|
||||
objects := make([]nodedatabase.Object, 0, len(response.Records))
|
||||
for i := range response.Records {
|
||||
objects = append(objects, response.Records[i])
|
||||
}
|
||||
return &nodedatabase.Response{
|
||||
Objects: objects,
|
||||
RowNumber: response.RowsAffected,
|
||||
}
|
||||
}
|
||||
@ -1,224 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
mockDatabase "github.com/coze-dev/coze-studio/backend/internal/mock/domain/memory/database"
|
||||
)
|
||||
|
||||
func mockExecuteSQL(t *testing.T) func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
|
||||
return func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
|
||||
if request.OperateType == database.OperateType_Custom {
|
||||
assert.Equal(t, *request.SQL, "select * from table where v1=? and v2=?")
|
||||
rs := make([]string, 0)
|
||||
for idx := range request.SQLParams {
|
||||
rs = append(rs, *request.SQLParams[idx].Value)
|
||||
}
|
||||
assert.Equal(t, rs, []string{"1", "2"})
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if request.OperateType == database.OperateType_Select {
|
||||
sFields := []string{"v1", "v2", "v3", "v4"}
|
||||
assert.Equal(t, request.SelectFieldList.FieldID, sFields)
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
rowsAffected := int64(10)
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
|
||||
},
|
||||
RowsAffected: &rowsAffected,
|
||||
}, nil
|
||||
|
||||
}
|
||||
if request.OperateType == database.OperateType_Delete {
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_NOT_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
rowsAffected := int64(10)
|
||||
return &service.ExecuteSQLResponse{
|
||||
Records: []map[string]any{
|
||||
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
|
||||
},
|
||||
RowsAffected: &rowsAffected,
|
||||
}, nil
|
||||
}
|
||||
if request.OperateType == database.OperateType_Insert {
|
||||
records := request.UpsertRows[0].Records
|
||||
ret := map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "44aacc",
|
||||
}
|
||||
for idx := range records {
|
||||
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if request.OperateType == database.OperateType_Update {
|
||||
|
||||
records := request.UpsertRows[0].Records
|
||||
ret := map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "aabbcc",
|
||||
}
|
||||
for idx := range records {
|
||||
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
|
||||
}
|
||||
|
||||
request.SQLParams = request.SQLParams[len(records):]
|
||||
cond := request.Condition.Conditions[1] // in
|
||||
assert.Equal(t, "(?,?)", cond.Right)
|
||||
assert.Equal(t, database.Operation_IN, cond.Operation)
|
||||
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
|
||||
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
|
||||
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
|
||||
|
||||
}
|
||||
return &service.ExecuteSQLResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabase_Database(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockClient := mockDatabase.NewMockDatabase(ctrl)
|
||||
defer ctrl.Finish()
|
||||
ds := DatabaseRepository{
|
||||
client: mockClient,
|
||||
}
|
||||
mockClient.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(mockExecuteSQL(t)).AnyTimes()
|
||||
|
||||
t.Run("execute", func(t *testing.T) {
|
||||
response, err := ds.Execute(context.Background(), &nodedatabase.CustomSQLRequest{
|
||||
DatabaseInfoID: 1,
|
||||
SQL: "select * from table where v1=? and v2=?",
|
||||
Params: []nodedatabase.SQLParam{
|
||||
nodedatabase.SQLParam{Value: "1"},
|
||||
nodedatabase.SQLParam{Value: "2"},
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, response.Objects, []nodedatabase.Object{
|
||||
{"v1": "1", "v2": "2"},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("select", func(t *testing.T) {
|
||||
req := &nodedatabase.QueryRequest{
|
||||
DatabaseInfoID: 1,
|
||||
SelectFields: []string{"v1", "v2", "v3", "v4"},
|
||||
Limit: 10,
|
||||
OrderClauses: []*nodedatabase.OrderClause{
|
||||
{FieldID: "v1", IsAsc: true},
|
||||
{FieldID: "v2", IsAsc: false},
|
||||
},
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
}
|
||||
|
||||
response, err := ds.Query(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *response.RowNumber, int64(10))
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
req := &nodedatabase.DeleteRequest{
|
||||
DatabaseInfoID: 1,
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorNotIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNotNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorNotLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
}
|
||||
response, err := ds.Delete(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, *response.RowNumber, int64(10))
|
||||
})
|
||||
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
req := &nodedatabase.InsertRequest{
|
||||
DatabaseInfoID: 1,
|
||||
Fields: map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "44aacc",
|
||||
},
|
||||
}
|
||||
_, err := ds.Insert(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("update", func(t *testing.T) {
|
||||
req := &nodedatabase.UpdateRequest{
|
||||
DatabaseInfoID: 1,
|
||||
ConditionGroup: &nodedatabase.ConditionGroup{
|
||||
Conditions: []*nodedatabase.Condition{
|
||||
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
|
||||
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
|
||||
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
|
||||
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
|
||||
},
|
||||
Relation: nodedatabase.ClauseRelationOR,
|
||||
},
|
||||
Fields: map[string]interface{}{
|
||||
"v1": "1",
|
||||
"v2": 2,
|
||||
"v3": 33,
|
||||
"v4": "aabbcc",
|
||||
},
|
||||
}
|
||||
_, err := ds.Update(context.Background(), req)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
@ -1,217 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
domainknowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type Knowledge struct {
|
||||
client domainknowledge.Knowledge
|
||||
idGen idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewKnowledgeRepository(client domainknowledge.Knowledge, idGen idgen.IDGenerator) *Knowledge {
|
||||
return &Knowledge{
|
||||
client: client,
|
||||
idGen: idGen,
|
||||
}
|
||||
}
|
||||
|
||||
func (k *Knowledge) Store(ctx context.Context, document *crossknowledge.CreateDocumentRequest) (*crossknowledge.CreateDocumentResponse, error) {
|
||||
var (
|
||||
ps *entity.ParsingStrategy
|
||||
cs = &entity.ChunkingStrategy{}
|
||||
)
|
||||
|
||||
if document.ParsingStrategy == nil {
|
||||
return nil, errors.New("document parsing strategy is required")
|
||||
}
|
||||
|
||||
if document.ChunkingStrategy == nil {
|
||||
return nil, errors.New("document chunking strategy is required")
|
||||
}
|
||||
|
||||
if document.ParsingStrategy.ParseMode == crossknowledge.AccurateParseMode {
|
||||
ps = &entity.ParsingStrategy{}
|
||||
ps.ExtractImage = document.ParsingStrategy.ExtractImage
|
||||
ps.ExtractTable = document.ParsingStrategy.ExtractTable
|
||||
ps.ImageOCR = document.ParsingStrategy.ImageOCR
|
||||
}
|
||||
|
||||
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.ChunkType = chunkType
|
||||
cs.Separator = document.ChunkingStrategy.Separator
|
||||
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
|
||||
cs.Overlap = document.ChunkingStrategy.Overlap
|
||||
|
||||
req := &entity.Document{
|
||||
Info: knowledge.Info{
|
||||
Name: document.FileName,
|
||||
},
|
||||
KnowledgeID: document.KnowledgeID,
|
||||
Type: knowledge.DocumentTypeText,
|
||||
URL: document.FileURL,
|
||||
Source: entity.DocumentSourceLocal,
|
||||
ParsingStrategy: ps,
|
||||
ChunkingStrategy: cs,
|
||||
FileExtension: document.FileExtension,
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.Info.CreatorID = *uid
|
||||
}
|
||||
|
||||
response, err := k.client.CreateDocument(ctx, &domainknowledge.CreateDocumentRequest{
|
||||
Documents: []*entity.Document{req},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kCResponse := &crossknowledge.CreateDocumentResponse{
|
||||
FileURL: document.FileURL,
|
||||
DocumentID: response.Documents[0].Info.ID,
|
||||
FileName: response.Documents[0].Info.Name,
|
||||
}
|
||||
|
||||
return kCResponse, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) Retrieve(ctx context.Context, r *crossknowledge.RetrieveRequest) (*crossknowledge.RetrieveResponse, error) {
|
||||
rs := &entity.RetrievalStrategy{}
|
||||
if r.RetrievalStrategy != nil {
|
||||
rs.TopK = r.RetrievalStrategy.TopK
|
||||
rs.MinScore = r.RetrievalStrategy.MinScore
|
||||
searchType, err := toSearchType(r.RetrievalStrategy.SearchType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs.SearchType = searchType
|
||||
rs.EnableQueryRewrite = r.RetrievalStrategy.EnableQueryRewrite
|
||||
rs.EnableRerank = r.RetrievalStrategy.EnableRerank
|
||||
rs.EnableNL2SQL = r.RetrievalStrategy.EnableNL2SQL
|
||||
}
|
||||
|
||||
req := &domainknowledge.RetrieveRequest{
|
||||
Query: r.Query,
|
||||
KnowledgeIDs: r.KnowledgeIDs,
|
||||
Strategy: rs,
|
||||
}
|
||||
|
||||
response, err := k.client.Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss := make([]*crossknowledge.Slice, 0, len(response.RetrieveSlices))
|
||||
for _, s := range response.RetrieveSlices {
|
||||
if s.Slice == nil {
|
||||
continue
|
||||
}
|
||||
ss = append(ss, &crossknowledge.Slice{
|
||||
DocumentID: strconv.FormatInt(s.Slice.DocumentID, 10),
|
||||
Output: s.Slice.GetSliceContent(),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &crossknowledge.RetrieveResponse{
|
||||
Slices: ss,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) Delete(ctx context.Context, r *crossknowledge.DeleteDocumentRequest) (*crossknowledge.DeleteDocumentResponse, error) {
|
||||
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
|
||||
}
|
||||
|
||||
err = k.client.DeleteDocument(ctx, &domainknowledge.DeleteDocumentRequest{
|
||||
DocumentID: docID,
|
||||
})
|
||||
if err != nil {
|
||||
return &crossknowledge.DeleteDocumentResponse{IsSuccess: false}, err
|
||||
}
|
||||
|
||||
return &crossknowledge.DeleteDocumentResponse{IsSuccess: true}, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) ListKnowledgeDetail(ctx context.Context, req *crossknowledge.ListKnowledgeDetailRequest) (*crossknowledge.ListKnowledgeDetailResponse, error) {
|
||||
response, err := k.client.MGetKnowledgeByID(ctx, &domainknowledge.MGetKnowledgeByIDRequest{
|
||||
KnowledgeIDs: req.KnowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &crossknowledge.ListKnowledgeDetailResponse{
|
||||
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *crossknowledge.KnowledgeDetail {
|
||||
return &crossknowledge.KnowledgeDetail{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Description: a.Description,
|
||||
IconURL: a.IconURL,
|
||||
FormatType: int64(a.Type),
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func toSearchType(typ crossknowledge.SearchType) (knowledge.SearchType, error) {
|
||||
switch typ {
|
||||
case crossknowledge.SearchTypeSemantic:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case crossknowledge.SearchTypeFullText:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
case crossknowledge.SearchTypeHybrid:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown search type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func toChunkType(typ crossknowledge.ChunkType) (parser.ChunkType, error) {
|
||||
switch typ {
|
||||
case crossknowledge.ChunkTypeDefault:
|
||||
return parser.ChunkTypeDefault, nil
|
||||
case crossknowledge.ChunkTypeCustom:
|
||||
return parser.ChunkTypeCustom, nil
|
||||
case crossknowledge.ChunkTypeLeveled:
|
||||
return parser.ChunkTypeLeveled, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown chunk type: %v", typ)
|
||||
}
|
||||
}
|
||||
@ -1,533 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type pluginService struct {
|
||||
client service.PluginService
|
||||
tos storage.Storage
|
||||
}
|
||||
|
||||
func NewPluginService(client service.PluginService, tos storage.Storage) crossplugin.Service {
|
||||
return &pluginService{client: client, tos: tos}
|
||||
}
|
||||
|
||||
type pluginInfo struct {
|
||||
*entity.PluginInfo
|
||||
LatestVersion *string
|
||||
}
|
||||
|
||||
func (t *pluginService) getPluginsWithTools(ctx context.Context, pluginEntity *crossplugin.Entity, toolIDs []int64, isDraft bool) (
|
||||
_ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var pluginsInfo []*entity.PluginInfo
|
||||
var latestPluginInfo *entity.PluginInfo
|
||||
pluginID := pluginEntity.PluginID
|
||||
if isDraft {
|
||||
plugins, err := t.client.MGetDraftPlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
|
||||
plugins, err := t.client.MGetOnlinePlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
|
||||
} else {
|
||||
plugins, err := t.client.MGetVersionPlugins(ctx, []entity.VersionPlugin{
|
||||
{PluginID: pluginID, Version: *pluginEntity.PluginVersion},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pluginsInfo = plugins
|
||||
|
||||
onlinePlugins, err := t.client.MGetOnlinePlugins(ctx, []int64{pluginID})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, pi := range onlinePlugins {
|
||||
if pi.ID == pluginID {
|
||||
latestPluginInfo = pi
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pInfo *entity.PluginInfo
|
||||
for _, p := range pluginsInfo {
|
||||
if p.ID == pluginID {
|
||||
pInfo = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if pInfo == nil {
|
||||
return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10)))
|
||||
}
|
||||
|
||||
if isDraft {
|
||||
tools, err := t.client.MGetDraftTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
|
||||
tools, err := t.client.MGetOnlineTools(ctx, toolIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
} else {
|
||||
eVersionTools := slices.Transform(toolIDs, func(tid int64) entity.VersionTool {
|
||||
return entity.VersionTool{
|
||||
ToolID: tid,
|
||||
Version: *pluginEntity.PluginVersion,
|
||||
}
|
||||
})
|
||||
tools, err := t.client.MGetVersionTools(ctx, eVersionTools)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toolsInfo = tools
|
||||
}
|
||||
|
||||
if latestPluginInfo != nil {
|
||||
return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil
|
||||
}
|
||||
|
||||
return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil
|
||||
}
|
||||
|
||||
func (t *pluginService) GetPluginToolsInfo(ctx context.Context, req *crossplugin.ToolsInfoRequest) (
|
||||
_ *crossplugin.ToolsInfoResponse, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var toolsInfo []*entity.ToolInfo
|
||||
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
|
||||
pInfo, toolsInfo, err := t.getPluginsWithTools(ctx, &crossplugin.Entity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, err := t.tos.GetObjectUrl(ctx, pInfo.GetIconURI())
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrTOSError, err)
|
||||
}
|
||||
|
||||
response := &crossplugin.ToolsInfoResponse{
|
||||
PluginID: pInfo.ID,
|
||||
SpaceID: pInfo.SpaceID,
|
||||
Version: pInfo.GetVersion(),
|
||||
PluginName: pInfo.GetName(),
|
||||
Description: pInfo.GetDesc(),
|
||||
IconURL: url,
|
||||
PluginType: int64(pInfo.PluginType),
|
||||
ToolInfoList: make(map[int64]crossplugin.ToolInfo),
|
||||
LatestVersion: pInfo.LatestVersion,
|
||||
IsOfficial: pInfo.IsOfficial(),
|
||||
AppID: pInfo.GetAPPID(),
|
||||
}
|
||||
|
||||
for _, tf := range toolsInfo {
|
||||
inputs, err := tf.ToReqAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs, err := tf.ToRespAPIParameter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolExample := pInfo.GetToolExample(ctx, tf.GetName())
|
||||
|
||||
var (
|
||||
requestExample string
|
||||
responseExample string
|
||||
)
|
||||
if toolExample != nil {
|
||||
requestExample = toolExample.RequestExample
|
||||
responseExample = toolExample.ResponseExample
|
||||
}
|
||||
|
||||
response.ToolInfoList[tf.ID] = crossplugin.ToolInfo{
|
||||
ToolID: tf.ID,
|
||||
ToolName: tf.GetName(),
|
||||
Inputs: slices.Transform(inputs, toWorkflowAPIParameter),
|
||||
Outputs: slices.Transform(outputs, toWorkflowAPIParameter),
|
||||
Description: tf.GetDesc(),
|
||||
DebugExample: &crossplugin.DebugExample{
|
||||
ReqExample: requestExample,
|
||||
RespExample: responseExample,
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (t *pluginService) GetPluginInvokableTools(ctx context.Context, req *crossplugin.ToolsInvokableRequest) (
|
||||
_ map[int64]crossplugin.InvokableTool, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var toolsInfo []*entity.ToolInfo
|
||||
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
|
||||
pInfo, toolsInfo, err := t.getPluginsWithTools(ctx, &crossplugin.Entity{
|
||||
PluginID: req.PluginEntity.PluginID,
|
||||
PluginVersion: req.PluginEntity.PluginVersion,
|
||||
}, maps.Keys(req.ToolsInvokableInfo), isDraft)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[int64]crossplugin.InvokableTool{}
|
||||
for _, tf := range toolsInfo {
|
||||
tl := &pluginInvokeTool{
|
||||
pluginEntity: crossplugin.Entity{
|
||||
PluginID: pInfo.ID,
|
||||
PluginVersion: pInfo.Version,
|
||||
},
|
||||
client: t.client,
|
||||
toolInfo: tf,
|
||||
IsDraft: isDraft,
|
||||
}
|
||||
|
||||
if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) {
|
||||
reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter)
|
||||
respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter)
|
||||
|
||||
tl.toolOperation, err = pluginutil.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tl.toolOperation.OperationID = tf.Operation.OperationID
|
||||
tl.toolOperation.Summary = tf.Operation.Summary
|
||||
}
|
||||
|
||||
result[tf.ID] = tl
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *pluginService) ExecutePlugin(ctx context.Context, input map[string]any, pe *crossplugin.Entity,
|
||||
toolID int64, cfg crossplugin.ExecConfig) (map[string]any, error) {
|
||||
args, err := sonic.MarshalString(input)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
|
||||
var uID string
|
||||
if cfg.AgentID != nil {
|
||||
uID = cfg.ConnectorUID
|
||||
} else {
|
||||
uID = conv.Int64ToStr(cfg.Operator)
|
||||
}
|
||||
|
||||
req := &service.ExecuteToolRequest{
|
||||
UserID: uID,
|
||||
PluginID: pe.PluginID,
|
||||
ToolID: toolID,
|
||||
ExecScene: plugin.ExecSceneOfWorkflow,
|
||||
ArgumentsInJson: args,
|
||||
ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0",
|
||||
}
|
||||
execOpts := []entity.ExecuteToolOpt{
|
||||
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
}
|
||||
|
||||
if pe.PluginVersion != nil {
|
||||
execOpts = append(execOpts, plugin.WithToolVersion(*pe.PluginVersion))
|
||||
}
|
||||
|
||||
r, err := t.client.ExecuteTool(ctx, req, execOpts...)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
pluginTIE, ok := extra.(*plugin.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
|
||||
}
|
||||
|
||||
var eventType workflow3.EventType
|
||||
switch pluginTIE.Event {
|
||||
case plugin.InterruptEventTypeOfToolNeedOAuth:
|
||||
eventType = workflow3.EventType_WorkflowOauthPlugin
|
||||
default:
|
||||
return nil, vo.WrapError(errno.ErrPluginAPIErr,
|
||||
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
|
||||
}
|
||||
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
ie := &entity2.InterruptEvent{
|
||||
ID: id,
|
||||
InterruptData: pluginTIE.ToolNeedOAuth.Message,
|
||||
EventType: eventType,
|
||||
}
|
||||
|
||||
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
|
||||
interruptData := ie.InterruptData
|
||||
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var output map[string]any
|
||||
err = sonic.UnmarshalString(r.TrimmedResp, &output)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
type pluginInvokeTool struct {
|
||||
pluginEntity crossplugin.Entity
|
||||
client service.PluginService
|
||||
toolInfo *entity.ToolInfo
|
||||
toolOperation *openapi3.Operation
|
||||
IsDraft bool
|
||||
}
|
||||
|
||||
func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var parameterInfo map[string]*schema.ParameterInfo
|
||||
if p.toolOperation != nil {
|
||||
parameterInfo, err = plugin.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx)
|
||||
} else {
|
||||
parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &schema.ToolInfo{
|
||||
Name: p.toolInfo.GetName(),
|
||||
Desc: p.toolInfo.GetDesc(),
|
||||
ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg crossplugin.ExecConfig) (string, error) {
|
||||
req := &service.ExecuteToolRequest{
|
||||
UserID: conv.Int64ToStr(cfg.Operator),
|
||||
PluginID: p.pluginEntity.PluginID,
|
||||
ToolID: p.toolInfo.ID,
|
||||
ExecScene: plugin.ExecSceneOfWorkflow,
|
||||
ArgumentsInJson: argumentsInJSON,
|
||||
ExecDraftTool: p.IsDraft,
|
||||
}
|
||||
execOpts := []entity.ExecuteToolOpt{
|
||||
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
|
||||
}
|
||||
|
||||
if p.pluginEntity.PluginVersion != nil {
|
||||
execOpts = append(execOpts, plugin.WithToolVersion(*p.pluginEntity.PluginVersion))
|
||||
}
|
||||
|
||||
if p.toolOperation != nil {
|
||||
execOpts = append(execOpts, plugin.WithOpenapiOperation(plugin.NewOpenapi3Operation(p.toolOperation)))
|
||||
}
|
||||
|
||||
r, err := p.client.ExecuteTool(ctx, req, execOpts...)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
pluginTIE, ok := extra.(*plugin.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
|
||||
}
|
||||
|
||||
var eventType workflow3.EventType
|
||||
switch pluginTIE.Event {
|
||||
case plugin.InterruptEventTypeOfToolNeedOAuth:
|
||||
eventType = workflow3.EventType_WorkflowOauthPlugin
|
||||
default:
|
||||
return "", vo.WrapError(errno.ErrPluginAPIErr,
|
||||
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
|
||||
}
|
||||
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return "", vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
|
||||
ie := &entity2.InterruptEvent{
|
||||
ID: id,
|
||||
InterruptData: pluginTIE.ToolNeedOAuth.Message,
|
||||
EventType: eventType,
|
||||
}
|
||||
|
||||
tie := &entity2.ToolInterruptEvent{
|
||||
ToolCallID: compose.GetToolCallID(ctx),
|
||||
ToolName: p.toolInfo.GetName(),
|
||||
InterruptEvent: ie,
|
||||
}
|
||||
|
||||
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
|
||||
_ = tie
|
||||
interruptData := ie.InterruptData
|
||||
return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return r.TrimmedResp, nil
|
||||
}
|
||||
|
||||
func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter {
|
||||
if parameter == nil {
|
||||
return nil
|
||||
}
|
||||
p := &common.APIParameter{
|
||||
ID: parameter.ID,
|
||||
Name: parameter.Name,
|
||||
Desc: parameter.Desc,
|
||||
Type: common.ParameterType(parameter.Type),
|
||||
Location: common.ParameterLocation(parameter.Location),
|
||||
IsRequired: parameter.IsRequired,
|
||||
GlobalDefault: parameter.GlobalDefault,
|
||||
GlobalDisable: parameter.GlobalDisable,
|
||||
LocalDefault: parameter.LocalDefault,
|
||||
LocalDisable: parameter.LocalDisable,
|
||||
VariableRef: parameter.VariableRef,
|
||||
}
|
||||
if parameter.SubType != nil {
|
||||
p.SubType = ptr.Of(common.ParameterType(*parameter.SubType))
|
||||
}
|
||||
|
||||
if parameter.DefaultParamSource != nil {
|
||||
p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource))
|
||||
}
|
||||
if parameter.AssistType != nil {
|
||||
p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType))
|
||||
}
|
||||
|
||||
if len(parameter.SubParameters) > 0 {
|
||||
p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters))
|
||||
for _, subParam := range parameter.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam))
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter {
|
||||
if parameter == nil {
|
||||
return nil
|
||||
}
|
||||
p := &workflow3.APIParameter{
|
||||
ID: parameter.ID,
|
||||
Name: parameter.Name,
|
||||
Desc: parameter.Desc,
|
||||
Type: workflow3.ParameterType(parameter.Type),
|
||||
Location: workflow3.ParameterLocation(parameter.Location),
|
||||
IsRequired: parameter.IsRequired,
|
||||
GlobalDefault: parameter.GlobalDefault,
|
||||
GlobalDisable: parameter.GlobalDisable,
|
||||
LocalDefault: parameter.LocalDefault,
|
||||
LocalDisable: parameter.LocalDisable,
|
||||
VariableRef: parameter.VariableRef,
|
||||
}
|
||||
if parameter.SubType != nil {
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType))
|
||||
}
|
||||
|
||||
if parameter.DefaultParamSource != nil {
|
||||
p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource))
|
||||
}
|
||||
if parameter.AssistType != nil {
|
||||
p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType))
|
||||
}
|
||||
|
||||
// Check if it's an array that needs unwrapping.
|
||||
if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" {
|
||||
arrayItem := parameter.SubParameters[0]
|
||||
p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type))
|
||||
// If the "[Array Item]" is an object, its sub-parameters become the array's sub-parameters.
|
||||
if arrayItem.Type == common.ParameterType_Object {
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters))
|
||||
for _, subParam := range arrayItem.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
}
|
||||
} else {
|
||||
// The array's SubType is the Type of the "[Array Item]".
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, 1)
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem))
|
||||
p.SubParameters[0].Name = "" // Remove the "[Array Item]" name.
|
||||
}
|
||||
} else if len(parameter.SubParameters) > 0 { // A simple object or a non-wrapped array.
|
||||
p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters))
|
||||
for _, subParam := range parameter.SubParameters {
|
||||
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
@ -1,189 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestToWorkflowAPIParameter(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
param *common.APIParameter
|
||||
expected *workflow3.APIParameter
|
||||
}{
|
||||
{
|
||||
name: "nil parameter",
|
||||
param: nil,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "simple string parameter",
|
||||
param: &common.APIParameter{
|
||||
Name: "prompt",
|
||||
Type: common.ParameterType_String,
|
||||
Desc: "User's prompt",
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "prompt",
|
||||
Type: workflow3.ParameterType_String,
|
||||
Desc: "User's prompt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple object parameter",
|
||||
param: &common.APIParameter{
|
||||
Name: "user_info",
|
||||
Type: common.ParameterType_Object,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "name",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
{
|
||||
Name: "age",
|
||||
Type: common.ParameterType_Number,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "user_info",
|
||||
Type: workflow3.ParameterType_Object,
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Name: "name",
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
{
|
||||
Name: "age",
|
||||
Type: workflow3.ParameterType_Number,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of strings",
|
||||
param: &common.APIParameter{
|
||||
Name: "tags",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "tags",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_String),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of objects",
|
||||
param: &common.APIParameter{
|
||||
Name: "users",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_Object,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "name",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
{
|
||||
Name: "id",
|
||||
Type: common.ParameterType_Number,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "users",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_Object),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Name: "name",
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
{
|
||||
Name: "id",
|
||||
Type: workflow3.ParameterType_Number,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of array of strings",
|
||||
param: &common.APIParameter{
|
||||
Name: "matrix",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_Array,
|
||||
SubParameters: []*common.APIParameter{
|
||||
{
|
||||
Name: "[Array Item]",
|
||||
Type: common.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &workflow3.APIParameter{
|
||||
Name: "matrix",
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_Array),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Name: "", // Name is cleared
|
||||
Type: workflow3.ParameterType_Array,
|
||||
SubType: ptr.Of(workflow3.ParameterType_String),
|
||||
SubParameters: []*workflow3.APIParameter{
|
||||
{
|
||||
Type: workflow3.ParameterType_String,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := toWorkflowAPIParameter(tc.param)
|
||||
assert.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,74 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
|
||||
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
|
||||
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Notifier interface {
|
||||
PublishWorkflowResource(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
|
||||
}
|
||||
|
||||
type Notify struct {
|
||||
client search.ResourceEventBus
|
||||
}
|
||||
|
||||
func NewNotify(client search.ResourceEventBus) *Notify {
|
||||
return &Notify{client: client}
|
||||
}
|
||||
|
||||
func (n *Notify) PublishWorkflowResource(ctx context.Context, op crosssearch.OpType, r *crosssearch.Resource) error {
|
||||
entityResource := &entity.ResourceDocument{
|
||||
ResType: common.ResType_Workflow,
|
||||
ResID: r.WorkflowID,
|
||||
ResSubType: r.Mode,
|
||||
Name: r.Name,
|
||||
SpaceID: r.SpaceID,
|
||||
OwnerID: r.OwnerID,
|
||||
APPID: r.APPID,
|
||||
}
|
||||
if r.PublishStatus != nil {
|
||||
publishStatus := *r.PublishStatus
|
||||
entityResource.PublishStatus = ptr.Of(common.PublishStatus(publishStatus))
|
||||
entityResource.PublishTimeMS = r.PublishedAt
|
||||
}
|
||||
|
||||
resource := &entity.ResourceDomainEvent{
|
||||
OpType: entity.OpType(op),
|
||||
Resource: entityResource,
|
||||
}
|
||||
if op == crosssearch.Created {
|
||||
resource.Resource.CreateTimeMS = r.CreatedAt
|
||||
resource.Resource.UpdateTimeMS = r.UpdatedAt
|
||||
} else if op == crosssearch.Updated {
|
||||
resource.Resource.UpdateTimeMS = r.UpdatedAt
|
||||
}
|
||||
|
||||
err := n.client.PublishResources(ctx, resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -29,7 +29,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
@ -74,6 +74,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
var composeOpts []compose.Option
|
||||
var pipeMsgOpt compose.Option
|
||||
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
|
||||
var workflowMsgCloser func()
|
||||
if r.containWfTool {
|
||||
cfReq := crossworkflow.ExecuteConfig{
|
||||
AgentID: &req.Identity.AgentID,
|
||||
@ -88,7 +89,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
}
|
||||
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
|
||||
composeOpts = append(composeOpts, wfConfig)
|
||||
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
|
||||
pipeMsgOpt, workflowMsgSr, workflowMsgCloser = crossworkflow.DefaultSVC().WithMessagePipe()
|
||||
composeOpts = append(composeOpts, pipeMsgOpt)
|
||||
}
|
||||
|
||||
@ -120,6 +121,9 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
|
||||
sw.Send(nil, errors.New("internal server error"))
|
||||
}
|
||||
if workflowMsgCloser != nil {
|
||||
workflowMsgCloser()
|
||||
}
|
||||
sw.Close()
|
||||
}()
|
||||
_, _ = r.runner.Stream(ctx, req, composeOpts...)
|
||||
@ -136,6 +140,7 @@ func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.Str
|
||||
if swT != nil {
|
||||
swT.Close()
|
||||
}
|
||||
wfStream.Close()
|
||||
}()
|
||||
for {
|
||||
msg, err := wfStream.Recv()
|
||||
|
||||
@ -31,7 +31,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
||||
@ -26,7 +26,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
knowledgeEntity "github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
|
||||
@ -56,6 +56,7 @@ func newSuggestGraph(_ context.Context, conf *Config, chatModel chatmodel.ToolCa
|
||||
}
|
||||
suggestPrompt := prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage(SUGGESTION_PROMPT_JINJA2),
|
||||
schema.UserMessage("Based on the contextual information, provide three recommended questions"),
|
||||
)
|
||||
|
||||
suggestGraph := compose.NewGraph[[]*schema.Message, *schema.Message]()
|
||||
|
||||
@ -33,7 +33,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
knowledgeEntity "github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
)
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
|
||||
@ -26,11 +26,11 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
@ -89,7 +89,7 @@ func (pr *toolPreCallConf) toolPreRetrieve(ctx context.Context, ar *AgentRequest
|
||||
logs.CtxErrorf(ctx, "Failed to unmarshal json arguments: %s", item.Arguments)
|
||||
return nil, err
|
||||
}
|
||||
execResp, _, err := crossworkflow.DefaultSVC().SyncExecuteWorkflow(ctx, vo.ExecuteConfig{
|
||||
execResp, _, err := crossworkflow.DefaultSVC().SyncExecuteWorkflow(ctx, workflowModel.ExecuteConfig{
|
||||
ID: item.PluginID,
|
||||
ConnectorID: ar.Identity.ConnectorID,
|
||||
ConnectorUID: ar.UserID,
|
||||
|
||||
@ -25,7 +25,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
||||
@ -20,7 +20,8 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
@ -36,7 +37,7 @@ func newWorkflowTools(ctx context.Context, conf *workflowConfig) ([]workflow.Too
|
||||
id := info.GetWorkflowId()
|
||||
policies = append(policies, &vo.GetPolicy{
|
||||
ID: id,
|
||||
QType: vo.FromLatestVersion,
|
||||
QType: workflowModel.FromLatestVersion,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -102,6 +102,18 @@ func (sa *SingleAgentDraftDAO) MGet(ctx context.Context, agentIDs []int64) ([]*e
|
||||
return dos, nil
|
||||
}
|
||||
|
||||
func (sa *SingleAgentDraftDAO) Save(ctx context.Context, agentInfo *entity.SingleAgent) (err error) {
|
||||
po := sa.singleAgentDraftDo2Po(agentInfo)
|
||||
singleAgentDAOModel := sa.dbQuery.SingleAgentDraft
|
||||
|
||||
err = singleAgentDAOModel.WithContext(ctx).Where(singleAgentDAOModel.AgentID.Eq(agentInfo.AgentID)).Save(po)
|
||||
if err != nil {
|
||||
return errorx.WrapByCode(err, errno.ErrAgentUpdateCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sa *SingleAgentDraftDAO) Update(ctx context.Context, agentInfo *entity.SingleAgent) (err error) {
|
||||
po := sa.singleAgentDraftDo2Po(agentInfo)
|
||||
singleAgentDAOModel := sa.dbQuery.SingleAgentDraft
|
||||
|
||||
@ -46,6 +46,7 @@ type SingleAgentDraftRepo interface {
|
||||
MGet(ctx context.Context, agentIDs []int64) ([]*entity.SingleAgent, error)
|
||||
Delete(ctx context.Context, spaceID, agentID int64) (err error)
|
||||
Update(ctx context.Context, agentInfo *entity.SingleAgent) (err error)
|
||||
Save(ctx context.Context, agentInfo *entity.SingleAgent) (err error)
|
||||
|
||||
GetDisplayInfo(ctx context.Context, userID, agentID int64) (*entity.AgentDraftDisplayInfo, error)
|
||||
UpdateDisplayInfo(ctx context.Context, userID int64, e *entity.AgentDraftDisplayInfo) error
|
||||
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/developer_api"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
|
||||
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
||||
@ -27,7 +27,7 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/agentflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/repository"
|
||||
@ -153,7 +153,7 @@ func (s *singleAgentImpl) UpdateSingleAgentDraft(ctx context.Context, agentInfo
|
||||
}
|
||||
}
|
||||
|
||||
return s.AgentDraftRepo.Update(ctx, agentInfo)
|
||||
return s.AgentDraftRepo.Save(ctx, agentInfo)
|
||||
}
|
||||
|
||||
func (s *singleAgentImpl) CreateSingleAgentDraftWithID(ctx context.Context, creatorID, agentID int64, draft *entity.SingleAgent) (int64, error) {
|
||||
|
||||
@ -22,8 +22,8 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
resourceCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
|
||||
@ -25,10 +25,10 @@ import (
|
||||
connectorModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/connector"
|
||||
databaseModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
|
||||
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
|
||||
@ -37,8 +37,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
|
||||
@ -27,8 +27,15 @@ type WhereSliceOpt struct {
|
||||
DocumentID int64
|
||||
DocumentIDs []int64
|
||||
Keyword *string
|
||||
Sequence int64
|
||||
PageSize int64
|
||||
Offset int64
|
||||
NotEmpty *bool
|
||||
}
|
||||
|
||||
type WherePhotoSliceOpt struct {
|
||||
KnowledgeID int64
|
||||
DocumentIDs []int64
|
||||
Limit *int
|
||||
Offset *int
|
||||
HasCaption *bool
|
||||
}
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
@ -52,7 +53,7 @@ func (dao *KnowledgeDocumentSliceDAO) Update(ctx context.Context, slice *model.K
|
||||
}
|
||||
|
||||
func (dao *KnowledgeDocumentSliceDAO) BatchCreate(ctx context.Context, slices []*model.KnowledgeDocumentSlice) error {
|
||||
return dao.Query.KnowledgeDocumentSlice.WithContext(ctx).CreateInBatches(slices, 100)
|
||||
return dao.Query.KnowledgeDocumentSlice.WithContext(ctx).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(slices, 100)
|
||||
}
|
||||
|
||||
func (dao *KnowledgeDocumentSliceDAO) BatchSetStatus(ctx context.Context, ids []int64, status int32, reason string) error {
|
||||
@ -236,8 +237,11 @@ func (dao *KnowledgeDocumentSliceDAO) FindSliceByCondition(ctx context.Context,
|
||||
|
||||
if opts.PageSize != 0 {
|
||||
do = do.Limit(int(opts.PageSize))
|
||||
do = do.Offset(int(opts.Sequence)).Order(s.Sequence.Asc())
|
||||
}
|
||||
if opts.Offset != 0 {
|
||||
do = do.Offset(int(opts.Offset))
|
||||
}
|
||||
do = do.Order(s.Sequence.Asc())
|
||||
if opts.NotEmpty != nil {
|
||||
if ptr.From(opts.NotEmpty) {
|
||||
do = do.Where(s.Content.Neq(""))
|
||||
@ -319,3 +323,44 @@ func (dao *KnowledgeDocumentSliceDAO) GetLastSequence(ctx context.Context, docum
|
||||
}
|
||||
return resp.Sequence, nil
|
||||
}
|
||||
|
||||
func (dao *KnowledgeDocumentSliceDAO) ListPhotoSlice(ctx context.Context, opts *entity.WherePhotoSliceOpt) ([]*model.KnowledgeDocumentSlice, int64, error) {
|
||||
s := dao.Query.KnowledgeDocumentSlice
|
||||
do := s.WithContext(ctx)
|
||||
if opts.KnowledgeID != 0 {
|
||||
do = do.Where(s.KnowledgeID.Eq(opts.KnowledgeID))
|
||||
}
|
||||
if len(opts.DocumentIDs) != 0 {
|
||||
do = do.Where(s.DocumentID.In(opts.DocumentIDs...))
|
||||
}
|
||||
if ptr.From(opts.Limit) != 0 {
|
||||
do = do.Limit(int(ptr.From(opts.Limit)))
|
||||
}
|
||||
if ptr.From(opts.Offset) != 0 {
|
||||
do = do.Offset(int(ptr.From(opts.Offset)))
|
||||
}
|
||||
if opts.HasCaption != nil {
|
||||
if ptr.From(opts.HasCaption) {
|
||||
do = do.Where(s.Content.Neq(""))
|
||||
} else {
|
||||
do = do.Where(s.Content.Eq(""))
|
||||
}
|
||||
}
|
||||
do = do.Order(s.UpdatedAt.Desc())
|
||||
pos, err := do.Find()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total, err := do.Limit(-1).Offset(-1).Count()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return pos, total, nil
|
||||
}
|
||||
|
||||
func (dao *KnowledgeDocumentSliceDAO) BatchCreateWithTX(ctx context.Context, tx *gorm.DB, slices []*model.KnowledgeDocumentSlice) error {
|
||||
if len(slices) == 0 {
|
||||
return nil
|
||||
}
|
||||
return tx.WithContext(ctx).Debug().Model(&model.KnowledgeDocumentSlice{}).CreateInBatches(slices, 100).Error
|
||||
}
|
||||
|
||||
@ -49,8 +49,9 @@ type baseDocProcessor struct {
|
||||
documentSource *entity.DocumentSource
|
||||
|
||||
// Drop DB model
|
||||
TableName string
|
||||
docModels []*model.KnowledgeDocument
|
||||
TableName string
|
||||
docModels []*model.KnowledgeDocument
|
||||
imageSlices []*model.KnowledgeDocumentSlice
|
||||
|
||||
storage storage.Storage
|
||||
knowledgeRepo repository.KnowledgeRepo
|
||||
@ -69,14 +70,14 @@ func (p *baseDocProcessor) BeforeCreate() error {
|
||||
|
||||
func (p *baseDocProcessor) BuildDBModel() error {
|
||||
p.docModels = make([]*model.KnowledgeDocument, 0, len(p.Documents))
|
||||
ids, err := p.idgen.GenMultiIDs(p.ctx, len(p.Documents))
|
||||
if err != nil {
|
||||
logs.CtxErrorf(p.ctx, "gen ids failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeIDGenCode)
|
||||
}
|
||||
for i := range p.Documents {
|
||||
id, err := p.idgen.GenID(p.ctx)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(p.ctx, "gen id failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeIDGenCode)
|
||||
}
|
||||
docModel := &model.KnowledgeDocument{
|
||||
ID: ids[i],
|
||||
ID: id,
|
||||
KnowledgeID: p.Documents[i].KnowledgeID,
|
||||
Name: p.Documents[i].Name,
|
||||
FileExtension: string(p.Documents[i].FileExtension),
|
||||
@ -95,6 +96,23 @@ func (p *baseDocProcessor) BuildDBModel() error {
|
||||
}
|
||||
p.Documents[i].ID = docModel.ID
|
||||
p.docModels = append(p.docModels, docModel)
|
||||
if p.Documents[i].Type == knowledge.DocumentTypeImage {
|
||||
id, err := p.idgen.GenID(p.ctx)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(p.ctx, "gen id failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeIDGenCode)
|
||||
}
|
||||
p.imageSlices = append(p.imageSlices, &model.KnowledgeDocumentSlice{
|
||||
ID: id,
|
||||
KnowledgeID: p.Documents[i].KnowledgeID,
|
||||
DocumentID: p.Documents[i].ID,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
CreatorID: p.UserID,
|
||||
SpaceID: p.SpaceID,
|
||||
Status: int32(knowledge.SliceStatusInit),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -142,6 +160,11 @@ func (p *baseDocProcessor) InsertDBModel() (err error) {
|
||||
logs.CtxErrorf(ctx, "create document failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
err = p.sliceRepo.BatchCreateWithTX(ctx, tx, p.imageSlices)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "update knowledge failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
err = p.knowledgeRepo.UpdateWithTx(ctx, tx, p.Documents[0].KnowledgeID, map[string]interface{}{
|
||||
"updated_at": time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
@ -84,12 +84,12 @@ type KnowledgeDocumentSliceRepo interface {
|
||||
Create(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
|
||||
Update(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
|
||||
Delete(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
|
||||
|
||||
BatchCreateWithTX(ctx context.Context, tx *gorm.DB, slices []*model.KnowledgeDocumentSlice) error
|
||||
BatchCreate(ctx context.Context, slices []*model.KnowledgeDocumentSlice) error
|
||||
BatchSetStatus(ctx context.Context, ids []int64, status int32, reason string) error
|
||||
DeleteByDocument(ctx context.Context, documentID int64) error
|
||||
MGetSlices(ctx context.Context, sliceIDs []int64) ([]*model.KnowledgeDocumentSlice, error)
|
||||
|
||||
ListPhotoSlice(ctx context.Context, opts *entity.WherePhotoSliceOpt) ([]*model.KnowledgeDocumentSlice, int64, error)
|
||||
FindSliceByCondition(ctx context.Context, opts *entity.WhereSliceOpt) (
|
||||
[]*model.KnowledgeDocumentSlice, int64, error)
|
||||
GetDocumentSliceIDs(ctx context.Context, docIDs []int64) (sliceIDs []int64, err error)
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
|
||||
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/datacopy"
|
||||
copyEntity "github.com/coze-dev/coze-studio/backend/domain/datacopy/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
|
||||
@ -35,12 +35,15 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/events"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
progressbarContract "github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
@ -139,96 +142,166 @@ func (k *knowledgeSVC) indexDocuments(ctx context.Context, event *entity.Event)
|
||||
return nil
|
||||
}
|
||||
|
||||
type indexDocCacheRecord struct {
|
||||
ProcessingIDs []int64
|
||||
LastProcessedNumber int64
|
||||
ParseUri string
|
||||
}
|
||||
|
||||
const (
|
||||
indexDocCacheKey = "index_doc_cache:%d:%d"
|
||||
)
|
||||
|
||||
// indexDocumentNew handles the indexing of a new document into the knowledge system
|
||||
func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (err error) {
|
||||
doc := event.Document
|
||||
if doc == nil {
|
||||
return errorx.New(errno.ErrKnowledgeNonRetryableCode, errorx.KV("reason", "[indexDocument] document not provided"))
|
||||
}
|
||||
|
||||
// 1. The index operations on the same document in the retry queue and the ordinary queue are concurrent, and the same document data is written twice (generated when the backend bugfix is online)
|
||||
// 2. rebalance repeated consumption of the same message
|
||||
|
||||
// check knowledge and document status
|
||||
if valid, err := k.isWritableKnowledgeAndDocument(ctx, doc.KnowledgeID, doc.ID); err != nil {
|
||||
return err
|
||||
} else if !valid {
|
||||
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[indexDocument] not writable, knowledge_id=%d, document_id=%d", event.KnowledgeID, doc.ID))
|
||||
errorx.KV("reason", "[indexDocument] document not provided"))
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("panic: %v", e)))
|
||||
logs.CtxErrorf(ctx, "[indexDocument] panic, err: %v", err)
|
||||
if setStatusErr := k.documentRepo.SetStatus(ctx, event.Document.ID, int32(entity.DocumentStatusFailed), err.Error()); setStatusErr != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
|
||||
}
|
||||
// Validate document and knowledge status
|
||||
var valid bool
|
||||
if valid, err = k.validateDocumentStatus(ctx, doc); err != nil || !valid {
|
||||
return
|
||||
}
|
||||
|
||||
// Setup error handling and recovery
|
||||
defer k.handleIndexingErrors(ctx, event, &err)
|
||||
|
||||
// Start indexing process
|
||||
if err = k.beginIndexingProcess(ctx, doc); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process document parsing and chunking
|
||||
var parseResult []*schema.Document
|
||||
var cacheRecord *indexDocCacheRecord
|
||||
parseResult, cacheRecord, err = k.processDocumentParsing(ctx, doc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cacheRecord.LastProcessedNumber == 0 {
|
||||
if err = k.cleanupPreviousProcessing(ctx, doc); err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
var errMsg string
|
||||
var statusError errorx.StatusError
|
||||
var status int32
|
||||
if errors.As(err, &statusError) {
|
||||
errMsg = errorx.ErrorWithoutStack(statusError)
|
||||
if statusError.Code() == errno.ErrKnowledgeNonRetryableCode {
|
||||
status = int32(entity.DocumentStatusFailed)
|
||||
} else {
|
||||
status = int32(entity.DocumentStatusChunking)
|
||||
}
|
||||
}
|
||||
// Handle table-type documents specially
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
if err = k.handleTableDocument(ctx, doc, parseResult); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Process document chunks in batches
|
||||
if err = k.processDocumentChunks(ctx, doc, parseResult, cacheRecord); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Finalize document indexing
|
||||
err = k.finalizeDocumentIndexing(ctx, event.Document.KnowledgeID, event.Document.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// validateDocumentStatus checks if the document can be indexed
|
||||
func (k *knowledgeSVC) validateDocumentStatus(ctx context.Context, doc *entity.Document) (bool, error) {
|
||||
valid, err := k.isWritableKnowledgeAndDocument(ctx, doc.KnowledgeID, doc.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !valid {
|
||||
return false, errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[indexDocument] not writable, knowledge_id=%d, document_id=%d",
|
||||
doc.KnowledgeID, doc.ID))
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleIndexingErrors manages errors and recovery during indexing
|
||||
func (k *knowledgeSVC) handleIndexingErrors(ctx context.Context, event *entity.Event, err *error) {
|
||||
if e := recover(); e != nil {
|
||||
err = ptr.Of(errorx.New(errno.ErrKnowledgeSystemCode,
|
||||
errorx.KV("msg", fmt.Sprintf("panic: %v", e))))
|
||||
logs.CtxErrorf(ctx, "[indexDocument] panic, err: %v", err)
|
||||
k.setDocumentStatus(ctx, event.Document.ID,
|
||||
int32(entity.DocumentStatusFailed), ptr.From(err).Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ptr.From(err) != nil {
|
||||
var status int32
|
||||
var errMsg string
|
||||
|
||||
var statusError errorx.StatusError
|
||||
if errors.As(ptr.From(err), &statusError) {
|
||||
errMsg = errorx.ErrorWithoutStack(statusError)
|
||||
if statusError.Code() == errno.ErrKnowledgeNonRetryableCode {
|
||||
status = int32(entity.DocumentStatusFailed)
|
||||
} else {
|
||||
errMsg = err.Error()
|
||||
status = int32(entity.DocumentStatusChunking)
|
||||
}
|
||||
if setStatusErr := k.documentRepo.SetStatus(ctx, event.Document.ID, status, errMsg); setStatusErr != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
|
||||
}
|
||||
} else {
|
||||
errMsg = ptr.From(err).Error()
|
||||
status = int32(entity.DocumentStatusChunking)
|
||||
}
|
||||
}()
|
||||
|
||||
// clear
|
||||
collectionName := getCollectionName(doc.KnowledgeID)
|
||||
|
||||
if !doc.IsAppend {
|
||||
ids, err := k.sliceRepo.GetDocumentSliceIDs(ctx, []int64{doc.ID})
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get document slice ids failed, err: %v", err)))
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
if err = k.sliceRepo.DeleteByDocument(ctx, doc.ID); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("delete document slice failed, err: %v", err)))
|
||||
}
|
||||
for _, manager := range k.searchStoreManagers {
|
||||
s, err := manager.GetSearchStore(ctx, collectionName)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
|
||||
}
|
||||
if err := s.Delete(ctx, slices.Transform(event.SliceIDs, func(id int64) string {
|
||||
return strconv.FormatInt(id, 10)
|
||||
})); err != nil {
|
||||
logs.Errorf("[indexDocument] delete knowledge failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("delete search store failed, err: %v", err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
k.setDocumentStatus(ctx, event.Document.ID, status, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// set chunk status
|
||||
if err = k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusChunking), ""); err != nil {
|
||||
// beginIndexingProcess starts the indexing process
|
||||
func (k *knowledgeSVC) beginIndexingProcess(ctx context.Context, doc *entity.Document) error {
|
||||
err := k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusChunking), "")
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse & chunk
|
||||
// processDocumentParsing handles document parsing and caching
|
||||
func (k *knowledgeSVC) processDocumentParsing(ctx context.Context, doc *entity.Document) (
|
||||
[]*schema.Document, *indexDocCacheRecord, error) {
|
||||
|
||||
cacheKey := fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)
|
||||
cacheRecord := &indexDocCacheRecord{}
|
||||
|
||||
// Try to get cached parse results
|
||||
val, err := k.cacheCli.Get(ctx, cacheKey).Result()
|
||||
if err == nil {
|
||||
if err = sonic.UnmarshalString(val, &cacheRecord); err != nil {
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeParseJSONCode,
|
||||
errorx.KV("msg", fmt.Sprintf("parse cache record failed, err: %v", err)))
|
||||
}
|
||||
}
|
||||
|
||||
// Parse document if not cached
|
||||
if err != nil || len(cacheRecord.ParseUri) == 0 {
|
||||
return k.parseAndCacheDocument(ctx, doc, cacheRecord, cacheKey)
|
||||
}
|
||||
|
||||
// Load parse results from cache
|
||||
return k.loadParsedDocument(ctx, cacheRecord)
|
||||
}
|
||||
|
||||
// parseAndCacheDocument parses the document and caches the results
|
||||
func (k *knowledgeSVC) parseAndCacheDocument(ctx context.Context, doc *entity.Document,
|
||||
cacheRecord *indexDocCacheRecord, cacheKey string) ([]*schema.Document, *indexDocCacheRecord, error) {
|
||||
|
||||
// Get document content from storage
|
||||
bodyBytes, err := k.storage.GetObject(ctx, doc.URI)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeGetObjectFailCode, errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeGetObjectFailCode,
|
||||
errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
|
||||
}
|
||||
|
||||
// Get appropriate parser for document type
|
||||
docParser, err := k.parseManager.GetParser(convert.DocumentToParseConfig(doc))
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeGetParserFailCode, errorx.KV("msg", fmt.Sprintf("get parser failed, err: %v", err)))
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeGetParserFailCode,
|
||||
errorx.KV("msg", fmt.Sprintf("get parser failed, err: %v", err)))
|
||||
}
|
||||
|
||||
// Parse document content
|
||||
parseResult, err := docParser.Parse(ctx, bytes.NewReader(bodyBytes), parser.WithExtraMeta(map[string]any{
|
||||
document.MetaDataKeyCreatorID: doc.CreatorID,
|
||||
document.MetaDataKeyExternalStorage: map[string]any{
|
||||
@ -236,77 +309,303 @@ func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (
|
||||
},
|
||||
}))
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeParserParseFailCode, errorx.KV("msg", fmt.Sprintf("parse document failed, err: %v", err)))
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeParserParseFailCode,
|
||||
errorx.KV("msg", fmt.Sprintf("parse document failed, err: %v", err)))
|
||||
}
|
||||
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
noData, err := document.GetDocumentsColumnsOnly(parseResult)
|
||||
if err != nil { // unexpected
|
||||
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[indexDocument] get table data status failed, err: %v", err))
|
||||
}
|
||||
if noData {
|
||||
parseResult = nil // clear parse result
|
||||
}
|
||||
// Cache parse results
|
||||
if err := k.cacheParseResults(ctx, doc, parseResult, cacheRecord, cacheKey); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// set id
|
||||
allIDs := make([]int64, 0, len(parseResult))
|
||||
for l := 0; l < len(parseResult); l += 100 {
|
||||
r := min(l+100, len(parseResult))
|
||||
batchSize := r - l
|
||||
ids, err := k.idgen.GenMultiIDs(ctx, batchSize)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeIDGenCode, errorx.KV("msg", fmt.Sprintf("GenMultiIDs failed, err: %v", err)))
|
||||
}
|
||||
allIDs = append(allIDs, ids...)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
id := ids[i]
|
||||
index := l + i
|
||||
parseResult[index].ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
}
|
||||
return parseResult, cacheRecord, nil
|
||||
}
|
||||
|
||||
convertFn := d2sMapping[doc.Type]
|
||||
if convertFn == nil {
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convertFn is empty"))
|
||||
}
|
||||
// cacheParseResults stores parse results in persistent storage and cache
|
||||
func (k *knowledgeSVC) cacheParseResults(ctx context.Context, doc *entity.Document,
|
||||
parseResult []*schema.Document, cacheRecord *indexDocCacheRecord, cacheKey string) error {
|
||||
|
||||
sliceEntities, err := slices.TransformWithErrorCheck(parseResult, func(a *schema.Document) (*entity.Slice, error) {
|
||||
return convertFn(a, doc.KnowledgeID, doc.ID, doc.CreatorID)
|
||||
})
|
||||
parseResultData, err := sonic.Marshal(parseResult)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeParseJSONCode,
|
||||
errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
|
||||
}
|
||||
|
||||
// save slices
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
// Table type to insert data into a database
|
||||
err = k.upsertDataToTable(ctx, &doc.TableInfo, sliceEntities)
|
||||
fileName := fmt.Sprintf("FileBizType.Knowledge/%d_%d.txt", doc.CreatorID, doc.ID)
|
||||
if err = k.storage.PutObject(ctx, fileName, parseResultData); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgePutObjectFailCode,
|
||||
errorx.KV("msg", fmt.Sprintf("put object failed, err: %v", err)))
|
||||
}
|
||||
|
||||
cacheRecord.ParseUri = fileName
|
||||
return k.recordIndexDocumentStatus(ctx, cacheRecord, cacheKey)
|
||||
}
|
||||
|
||||
// loadParsedDocument loads previously parsed document from cache
|
||||
func (k *knowledgeSVC) loadParsedDocument(ctx context.Context,
|
||||
cacheRecord *indexDocCacheRecord) ([]*schema.Document, *indexDocCacheRecord, error) {
|
||||
|
||||
data, err := k.storage.GetObject(ctx, cacheRecord.ParseUri)
|
||||
if err != nil {
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeGetObjectFailCode,
|
||||
errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
|
||||
}
|
||||
|
||||
var parseResult []*schema.Document
|
||||
if err = sonic.Unmarshal(data, &parseResult); err != nil {
|
||||
return nil, nil, errorx.New(errno.ErrKnowledgeParseJSONCode,
|
||||
errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
|
||||
}
|
||||
|
||||
return parseResult, cacheRecord, nil
|
||||
}
|
||||
|
||||
// handleTableDocument handles special processing for table-type documents
|
||||
func (k *knowledgeSVC) handleTableDocument(ctx context.Context,
|
||||
doc *entity.Document, parseResult []*schema.Document) error {
|
||||
|
||||
noData, err := document.GetDocumentsColumnsOnly(parseResult)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
|
||||
errorx.KVf("reason", "[indexDocument] get table data status failed, err: %v", err))
|
||||
}
|
||||
if noData {
|
||||
parseResult = nil // clear parse result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDocumentChunks processes document chunks in batches
|
||||
func (k *knowledgeSVC) processDocumentChunks(ctx context.Context,
|
||||
doc *entity.Document, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord) error {
|
||||
|
||||
batchSize := 100
|
||||
progressbar := progressbar.NewProgressBar(ctx, doc.ID,
|
||||
int64(len(parseResult)*len(k.searchStoreManagers)), k.cacheCli, true)
|
||||
|
||||
if err := progressbar.AddN(int(cacheRecord.LastProcessedNumber) * len(k.searchStoreManagers)); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode,
|
||||
errorx.KV("msg", fmt.Sprintf("add progress bar failed, err: %v", err)))
|
||||
}
|
||||
|
||||
// Process chunks in batches
|
||||
for i := int(cacheRecord.LastProcessedNumber); i < len(parseResult); i += batchSize {
|
||||
chunks := parseResult[i:min(i+batchSize, len(parseResult))]
|
||||
if err := k.batchProcessSlice(ctx, doc, i, chunks, cacheRecord, progressbar); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// finalizeDocumentIndexing completes the document indexing process
|
||||
func (k *knowledgeSVC) finalizeDocumentIndexing(ctx context.Context, knowledgeID, documentID int64) error {
|
||||
if err := k.documentRepo.SetStatus(ctx, documentID, int32(entity.DocumentStatusEnable), ""); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
|
||||
}
|
||||
if err := k.documentRepo.UpdateDocumentSliceInfo(ctx, documentID); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update document slice info failed, err: %v", err)))
|
||||
}
|
||||
if err := k.cacheCli.Del(ctx, fmt.Sprintf(indexDocCacheKey, knowledgeID, documentID)).Err(); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("del cache failed, err: %v", err)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// batchProcessSlice processes a batch of document slices
|
||||
func (k *knowledgeSVC) batchProcessSlice(ctx context.Context, doc *entity.Document,
|
||||
startIdx int, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord,
|
||||
progressBar progressbarContract.ProgressBar) error {
|
||||
|
||||
collectionName := getCollectionName(doc.KnowledgeID)
|
||||
length := len(parseResult)
|
||||
var ids []int64
|
||||
var err error
|
||||
// Generate IDs for this batch
|
||||
if len(cacheRecord.ProcessingIDs) == 0 {
|
||||
ids, err = k.genMultiIDs(ctx, length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
ids = cacheRecord.ProcessingIDs
|
||||
}
|
||||
for idx := range parseResult {
|
||||
parseResult[idx].ID = strconv.FormatInt(ids[idx], 10)
|
||||
}
|
||||
// Update cache record with processing IDs
|
||||
cacheRecord.ProcessingIDs = ids
|
||||
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
|
||||
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert documents to slices
|
||||
sliceEntities, err := k.convertToSlices(doc, parseResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle table-type documents
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
if err := k.upsertDataToTable(ctx, &doc.TableInfo, sliceEntities); err != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] insert data to table failed, err: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store slices in database
|
||||
|
||||
if err := k.storeSlicesInDB(ctx, doc, parseResult, startIdx, ids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Index slices in search stores
|
||||
if err := k.indexSlicesInSearchStores(ctx, doc, collectionName, sliceEntities,
|
||||
cacheRecord, progressBar); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update cache record after successful processing
|
||||
cacheRecord.LastProcessedNumber = int64(startIdx) + int64(length)
|
||||
cacheRecord.ProcessingIDs = nil
|
||||
|
||||
// Mark slices as done
|
||||
err = k.sliceRepo.BatchSetStatus(ctx, ids, int32(model.SliceStatusDone), "")
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch set slice status failed, err: %v", err)))
|
||||
}
|
||||
|
||||
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
|
||||
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertToSlices converts parsed documents to slice entities
|
||||
func (k *knowledgeSVC) convertToSlices(doc *entity.Document, parseResult []*schema.Document) ([]*entity.Slice, error) {
|
||||
|
||||
convertFn := d2sMapping[doc.Type]
|
||||
if convertFn == nil {
|
||||
return nil, errorx.New(errno.ErrKnowledgeSystemCode,
|
||||
errorx.KV("msg", "convertFn is empty"))
|
||||
}
|
||||
|
||||
return slices.TransformWithErrorCheck(parseResult, func(a *schema.Document) (*entity.Slice, error) {
|
||||
return convertFn(a, doc.KnowledgeID, doc.ID, doc.CreatorID)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanupPreviousProcessing cleans up partially processed data from previous attempts
|
||||
func (k *knowledgeSVC) cleanupPreviousProcessing(ctx context.Context, doc *entity.Document) error {
|
||||
collectionName := getCollectionName(doc.KnowledgeID)
|
||||
if doc.IsAppend || doc.Type == knowledge.DocumentTypeImage {
|
||||
return nil
|
||||
}
|
||||
ids, err := k.sliceRepo.GetDocumentSliceIDs(ctx, []int64{doc.ID})
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get document slice ids failed, err: %v", err)))
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
if err = k.sliceRepo.DeleteByDocument(ctx, doc.ID); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("delete document slice failed, err: %v", err)))
|
||||
}
|
||||
|
||||
for _, manager := range k.searchStoreManagers {
|
||||
s, err := manager.GetSearchStore(ctx, collectionName)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
|
||||
}
|
||||
if err := s.Delete(ctx, slices.Transform(ids, func(id int64) string {
|
||||
return strconv.FormatInt(id, 10)
|
||||
})); err != nil {
|
||||
logs.Errorf("[indexDocument] delete knowledge failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("delete search store failed, err: %v", err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
_, err := k.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
|
||||
TableName: doc.TableInfo.PhysicalTableName,
|
||||
Where: &rdb.ComplexCondition{
|
||||
Conditions: []*rdb.Condition{
|
||||
{
|
||||
Field: consts.RDBFieldID,
|
||||
Operator: rdbEntity.OperatorIn,
|
||||
Value: ids,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "delete data failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeCrossDomainCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeSlicesInDB stores slice data in the database
|
||||
func (k *knowledgeSVC) storeSlicesInDB(ctx context.Context, doc *entity.Document,
|
||||
parseResult []*schema.Document, startIdx int, ids []int64) error {
|
||||
|
||||
var seqOffset float64
|
||||
var err error
|
||||
|
||||
if doc.IsAppend {
|
||||
seqOffset, err = k.sliceRepo.GetLastSequence(ctx, doc.ID)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get last sequence failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeDBCode,
|
||||
errorx.KV("msg", fmt.Sprintf("get last sequence failed, err: %v", err)))
|
||||
}
|
||||
seqOffset += 1
|
||||
}
|
||||
|
||||
if doc.Type == knowledge.DocumentTypeImage {
|
||||
if len(parseResult) != 0 {
|
||||
slices, _, err := k.sliceRepo.FindSliceByCondition(ctx, &entity.WhereSliceOpt{DocumentID: doc.ID})
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("find slice failed, err: %v", err)))
|
||||
}
|
||||
var slice *model.KnowledgeDocumentSlice
|
||||
if len(slices) > 0 {
|
||||
slice = slices[0]
|
||||
slice.Content = parseResult[0].Content
|
||||
} else {
|
||||
id, err := k.idgen.GenID(ctx)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeIDGenCode, errorx.KV("msg", fmt.Sprintf("GenID failed, err: %v", err)))
|
||||
}
|
||||
slice = &model.KnowledgeDocumentSlice{
|
||||
ID: id,
|
||||
KnowledgeID: doc.KnowledgeID,
|
||||
DocumentID: doc.ID,
|
||||
Content: parseResult[0].Content,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
CreatorID: doc.CreatorID,
|
||||
SpaceID: doc.SpaceID,
|
||||
Status: int32(model.SliceStatusProcessing),
|
||||
FailReason: "",
|
||||
}
|
||||
}
|
||||
if err = k.sliceRepo.Update(ctx, slice); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update slice failed, err: %v", err)))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
sliceModels := make([]*model.KnowledgeDocumentSlice, 0, len(parseResult))
|
||||
for i, src := range parseResult {
|
||||
now := time.Now().UnixMilli()
|
||||
sliceModel := &model.KnowledgeDocumentSlice{
|
||||
ID: allIDs[i],
|
||||
ID: ids[i],
|
||||
KnowledgeID: doc.KnowledgeID,
|
||||
DocumentID: doc.ID,
|
||||
Content: parseResult[i].Content,
|
||||
Sequence: seqOffset + float64(i),
|
||||
Sequence: seqOffset + float64(i+startIdx),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatorID: doc.CreatorID,
|
||||
@ -314,83 +613,108 @@ func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (
|
||||
Status: int32(model.SliceStatusProcessing),
|
||||
FailReason: "",
|
||||
}
|
||||
|
||||
if doc.Type == knowledge.DocumentTypeTable {
|
||||
convertFn := d2sMapping[doc.Type]
|
||||
sliceEntity, err := convertFn(src, doc.KnowledgeID, doc.ID, doc.CreatorID)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] convert document failed, err: %v", err)
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode,
|
||||
errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
|
||||
}
|
||||
sliceModel.Content = sliceEntity.GetSliceContent()
|
||||
}
|
||||
|
||||
sliceModels = append(sliceModels, sliceModel)
|
||||
}
|
||||
if err = k.sliceRepo.BatchCreate(ctx, sliceModels); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch create slice failed, err: %v", err)))
|
||||
|
||||
err = k.sliceRepo.BatchCreate(ctx, sliceModels)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode,
|
||||
errorx.KV("msg", fmt.Sprintf("batch create slice failed, err: %v", err)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil { // set slice status
|
||||
if setStatusErr := k.sliceRepo.BatchSetStatus(ctx, allIDs, int32(model.SliceStatusFailed), err.Error()); setStatusErr != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] set slice status failed, err: %v", setStatusErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// indexSlicesInSearchStores indexes slices in appropriate search stores
|
||||
func (k *knowledgeSVC) indexSlicesInSearchStores(ctx context.Context, doc *entity.Document,
|
||||
collectionName string, sliceEntities []*entity.Slice, cacheRecord *indexDocCacheRecord,
|
||||
progressBar progressbarContract.ProgressBar) error {
|
||||
|
||||
// to vectorstore
|
||||
fields, err := k.mapSearchFields(doc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
indexingFields := getIndexingFields(fields)
|
||||
|
||||
// reformat docs, mainly for enableCompactTable
|
||||
// Convert slices to search documents
|
||||
ssDocs, err := slices.TransformWithErrorCheck(sliceEntities, func(a *entity.Slice) (*schema.Document, error) {
|
||||
return k.slice2Document(ctx, doc, a)
|
||||
})
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("reformat document failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeSystemCode,
|
||||
errorx.KV("msg", fmt.Sprintf("reformat document failed, err: %v", err)))
|
||||
}
|
||||
progressbar := progressbar.NewProgressBar(ctx, doc.ID, int64(len(ssDocs)*len(k.searchStoreManagers)), k.cacheCli, true)
|
||||
|
||||
// Skip if it's an image document with empty content
|
||||
if doc.Type == knowledge.DocumentTypeImage && len(ssDocs) == 1 && len(ssDocs[0].Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Index in each search store manager
|
||||
for _, manager := range k.searchStoreManagers {
|
||||
now := time.Now()
|
||||
if err = manager.Create(ctx, &searchstore.CreateRequest{
|
||||
if err := manager.Create(ctx, &searchstore.CreateRequest{
|
||||
CollectionName: collectionName,
|
||||
Fields: fields,
|
||||
CollectionMeta: nil,
|
||||
}); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("create search store failed, err: %v", err)))
|
||||
}
|
||||
// Picture knowledge base kn: doc: slice = 1: n: n, maybe the content is empty, no need to write
|
||||
if doc.Type == knowledge.DocumentTypeImage && len(ssDocs) == 1 && len(ssDocs[0].Content) == 0 {
|
||||
continue
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
|
||||
errorx.KV("msg", fmt.Sprintf("create search store failed, err: %v", err)))
|
||||
}
|
||||
|
||||
ss, err := manager.GetSearchStore(ctx, collectionName)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
|
||||
errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
|
||||
}
|
||||
|
||||
if _, err = ss.Store(ctx, ssDocs,
|
||||
searchstore.WithIndexerPartitionKey(fieldNameDocumentID),
|
||||
searchstore.WithPartition(strconv.FormatInt(doc.ID, 10)),
|
||||
searchstore.WithIndexingFields(indexingFields),
|
||||
searchstore.WithProgressBar(progressbar),
|
||||
searchstore.WithProgressBar(progressBar),
|
||||
); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("store search store failed, err: %v", err)))
|
||||
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
|
||||
errorx.KV("msg", fmt.Sprintf("store search store failed, err: %v", err)))
|
||||
}
|
||||
|
||||
logs.CtxDebugf(ctx, "[indexDocument] ss type=%v, len(docs)=%d, finished after %d ms",
|
||||
manager.GetType(), len(ssDocs), time.Now().Sub(now).Milliseconds())
|
||||
}
|
||||
// set slice status
|
||||
if err = k.sliceRepo.BatchSetStatus(ctx, allIDs, int32(model.SliceStatusDone), ""); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch set slice status failed, err: %v", err)))
|
||||
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
|
||||
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// set document status
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusEnable), ""); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
|
||||
// setDocumentStatus updates document status with error handling
|
||||
func (k *knowledgeSVC) setDocumentStatus(ctx context.Context, docID int64, status int32, errMsg string) {
|
||||
if setStatusErr := k.documentRepo.SetStatus(ctx, docID, status, errMsg); setStatusErr != nil {
|
||||
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
|
||||
}
|
||||
if err = k.documentRepo.UpdateDocumentSliceInfo(ctx, event.Document.ID); err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update document slice info failed, err: %v", err)))
|
||||
}
|
||||
|
||||
func (k *knowledgeSVC) recordIndexDocumentStatus(ctx context.Context, r *indexDocCacheRecord, cacheKey string) error {
|
||||
data, err := sonic.Marshal(r)
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeParseJSONCode, errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
|
||||
}
|
||||
err = k.cacheCli.Set(ctx, cacheKey, data, time.Hour*2).Err()
|
||||
if err != nil {
|
||||
return errorx.New(errno.ErrKnowledgeCacheClientSetFailCode, errorx.KV("msg", fmt.Sprintf("set cache failed, err: %v", err)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -217,9 +217,10 @@ type RetrieveContext struct {
|
||||
}
|
||||
|
||||
type KnowledgeInfo struct {
|
||||
DocumentIDs []int64
|
||||
DocumentType knowledge.DocumentType
|
||||
TableColumns []*entity.TableColumn
|
||||
KnowledgeName string
|
||||
DocumentIDs []int64
|
||||
DocumentType knowledge.DocumentType
|
||||
TableColumns []*entity.TableColumn
|
||||
}
|
||||
type AlterTableSchemaRequest struct {
|
||||
DocumentID int64
|
||||
|
||||
@ -58,9 +58,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
@ -70,50 +68,42 @@ import (
|
||||
|
||||
func NewKnowledgeSVC(config *KnowledgeSVCConfig) (Knowledge, eventbus.ConsumerHandler) {
|
||||
svc := &knowledgeSVC{
|
||||
knowledgeRepo: repository.NewKnowledgeDAO(config.DB),
|
||||
documentRepo: repository.NewKnowledgeDocumentDAO(config.DB),
|
||||
sliceRepo: repository.NewKnowledgeDocumentSliceDAO(config.DB),
|
||||
reviewRepo: repository.NewKnowledgeDocumentReviewDAO(config.DB),
|
||||
idgen: config.IDGen,
|
||||
rdb: config.RDB,
|
||||
producer: config.Producer,
|
||||
searchStoreManagers: config.SearchStoreManagers,
|
||||
parseManager: config.ParseManager,
|
||||
storage: config.Storage,
|
||||
reranker: config.Reranker,
|
||||
rewriter: config.Rewriter,
|
||||
nl2Sql: config.NL2Sql,
|
||||
enableCompactTable: ptr.FromOrDefault(config.EnableCompactTable, true),
|
||||
cacheCli: config.CacheCli,
|
||||
isAutoAnnotationSupported: config.IsAutoAnnotationSupported,
|
||||
modelFactory: config.ModelFactory,
|
||||
}
|
||||
if svc.reranker == nil {
|
||||
svc.reranker = rrf.NewRRFReranker(0)
|
||||
}
|
||||
if svc.parseManager == nil {
|
||||
svc.parseManager = builtin.NewManager(config.Storage, config.OCR, nil)
|
||||
knowledgeRepo: repository.NewKnowledgeDAO(config.DB),
|
||||
documentRepo: repository.NewKnowledgeDocumentDAO(config.DB),
|
||||
sliceRepo: repository.NewKnowledgeDocumentSliceDAO(config.DB),
|
||||
reviewRepo: repository.NewKnowledgeDocumentReviewDAO(config.DB),
|
||||
idgen: config.IDGen,
|
||||
rdb: config.RDB,
|
||||
producer: config.Producer,
|
||||
searchStoreManagers: config.SearchStoreManagers,
|
||||
parseManager: config.ParseManager,
|
||||
storage: config.Storage,
|
||||
reranker: config.Reranker,
|
||||
rewriter: config.Rewriter,
|
||||
nl2Sql: config.NL2Sql,
|
||||
enableCompactTable: ptr.FromOrDefault(config.EnableCompactTable, true),
|
||||
cacheCli: config.CacheCli,
|
||||
modelFactory: config.ModelFactory,
|
||||
}
|
||||
|
||||
return svc, svc
|
||||
}
|
||||
|
||||
type KnowledgeSVCConfig struct {
|
||||
DB *gorm.DB // required
|
||||
IDGen idgen.IDGenerator // required
|
||||
RDB rdb.RDB // Required: Form storage
|
||||
Producer eventbus.Producer // Required: Document indexing process goes through mq asynchronous processing
|
||||
SearchStoreManagers []searchstore.Manager // Required: Vector/Full Text
|
||||
ParseManager parser.Manager // Optional: document segmentation and processing capability, default builtin parser
|
||||
Storage storage.Storage // required: oss
|
||||
ModelFactory chatmodel.Factory // Required: Model factory
|
||||
Rewriter messages2query.MessagesToQuery // Optional: Do not overwrite when not configured
|
||||
Reranker rerank.Reranker // Optional: default rrf when not configured
|
||||
NL2Sql nl2sql.NL2SQL // Optional: Not supported by default when not configured
|
||||
EnableCompactTable *bool // Optional: Table data compression, default true
|
||||
OCR ocr.OCR // Optional: ocr, ocr function is not available when not provided
|
||||
CacheCli cache.Cmdable // Optional: cache implementation
|
||||
IsAutoAnnotationSupported bool // Does it support automatic image labeling?
|
||||
DB *gorm.DB // required
|
||||
IDGen idgen.IDGenerator // required
|
||||
RDB rdb.RDB // Required: Form storage
|
||||
Producer eventbus.Producer // Required: Document indexing process goes through mq asynchronous processing
|
||||
SearchStoreManagers []searchstore.Manager // Required: Vector/Full Text
|
||||
ParseManager parser.Manager // Optional: document segmentation and processing capability, default builtin parser
|
||||
Storage storage.Storage // required: oss
|
||||
ModelFactory chatmodel.Factory // Required: Model factory
|
||||
Rewriter messages2query.MessagesToQuery // Optional: Do not overwrite when not configured
|
||||
Reranker rerank.Reranker // Optional: default rrf when not configured
|
||||
NL2Sql nl2sql.NL2SQL // Optional: Not supported by default when not configured
|
||||
EnableCompactTable *bool // Optional: Table data compression, default true
|
||||
OCR ocr.OCR // Optional: ocr, ocr function is not available when not provided
|
||||
CacheCli cache.Cmdable // Optional: cache implementation
|
||||
}
|
||||
|
||||
type knowledgeSVC struct {
|
||||
@ -123,18 +113,17 @@ type knowledgeSVC struct {
|
||||
reviewRepo repository.KnowledgeDocumentReviewRepo
|
||||
modelFactory chatmodel.Factory
|
||||
|
||||
idgen idgen.IDGenerator
|
||||
rdb rdb.RDB
|
||||
producer eventbus.Producer
|
||||
searchStoreManagers []searchstore.Manager
|
||||
parseManager parser.Manager
|
||||
rewriter messages2query.MessagesToQuery
|
||||
reranker rerank.Reranker
|
||||
storage storage.Storage
|
||||
nl2Sql nl2sql.NL2SQL
|
||||
cacheCli cache.Cmdable
|
||||
enableCompactTable bool // Table data compression
|
||||
isAutoAnnotationSupported bool // Does it support automatic image labeling?
|
||||
idgen idgen.IDGenerator
|
||||
rdb rdb.RDB
|
||||
producer eventbus.Producer
|
||||
searchStoreManagers []searchstore.Manager
|
||||
parseManager parser.Manager
|
||||
rewriter messages2query.MessagesToQuery
|
||||
reranker rerank.Reranker
|
||||
storage storage.Storage
|
||||
nl2Sql nl2sql.NL2SQL
|
||||
cacheCli cache.Cmdable
|
||||
enableCompactTable bool // Table data compression
|
||||
}
|
||||
|
||||
func (k *knowledgeSVC) CreateKnowledge(ctx context.Context, request *CreateKnowledgeRequest) (response *CreateKnowledgeResponse, err error) {
|
||||
@ -318,7 +307,7 @@ func (k *knowledgeSVC) checkRequest(request *CreateDocumentRequest) error {
|
||||
}
|
||||
for i := range request.Documents {
|
||||
if request.Documents[i].Type == knowledgeModel.DocumentTypeImage && ptr.From(request.Documents[i].ParsingStrategy.CaptionType) == parser.ImageAnnotationTypeModel {
|
||||
if !k.isAutoAnnotationSupported {
|
||||
if !k.parseManager.IsAutoAnnotationSupported() {
|
||||
return errors.New("auto caption type is not supported")
|
||||
}
|
||||
}
|
||||
@ -887,9 +876,8 @@ func (k *knowledgeSVC) ListSlice(ctx context.Context, request *ListSliceRequest)
|
||||
KnowledgeID: ptr.From(request.KnowledgeID),
|
||||
DocumentID: ptr.From(request.DocumentID),
|
||||
Keyword: request.Keyword,
|
||||
Sequence: request.Sequence,
|
||||
Offset: request.Sequence,
|
||||
PageSize: request.Limit,
|
||||
Offset: request.Offset,
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "list slice failed, err: %v", err)
|
||||
@ -1386,12 +1374,12 @@ func (k *knowledgeSVC) ListPhotoSlice(ctx context.Context, request *ListPhotoSli
|
||||
if request == nil {
|
||||
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is empty"))
|
||||
}
|
||||
sliceArr, total, err := k.sliceRepo.FindSliceByCondition(ctx, &entity.WhereSliceOpt{
|
||||
sliceArr, total, err := k.sliceRepo.ListPhotoSlice(ctx, &entity.WherePhotoSliceOpt{
|
||||
KnowledgeID: request.KnowledgeID,
|
||||
DocumentIDs: request.DocumentIDs,
|
||||
Offset: int64(ptr.From(request.Offset)),
|
||||
PageSize: int64(ptr.From(request.Limit)),
|
||||
NotEmpty: request.HasCaption,
|
||||
Offset: request.Offset,
|
||||
Limit: request.Limit,
|
||||
HasCaption: request.HasCaption,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
||||
@ -1411,7 +1399,7 @@ func (k *knowledgeSVC) ExtractPhotoCaption(ctx context.Context, request *Extract
|
||||
if request == nil {
|
||||
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is empty"))
|
||||
}
|
||||
if !k.isAutoAnnotationSupported {
|
||||
if !k.parseManager.IsAutoAnnotationSupported() {
|
||||
return nil, errorx.New(errno.ErrKnowledgeAutoAnnotationNotSupportedCode, errorx.KV("msg", "auto annotation is not supported"))
|
||||
}
|
||||
docInfo, err := k.documentRepo.GetByID(ctx, request.DocumentID)
|
||||
|
||||
@ -41,6 +41,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
|
||||
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
|
||||
hembed "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
|
||||
@ -169,10 +171,10 @@ func (suite *KnowledgeTestSuite) SetupSuite() {
|
||||
RDB: rdbService,
|
||||
Producer: knowledgeProducer,
|
||||
SearchStoreManagers: mgrs,
|
||||
ParseManager: nil, // default builtin
|
||||
ParseManager: builtin.NewManager(tosClient, nil, nil), // default builtin
|
||||
Storage: tosClient,
|
||||
Rewriter: nil,
|
||||
Reranker: nil, // default rrf
|
||||
Reranker: rrf.NewRRFReranker(0), // default rrf
|
||||
EnableCompactTable: ptr.Of(true),
|
||||
})
|
||||
|
||||
|
||||
@ -32,6 +32,8 @@ import (
|
||||
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
|
||||
producerMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/eventbus"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
@ -98,11 +100,14 @@ func MockKnowledgeSVC(t *testing.T) Knowledge {
|
||||
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
rdb := rdb.NewService(db, mockIDGen)
|
||||
svc, _ := NewKnowledgeSVC(&KnowledgeSVCConfig{
|
||||
DB: db,
|
||||
IDGen: mockIDGen,
|
||||
Storage: mockStorage,
|
||||
Producer: producer,
|
||||
RDB: rdb,
|
||||
DB: db,
|
||||
IDGen: mockIDGen,
|
||||
Storage: mockStorage,
|
||||
Producer: producer,
|
||||
RDB: rdb,
|
||||
Reranker: rrf.NewRRFReranker(0),
|
||||
ParseManager: builtin.NewManager(mockStorage, nil, nil), // default builtin
|
||||
|
||||
})
|
||||
return svc
|
||||
}
|
||||
|
||||
@ -18,11 +18,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
@ -50,6 +48,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
@ -127,6 +126,7 @@ func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequ
|
||||
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
|
||||
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
|
||||
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
|
||||
knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name
|
||||
}
|
||||
}
|
||||
for _, doc := range enableDocs {
|
||||
@ -189,7 +189,7 @@ func (k *knowledgeSVC) prepareRAGDocuments(ctx context.Context, documentIDs []in
|
||||
}
|
||||
|
||||
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, req *RetrieveContext) (newRetrieveContext *RetrieveContext, err error) {
|
||||
if len(req.ChatHistory) == 0 {
|
||||
if len(req.ChatHistory) == 1 {
|
||||
// No context, no rewriting.
|
||||
return req, nil
|
||||
}
|
||||
@ -390,6 +390,10 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||
}
|
||||
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
|
||||
}
|
||||
virtualColumnMap := map[string]*entity.TableColumn{}
|
||||
for i := range doc.TableInfo.Columns {
|
||||
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
|
||||
}
|
||||
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
|
||||
@ -404,15 +408,52 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||
return nil, err
|
||||
}
|
||||
for i := range resp.ResultSet.Rows {
|
||||
d := &schema.Document{
|
||||
Content: "",
|
||||
MetaData: map[string]any{
|
||||
"document_id": doc.ID,
|
||||
"document_name": doc.Name,
|
||||
"knowledge_id": doc.KnowledgeID,
|
||||
"knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName,
|
||||
},
|
||||
}
|
||||
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
|
||||
if !ok {
|
||||
logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
|
||||
return nil, errors.New("convert id failed")
|
||||
}
|
||||
d := &schema.Document{
|
||||
ID: strconv.FormatInt(id, 10),
|
||||
Content: "",
|
||||
MetaData: map[string]any{},
|
||||
byteData, err := sonic.Marshal(resp.ResultSet.Rows)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
prefix := "sql:" + sql + ";result:"
|
||||
d.Content = prefix + string(byteData)
|
||||
} else {
|
||||
transferMap := map[string]string{}
|
||||
for cName, val := range resp.ResultSet.Rows[i] {
|
||||
column, found := virtualColumnMap[cName]
|
||||
if !found {
|
||||
logs.CtxInfof(ctx, "column not found, name: %s", cName)
|
||||
continue
|
||||
}
|
||||
columnData, err := convert.ParseAnyData(column, val)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "parse any data failed: %v", err)
|
||||
return nil, errorx.New(errno.ErrKnowledgeColumnParseFailCode, errorx.KV("msg", err.Error()))
|
||||
}
|
||||
if columnData.Type == document.TableColumnTypeString {
|
||||
columnData.ValString = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
|
||||
}
|
||||
if columnData.Type == document.TableColumnTypeImage {
|
||||
columnData.ValImage = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
|
||||
}
|
||||
transferMap[column.Name] = columnData.GetNullableStringValue()
|
||||
}
|
||||
byteData, err := sonic.Marshal(transferMap)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
d.Content = string(byteData)
|
||||
d.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
d.WithScore(1)
|
||||
retrieveResult = append(retrieveResult, d)
|
||||
@ -423,29 +464,13 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
|
||||
const pkID = "_knowledge_slice_id"
|
||||
|
||||
func addSliceIdColumn(originalSql string) string {
|
||||
lowerSql := strings.ToLower(originalSql)
|
||||
selectIndex := strings.Index(lowerSql, "select ")
|
||||
if selectIndex == -1 {
|
||||
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
|
||||
if err != nil {
|
||||
logs.Errorf("add slice id column failed: %v", err)
|
||||
return originalSql
|
||||
}
|
||||
result := originalSql[:selectIndex+len("select ")] // Keep selected part
|
||||
remainder := originalSql[selectIndex+len("select "):]
|
||||
|
||||
lowerRemainder := strings.ToLower(remainder)
|
||||
fromIndex := strings.Index(lowerRemainder, " from")
|
||||
if fromIndex == -1 {
|
||||
return originalSql
|
||||
}
|
||||
|
||||
columns := strings.TrimSpace(remainder[:fromIndex])
|
||||
if columns != "*" {
|
||||
columns += ", " + pkID
|
||||
}
|
||||
|
||||
result += columns + remainder[fromIndex:]
|
||||
return result
|
||||
return sql
|
||||
}
|
||||
|
||||
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
|
||||
res := &document.TableSchema{}
|
||||
if doc.TableInfo == nil {
|
||||
@ -561,18 +586,39 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
|
||||
sliceIDs := make(sets.Set[int64])
|
||||
docIDs := make(sets.Set[int64])
|
||||
knowledgeIDs := make(sets.Set[int64])
|
||||
|
||||
results = []*knowledgeModel.RetrieveSlice{}
|
||||
documentMap := map[int64]*model.KnowledgeDocument{}
|
||||
knowledgeMap := map[int64]*model.Knowledge{}
|
||||
sliceScoreMap := map[int64]float64{}
|
||||
for _, doc := range retrieveResult {
|
||||
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "convert id failed: %v", err)
|
||||
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
|
||||
if len(doc.ID) == 0 {
|
||||
results = append(results, &knowledgeModel.RetrieveSlice{
|
||||
Slice: &knowledgeModel.Slice{
|
||||
KnowledgeID: doc.MetaData["knowledge_id"].(int64),
|
||||
DocumentID: doc.MetaData["document_id"].(int64),
|
||||
DocumentName: doc.MetaData["document_name"].(string),
|
||||
RawContent: []*knowledgeModel.SliceContent{
|
||||
{
|
||||
Type: knowledgeModel.SliceContentTypeText,
|
||||
Text: ptr.Of(doc.Content),
|
||||
},
|
||||
},
|
||||
Extra: map[string]string{
|
||||
consts.KnowledgeName: doc.MetaData["knowledge_name"].(string),
|
||||
consts.DocumentURL: "",
|
||||
},
|
||||
},
|
||||
Score: 1,
|
||||
})
|
||||
} else {
|
||||
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "convert id failed: %v", err)
|
||||
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
|
||||
}
|
||||
sliceIDs[id] = struct{}{}
|
||||
sliceScoreMap[id] = doc.Score()
|
||||
}
|
||||
sliceIDs[id] = struct{}{}
|
||||
sliceScoreMap[id] = doc.Score()
|
||||
}
|
||||
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
|
||||
if err != nil {
|
||||
@ -625,7 +671,6 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
results = []*knowledgeModel.RetrieveSlice{}
|
||||
for i := range slices {
|
||||
doc := documentMap[slices[i].DocumentID]
|
||||
kn := knowledgeMap[slices[i].KnowledgeID]
|
||||
|
||||
237
backend/domain/knowledge/service/retrieve_test.go
Normal file
237
backend/domain/knowledge/service/retrieve_test.go
Normal file
@ -0,0 +1,237 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
||||
rdb_entity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/nl2sql"
|
||||
mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/rdb"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
||||
)
|
||||
|
||||
func TestAddSliceIdColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
input: "SELECT name, age FROM users",
|
||||
expected: "SELECT `name`,`age`,`_knowledge_slice_id` FROM `users`",
|
||||
},
|
||||
{
|
||||
name: "select stmt wrong",
|
||||
input: "SELECT FROM users",
|
||||
expected: "SELECT FROM users",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := addSliceIdColumn(tt.input)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("AddSliceIdColumn() = %v, want %v", actual, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNL2sqlExec(t *testing.T) {
|
||||
svc := knowledgeSVC{}
|
||||
ctrl := gomock.NewController(t)
|
||||
db := mock_db.NewMockRDB(ctrl)
|
||||
nl2SQL := mock.NewMockNL2SQL(ctrl)
|
||||
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||
return "select count(*) from users", nil
|
||||
})
|
||||
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||
return &rdb.ExecuteSQLResponse{
|
||||
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
|
||||
{
|
||||
"count(*)": 100,
|
||||
},
|
||||
}},
|
||||
}, nil
|
||||
})
|
||||
svc.nl2Sql = nl2SQL
|
||||
svc.rdb = db
|
||||
ctx := context.Background()
|
||||
docu := model.KnowledgeDocument{
|
||||
ID: 110,
|
||||
KnowledgeID: 111,
|
||||
Name: "users",
|
||||
FileExtension: "xlsx",
|
||||
DocumentType: 1,
|
||||
CreatorID: 666,
|
||||
SpaceID: 666,
|
||||
Status: 1,
|
||||
TableInfo: &entity.TableInfo{
|
||||
VirtualTableName: "users",
|
||||
PhysicalTableName: "table_111",
|
||||
TableDesc: "user table",
|
||||
Columns: []*entity.TableColumn{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "_knowledge_slice_id",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Description: "id",
|
||||
Indexing: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "name",
|
||||
Type: document.TableColumnTypeString,
|
||||
Description: "name",
|
||||
Indexing: true,
|
||||
Sequence: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
retrieveCtx := &RetrieveContext{
|
||||
Ctx: ctx,
|
||||
OriginQuery: "select count(*) from users",
|
||||
KnowledgeIDs: sets.FromSlice[int64]([]int64{111}),
|
||||
Documents: []*model.KnowledgeDocument{&docu},
|
||||
KnowledgeInfoMap: map[int64]*KnowledgeInfo{
|
||||
111: &KnowledgeInfo{
|
||||
KnowledgeName: "users",
|
||||
DocumentIDs: []int64{110},
|
||||
DocumentType: 1,
|
||||
TableColumns: []*entity.TableColumn{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "_knowledge_slice_id",
|
||||
Type: document.TableColumnTypeInteger,
|
||||
Description: "id",
|
||||
Indexing: false,
|
||||
Sequence: 1,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "name",
|
||||
Type: document.TableColumnTypeString,
|
||||
Description: "name",
|
||||
Indexing: true,
|
||||
Sequence: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
docs, err := svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(docs))
|
||||
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", docs[0].Content)
|
||||
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||
return "", errors.New("nl2sql error")
|
||||
})
|
||||
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||
assert.Equal(t, "nl2sql error", err.Error())
|
||||
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||
return nil, errors.New("rdb error")
|
||||
})
|
||||
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||
return "select count(*) from users", nil
|
||||
})
|
||||
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||
assert.Equal(t, "rdb error", err.Error())
|
||||
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
|
||||
return &rdb.ExecuteSQLResponse{
|
||||
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
|
||||
{
|
||||
"name": "666",
|
||||
"_knowledge_document_slice_id": int64(999),
|
||||
},
|
||||
}},
|
||||
}, nil
|
||||
})
|
||||
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
|
||||
return "select name from users", nil
|
||||
})
|
||||
docs, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(docs))
|
||||
assert.Equal(t, "999", docs[0].ID)
|
||||
|
||||
}
|
||||
|
||||
func TestPackResults(t *testing.T) {
|
||||
svc := knowledgeSVC{}
|
||||
ctx := context.Background()
|
||||
svc.packResults(ctx, []*schema.Document{})
|
||||
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
if os.Getenv("CI_JOB_NAME") != "" {
|
||||
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
|
||||
}
|
||||
gormDB, err := gorm.Open(mysql.Open(dsn))
|
||||
assert.Equal(t, nil, err)
|
||||
svc.knowledgeRepo = repository.NewKnowledgeDAO(gormDB)
|
||||
svc.documentRepo = repository.NewKnowledgeDocumentDAO(gormDB)
|
||||
svc.sliceRepo = repository.NewKnowledgeDocumentSliceDAO(gormDB)
|
||||
docs := []*schema.Document{
|
||||
{
|
||||
ID: "",
|
||||
Content: "sql:select count(*) from users;result:[{\"count(*)\":100}]",
|
||||
MetaData: map[string]any{
|
||||
"knowledge_id": int64(111),
|
||||
"document_id": int64(110),
|
||||
"document_name": "users",
|
||||
"knowledge_name": "users",
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := svc.packResults(ctx, docs)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(res))
|
||||
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", ptr.From(res[0].Slice.RawContent[0].Text))
|
||||
docs = []*schema.Document{
|
||||
{
|
||||
ID: "10000",
|
||||
Content: "",
|
||||
MetaData: map[string]any{
|
||||
"knowledge_id": int64(111),
|
||||
"document_id": int64(110),
|
||||
"document_name": "users",
|
||||
"knowledge_name": "users",
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err = svc.packResults(ctx, docs)
|
||||
assert.Equal(t, 0, len(res))
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
@ -36,7 +36,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
entity2 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/dal/query"
|
||||
|
||||
@ -263,28 +263,28 @@ func MustString(value any) string {
|
||||
}
|
||||
}
|
||||
|
||||
func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
|
||||
func TryCorrectValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
|
||||
if value == nil {
|
||||
return "", fmt.Errorf("value of '%s' is nil", paramName)
|
||||
}
|
||||
|
||||
switch schemaRef.Value.Type {
|
||||
case openapi3.TypeString:
|
||||
return tryString(value)
|
||||
return tryCorrectString(value)
|
||||
case openapi3.TypeNumber:
|
||||
return tryFloat64(value)
|
||||
return tryCorrectFloat64(value)
|
||||
case openapi3.TypeInteger:
|
||||
return tryInt64(value)
|
||||
return tryCorrectInt64(value)
|
||||
case openapi3.TypeBoolean:
|
||||
return tryBool(value)
|
||||
return tryCorrectBool(value)
|
||||
case openapi3.TypeArray:
|
||||
arrVal, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[TryFixValueType] value '%s' is not array", paramName)
|
||||
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not array", paramName)
|
||||
}
|
||||
|
||||
for i, v := range arrVal {
|
||||
_v, err := TryFixValueType(paramName, schemaRef.Value.Items, v)
|
||||
_v, err := TryCorrectValueType(paramName, schemaRef.Value.Items, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -296,7 +296,7 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
|
||||
case openapi3.TypeObject:
|
||||
mapVal, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[TryFixValueType] value '%s' is not object", paramName)
|
||||
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not object", paramName)
|
||||
}
|
||||
|
||||
for k, v := range mapVal {
|
||||
@ -305,7 +305,7 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
|
||||
continue
|
||||
}
|
||||
|
||||
_v, err := TryFixValueType(k, p, v)
|
||||
_v, err := TryCorrectValueType(k, p, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -315,11 +315,11 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
|
||||
|
||||
return mapVal, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("[TryFixValueType] unsupported schema type '%s'", schemaRef.Value.Type)
|
||||
return nil, fmt.Errorf("[TryCorrectValueType] unsupported schema type '%s'", schemaRef.Value.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func tryString(value any) (string, error) {
|
||||
func tryCorrectString(value any) (string, error) {
|
||||
switch val := value.(type) {
|
||||
case string:
|
||||
return val, nil
|
||||
@ -331,11 +331,15 @@ func tryString(value any) (string, error) {
|
||||
case json.Number:
|
||||
return val.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot convert type from '%T' to string", val)
|
||||
b, err := sonic.MarshalString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("tryCorrectString failed, err=%w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
|
||||
func tryInt64(value any) (int64, error) {
|
||||
func tryCorrectInt64(value any) (int64, error) {
|
||||
switch val := value.(type) {
|
||||
case string:
|
||||
vi64, _ := strconv.ParseInt(val, 10, 64)
|
||||
@ -352,7 +356,7 @@ func tryInt64(value any) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func tryBool(value any) (bool, error) {
|
||||
func tryCorrectBool(value any) (bool, error) {
|
||||
switch val := value.(type) {
|
||||
case string:
|
||||
return strconv.ParseBool(val)
|
||||
@ -363,7 +367,7 @@ func tryBool(value any) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func tryFloat64(value any) (float64, error) {
|
||||
func tryCorrectFloat64(value any) (float64, error) {
|
||||
switch val := value.(type) {
|
||||
case string:
|
||||
return strconv.ParseFloat(val, 64)
|
||||
|
||||
@ -37,7 +37,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
|
||||
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
|
||||
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
@ -1100,15 +1100,15 @@ func (t *toolExecutor) processWithInvalidRespProcessStrategyOfReturnErr(_ contex
|
||||
processor = func(paramName string, paramVal any, schemaVal *openapi3.Schema) (any, error) {
|
||||
switch schemaVal.Type {
|
||||
case openapi3.TypeObject:
|
||||
newParamValMap := map[string]any{}
|
||||
paramValMap, ok := paramVal.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey,
|
||||
"expected '%s' to be of type 'object', but got '%T'", paramName, paramVal))
|
||||
}
|
||||
|
||||
newParamValMap := map[string]any{}
|
||||
for paramName_, paramVal_ := range paramValMap {
|
||||
paramSchema_, ok := schemaVal.Properties[paramName]
|
||||
paramSchema_, ok := schemaVal.Properties[paramName_]
|
||||
if !ok || t.disabledParam(paramSchema_.Value) { // Only the object field can be disabled, and the top level of request and response must be the object structure
|
||||
continue
|
||||
}
|
||||
@ -1122,13 +1122,13 @@ func (t *toolExecutor) processWithInvalidRespProcessStrategyOfReturnErr(_ contex
|
||||
return newParamValMap, nil
|
||||
|
||||
case openapi3.TypeArray:
|
||||
newParamValSlice := []any{}
|
||||
paramValSlice, ok := paramVal.([]any)
|
||||
if !ok {
|
||||
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey,
|
||||
"expected '%s' to be of type 'array', but got '%T'", paramName, paramVal))
|
||||
}
|
||||
|
||||
newParamValSlice := []any{}
|
||||
for _, paramVal_ := range paramValSlice {
|
||||
newParamVal, err := processor(paramName, paramVal_, schemaVal.Items.Value)
|
||||
if err != nil {
|
||||
@ -1426,7 +1426,7 @@ func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3O
|
||||
continue
|
||||
}
|
||||
|
||||
_value, err := encoder.TryFixValueType(paramName, prop, value)
|
||||
_value, err := encoder.TryCorrectValueType(paramName, prop, value)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user