Compare commits

..

55 Commits

Author SHA1 Message Date
24e8dfdb45 fix(singleagent): shortcmd without default value 2025-09-01 10:34:10 +08:00
Zhj
aae865dafb feat: read workflow config file should join workspace dir (#1850) 2025-08-27 09:49:43 +00:00
Zhj
77e1931494 feat: add config for workflow domain (#1847) 2025-08-27 09:04:42 +00:00
5562800958 fix: workflow tool closes stream writer correctly (#1839) 2025-08-27 08:29:42 +00:00
Ryo
263a75b1c0 feat(infra): add file listing support (#1836) 2025-08-27 04:06:55 +00:00
901d0252e8 fix(singleagent): append the user message when generating suggestions (#1821) 2025-08-26 12:52:05 +00:00
Ryo
5ecdddbacb fix(prompt): disallow update if prompt is empty (#1816) 2025-08-26 08:04:49 +00:00
2a704fc873 fix: workflow tool in react agent resume once in one agent run (#1801) 2025-08-26 02:57:37 +00:00
Ryo
f19761fa31 fix(agent): disallow update if prompt is empty (#1802) 2025-08-25 12:22:45 +00:00
Ryo
14ce6bc112 feat(infra): support uploading files via io.Reader (#1793) 2025-08-25 10:25:12 +00:00
Zhj
6fa2acf05a fix: workflow as tool, not be serialized where has not parameters (#1777) 2025-08-25 06:30:33 +00:00
fc47e4096c chore: remove redundant module (#864) 2025-08-22 10:35:54 +00:00
035ed2450b chore: clean up go mod & fix mail link (#861) 2025-08-22 10:01:57 +00:00
Ryo
0fef5a1634 fix(infra): repleace minio:9000 for SSE api (#854) 2025-08-22 07:15:59 +00:00
f09c624988 optimize(infra): remove tautological condition (#856) 2025-08-22 07:00:24 +00:00
59c1d9aa03 feat(knowledge): Support ark rerank (#852) 2025-08-22 06:41:58 +00:00
Zhj
19c63a1150 fix: context cancel not working during node runner execution (#819) 2025-08-21 09:59:01 +00:00
09d00c26cb optimize(knowledge): Optimize the index logic of the knowledge base (#841) 2025-08-21 09:44:26 +00:00
3d53aaa785 fix(docker): add missing MINIO_BUCKET environment variable (#839) 2025-08-21 06:31:22 +00:00
Ryo
5044cb2b85 chore(infra): remove all docker export port (#840) 2025-08-21 03:49:12 +00:00
e7070b419c fix(knowledge): Fix the issue of ineffective pagination parameters in the image-based knowledge base (#831) 2025-08-20 09:49:38 +00:00
f956c18a09 fix: correctly transform single-field Array<object> variables (#702) 2025-08-20 07:57:49 +00:00
a4b11729a6 fix(plugin): ToEinoSchemaParameterInfo nil panic (#832) 2025-08-20 07:21:15 +00:00
Ryo
1dc00e4df8 refactor(infra): es_index_schemas => es_index_schema (#808) 2025-08-19 09:17:51 +00:00
5e9740c047 fix(knowledge): Fix the issue where knowledge cannot execute aggregated SQL (#794) 2025-08-19 07:03:30 +00:00
Ryo
f940edf585 refactor(knowledge): Move the all dependent components to app infra (#795) 2025-08-18 10:18:39 +00:00
23a468c72c fix: openapi message list limit default (#803) 2025-08-18 10:16:21 +00:00
85e6926a14 feat: intergrate gemini embedding (#783) 2025-08-18 08:44:19 +00:00
a9b87c188b fix(ci): correct workflow unit test (#780) 2025-08-15 12:30:13 +00:00
Ryo
ee03b41ad5 fix(infra): remove duplicate init code (#779) 2025-08-15 10:32:39 +00:00
18e45b333f feat(ci): enable unit test for backend (#552) 2025-08-15 09:33:23 +00:00
f040a511e4 fix(singleagent): openapi message list field type (#775) 2025-08-15 08:11:26 +00:00
dfa9eb44e1 fix(singleagent): remove code (#774) 2025-08-15 07:55:11 +00:00
Zhj
4ff734f15f fix: where HTTP node URL, JSON text, and raw text template rendering could not find the corresponding rendering variables (#745) 2025-08-15 03:25:27 +00:00
Ryo
ff00dcb31b refactor(knowledge): Move the searchstore manager to app infra (#764) 2025-08-15 02:46:09 +00:00
710bbbff2b docs: readme security risk tips (#763) 2025-08-14 11:53:09 +00:00
a734d9d8af feat: milvus support to use username+password as auth (#751) 2025-08-14 10:49:03 +00:00
Ryo
174da78c78 chore(ci): 1. Update the Docker image names for both the server and the web.2. Set the default port access to localhost only. (#760) 2025-08-14 09:58:38 +00:00
d58783b11c feat(plugin): supports using json marshal to correct string types by … (#758) 2025-08-14 09:24:36 +00:00
3030d4d627 fix: checkpoint store correctly initialize in multi-layered sub-workf… (#755) 2025-08-14 08:34:39 +00:00
c79ee64fe8 fix: display name for custom SQL node (#747) 2025-08-14 04:09:35 +00:00
8994cec367 fix: Elasticsearch OpIn query append to Must instead of MustNot (#744) 2025-08-14 03:48:43 +00:00
Ryo
dce313b8e3 refactor(workflow): Move the variable component in the Workflow package into the common crossdomain package (#738) 2025-08-14 02:41:14 +00:00
Ryo
5d98e8ef93 refactor(workflow): Move domain resources events into the application layer (#729) 2025-08-13 13:06:56 +00:00
8c3ae99643 fix(frontend): extend image extension to support tos key attribute (#723) 2025-08-13 12:55:57 +00:00
e0800abb99 fix: do not throw error when encountering unknown param for LLM (#735) 2025-08-13 10:15:46 +00:00
Zhj
ffbc108875 fix: copy or move app workflow to library, dependencies on other comp… (#720) 2025-08-13 08:44:44 +00:00
6b60c07c22 feat(infra): integrate PaddleOCR's PP-StructureV3 as a document parser backend (#714) 2025-08-13 08:37:42 +00:00
708a6ed0c0 fix: app workflow publish panic (#719) 2025-08-13 03:40:57 +00:00
Ryo
99c759addc refactor(workflow): Move the plugin component in the Workflow package into the common crossdomain package (#717) 2025-08-13 03:06:53 +00:00
Ryo
b38ab95623 refactor(workflow): Move the knowledge component in the Workflow package into the common crossdomain package (#708) 2025-08-12 09:10:36 +00:00
Ryo
9ff065cebd refactor(workflow): Move the database component in the Workflow package into the common crossdomain package (#704) 2025-08-12 07:42:58 +00:00
e7011f2549 fix(app): avoid nil panic (#693) 2025-08-12 02:03:34 +00:00
643a448157 docs: security risks tips & model configuration with byteplus (#692) 2025-08-11 11:46:09 +00:00
Ryo
e03cf4cc87 feat: optimize the package name for cross-domain functionality (#690) 2025-08-11 11:04:06 +00:00
249 changed files with 8077 additions and 6855 deletions

121
.github/workflows/ci@backend.yml vendored Normal file
View File

@ -0,0 +1,121 @@
name: Backend Tests
on:
pull_request:
paths:
- 'backend/**'
- 'docker/atlas/**'
- '.github/workflows/ci@backend.yml'
push:
branches:
- main
paths:
- 'backend/**'
- 'docker/atlas/**'
- '.github/workflows/ci@backend.yml'
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
env:
DEFAULT_GO_VERSION: "1.24"
jobs:
backend-unit-test:
name: backend-unit-test
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
env:
COVERAGE_FILE: coverage.out
BREAKDOWN_FILE: main.breakdown
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: ${{ env.DEFAULT_GO_VERSION }}
# - name: Shutdown Ubuntu MySQL
# run: service mysql stop
- name: Set Up MySQL
uses: mirromutth/mysql-action@v1.1
with:
host port: 3306
container port: 3306
character set server: 'utf8mb4'
collation server: 'utf8mb4_general_ci'
mysql version: '8.4.5'
mysql database: 'opencoze'
mysql root password: 'root'
- name: Verify MySQL Startup
run: |
echo "Waiting for MySQL to be ready..."
for i in {1..60}; do
if cat /proc/net/tcp | grep 0CEA; then
echo "MySQL port 3306 is listening!"
break
fi
echo "Waiting for MySQL port... ($i/60)"
sleep 1
done
echo "Final verification: MySQL port 3306 is accessible"
- name: Install MySQL Client
run: sudo apt-get update && sudo apt-get install -y mysql-client
- name: Initialize Database
run: mysql -h 127.0.0.1 -P 3306 -u root -proot opencoze < docker/volumes/mysql/schema.sql
- name: Run Go Test
run: |
modules=`find . -name "go.mod" -exec dirname {} \;`
echo $modules
list=""
coverpkg=""
if [[ ! -f "go.work" ]];then go work init;fi
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
go work sync
go test -race -v -coverprofile=${{ env.COVERAGE_FILE }} -gcflags="all=-l -N" -coverpkg=$coverpkg $list
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
name: coze-studio-backend
env_vars: GOLANG,Coze-Studio,BACKEND
fail_ci_if_error: 'false'
files: ${{ env.COVERAGE_FILE }}
token: ${{ secrets.CODECOV_TOKEN }}
- name: Shutdown MySQL
if: always()
continue-on-error: true
run: docker rm -f $(docker ps -q --filter "ancestor=mysql:8.4.5")
benchmark-test:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: ${{ env.DEFAULT_GO_VERSION }}
- name: Run Go Benchmark
run: |
modules=`find . -name "go.mod" -exec dirname {} \;`
echo $modules
list=""
coverpkg=""
if [[ ! -f "go.work" ]];then go work init;fi
for module in $modules; do go work use $module; list=$module"/... "$list; coverpkg=$module"/...,"$coverpkg; done
go work sync
go test -race -v -bench=. -benchmem -run=none -gcflags="all=-l -N" $list

View File

@ -116,6 +116,7 @@ help:
@echo " middleware - Setup middlewares docker environment, but exclude the server app."
@echo " web - Setup web docker environment, include middlewares docker."
@echo " down - Stop the docker containers."
@echo " down_web - Stop the web docker containers."
@echo " clean - Stop the docker containers and clean volumes."
@echo " python - Setup python environment."
@echo " atlas-hash - Rehash atlas migration files."

View File

@ -37,7 +37,6 @@ The backend of Coze Studio is developed using Golang, the frontend uses React +
## Quickstart
Learn how to obtain and deploy the open-source version of Coze Studio, quickly build projects, and experience Coze Studio's open-source version.
> Detailed steps and deployment requirements can be found in [Quickstart](https://github.com/coze-dev/coze-studio/wiki/2.-Quickstart).
Environment requirements:
@ -63,9 +62,10 @@ Deployment steps:
2. Modify the template file in the configuration file directory.
1. Enter the directory `backend/conf/model`. Open the file `ark_doubao-seed-1.6.yaml`.
2. Set the fields `id`, `meta.conn_config.api_key`, `meta.conn_config.model`, and save the file.
* **id**: The model ID in Coze Studio, defined by the developers themselves, must be a non-zero integer and globally unique. Do not modify the model ID after the model goes online.
* **meta.conn_config.api_key**: The API Key for the model service, which in this example is the API Key for Volcengine Ark. Refer to [Retrieve Volcengine Ark API Key](https://www.volcengine.com/docs/82379/1541594) for the acquisition method.
* **meta.conn_config.model**: The model ID of the model service, which in this example is the Endpoint ID of the Volcengine Ark doubao-seed-1.6 model access point. For retrieval methods, refer to [Retrieve Endpoint ID](https://www.volcengine.com/docs/82379/1099522).
* **id**: The model ID in Coze Studio, defined by the developer, must be a non-zero integer and globally unique. Agents or workflows call models based on model IDs. For models that have already been launched, do not modify their IDs; otherwise, it may result in model call failures.
* **meta.conn_config.api_key**: The API Key for the model service. In this example, it is the API Key for Ark API Key. For more information, see [Get Volcengine Ark API Key](https://www.volcengine.com/docs/82379/1541594) or [Get BytePlus ModelArk API Key](https://docs.byteplus.com/en/docs/ModelArk/1361424?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source).
* **meta.conn_config.model**: The Model name for the model service. In this example, it refers to the Model ID or Endpoint ID of Ark. For more information, see [Get Volcengine Ark Model ID](https://www.volcengine.com/docs/82379/1513689) / [Get Volcengine Ark Endpoint ID](https://www.volcengine.com/docs/82379/1099522) or [Get BytePlus ModelArk Model ID](https://docs.byteplus.com/en/docs/ModelArk/model_id?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source) / [Get BytePlus ModelArk Endpoint ID](https://docs.byteplus.com/en/docs/ModelArk/1099522?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source).
> For users in China, you may use Volcengine Ark; for users outside China, you may use BytePlus ModelArk instead.
3. Deploy and start the service.
When deploying and starting Coze Studio for the first time, it may take a while to retrieve images and build local images. Please be patient. During deployment, you will see the following log information. If you see the message "Container coze-server Started," it means the Coze Studio service has started successfully.
```Bash
@ -77,6 +77,8 @@ Deployment steps:
For common startup failure issues, **please refer to the [FAQ](https://github.com/coze-dev/coze-studio/wiki/9.-FAQ)**.
4. After starting the service, you can open Coze Studio by accessing `http://localhost:8888/` through your browser.
> [!WARNING]
> If you want to deploy Coze Studio in a public network environment, it is recommended to assess security risks before you begin, and take corresponding protection measures. Possible security risks include account registration functions, Python execution environments in workflow code nodes, Coze Server listening address configurations, SSRF (Server - Side Request Forgery), and some horizontal privilege escalations in APIs. For more details, refer to [Quickstart](https://github.com/coze-dev/coze-studio/wiki/2.-Quickstart#security-risks-in-public-networks).
## Developer Guide
@ -106,7 +108,7 @@ This project uses the Apache 2.0 license. For details, please refer to the [LICE
## Community contributions
We welcome community contributions. For contribution guidelines, please refer to [CONTRIBUTING](https://github.com/coze-dev/coze-studio/blob/main/CONTRIBUTING.md) and [Code of conduct](https://github.com/coze-dev/coze-studio/blob/main/CODE_OF_CONDUCT.md). We look forward to your contributions!
## Security and privacy
If you discover potential security issues in the project, or believe you may have found a security issue, please notify the ByteDance security team through our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](sec@bytedance.com).
If you discover potential security issues in the project, or believe you may have found a security issue, please notify the ByteDance security team through our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](mailto:sec@bytedance.com).
Please **do not** create public GitHub Issues.
## Join Community

View File

@ -37,7 +37,6 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript
| API 与 SDK | * 创建会话、发起对话等 OpenAPI <br> * 通过 Chat SDK 将智能体或应用集成到自己的应用 |
## 快速开始
了解如何获取并部署 Coze Studio 开源版,快速构建项目、体验 Coze Studio 开源版。
> 详细步骤及部署要求可参考[快速开始](https://github.com/coze-dev/coze-studio/wiki/2.-快速开始)。
环境要求:
@ -63,9 +62,10 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript
2. 在配置文件目录下,修改模版文件。
1. 进入目录 `backend/conf/model`。打开复制后的文件`ark_doubao-seed-1.6.yaml`。
2. 设置 `id`、`meta.conn_config.api_key`、`meta.conn_config.model` 字段,并保存文件。
* **id**Coze Studio 中的模型 ID由开发者自行定义必须是非 0 的整数,且全局唯一。模型上线后请勿修改模型 id
* **meta.conn_config.api_key**:模型服务的 API Key在本示例中为火山方舟的 API Key获取方式可参考[获取火山方舟 API Key](https://www.volcengine.com/docs/82379/1541594)。
* **meta.conn_config.model**:模型服务的 model ID,在本示例中为火山方舟 doubao-seed-1.6 模型接入点的 Endpoint ID获取方式可参考[获取 Endpoint ID](https://www.volcengine.com/docs/82379/1099522)。
* **id**Coze Studio 中的模型 ID由开发者自行定义必须是非 0 的整数,且全局唯一。智能体或工作流根据模型 ID 来调用模型。对于已上线的模型,请勿修改模型 ID否则可能导致模型调用失败
* **meta.conn_config.api_key**:模型服务的 API Key在本示例中为火山方舟的 API Key获取方式可参考[获取火山方舟 API Key](https://www.volcengine.com/docs/82379/1541594) 或[获取 Byteplus ModelArk API Key](https://docs.byteplus.com/en/docs/ModelArk/1361424?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source)
* **meta.conn_config.model**:模型服务的 Model name,在本示例中为火山方舟的 Model ID 或 Endpoint ID获取方式可参考 [获取火山方舟 Model ID](https://www.volcengine.com/docs/82379/1513689) / [获取火山方舟 Endpoint ID](https://www.volcengine.com/docs/82379/1099522),或者参考[获取 BytePlus ModelArk Model ID](https://docs.byteplus.com/en/docs/ModelArk/model_id?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source) / [获取 BytePlus ModelArk Endpoint ID](https://docs.byteplus.com/en/docs/ModelArk/1099522?utm_source=github&utm_medium=readme&utm_campaign=coze_open_source)。
> 中国境内用户可选用火山方舟Volcengine Ark非中国境内的用户则可用 BytePlus ModelArk。
3. 部署并启动服务。
首次部署并启动 Coze Studio 需要拉取镜像、构建本地镜像,可能耗时较久,请耐心等待。部署过程中,你会看到以下日志信息。如果看到提示 "Container coze-server Started",表示 Coze Studio 服务已成功启动。
```Bash
@ -78,6 +78,8 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript
4. 启动服务后,通过浏览器访问 `http://localhost:8888/` 即可打开 Coze Studio。
> [!WARNING]
> 如果要将 Coze Studio 部署到公网环境,建议在部署前评估整体评估安全风险,例如账号注册功能、工作流代码节点 Python执行环境、Coze Server 监听地址配置、SSRF 和部分 API 水平越权的风险,并采取相应防护措施。详细信息可参考[快速开始](https://github.com/coze-dev/coze-studio/wiki/2.-%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B#%E5%85%AC%E7%BD%91%E5%AE%89%E5%85%A8%E9%A3%8E%E9%99%A9)。
## 开发指南
* **项目配置**
@ -106,7 +108,7 @@ Coze Studio 的后端采用 Golang 开发,前端使用 React + TypeScript
## 社区贡献
我们欢迎社区贡献,贡献指南参见 [CONTRIBUTING](https://github.com/coze-dev/coze-studio/blob/main/CONTRIBUTING.md) 和 [Code of conduct](https://github.com/coze-dev/coze-studio/blob/main/CODE_OF_CONDUCT.md),期待您的贡献!
## 安全与隐私
如果你在该项目中发现潜在的安全问题,或你认为可能发现了安全问题,请通过我们的[安全中心](https://security.bytedance.com/src) 或[漏洞报告邮箱](sec@bytedance.com)通知字节跳动安全团队。
如果你在该项目中发现潜在的安全问题,或你认为可能发现了安全问题,请通过我们的[安全中心](https://security.bytedance.com/src) 或[漏洞报告邮箱](mailto:sec@bytedance.com)通知字节跳动安全团队。
请**不要**创建公开的 GitHub Issue。
## 加入社区

View File

@ -31,6 +31,7 @@ import (
"time"
"github.com/alicebob/miniredis/v2"
"github.com/bytedance/mockey"
"github.com/cloudwego/eino/callbacks"
model2 "github.com/cloudwego/eino/components/model"
@ -47,9 +48,12 @@ import (
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
modelknowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
plugin2 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/playground"
pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
@ -59,30 +63,29 @@ import (
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/application/user"
appworkflow "github.com/coze-dev/coze-studio/backend/application/workflow"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
plugin3 "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/database/databasemock"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge/knowledgemock"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
mockmodel "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr/modelmock"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin/pluginmock"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
entity4 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
entity2 "github.com/coze-dev/coze-studio/backend/domain/openauth/openapiauth/entity"
entity3 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
entity5 "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
userentity "github.com/coze-dev/coze-studio/backend/domain/user/entity"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge/knowledgemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin/pluginmock"
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search/searchmock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
mockvar "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable/varmock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
mockvar "github.com/coze-dev/coze-studio/backend/domain/workflow/variable/varmock"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
@ -110,23 +113,23 @@ func TestMain(m *testing.M) {
}
type wfTestRunner struct {
t *testing.T
h *server.Hertz
ctrl *gomock.Controller
idGen *mock.MockIDGenerator
search *searchmock.MockNotifier
appVarS *mockvar.MockStore
userVarS *mockvar.MockStore
varGetter *mockvar.MockVariablesMetaGetter
modelManage *mockmodel.MockManager
plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledgeOperator
database *databasemock.MockDatabaseOperator
pluginSrv *pluginmock.MockService
internalModel *testutil.UTChatModel
ctx context.Context
closeFn func()
t *testing.T
h *server.Hertz
ctrl *gomock.Controller
idGen *mock.MockIDGenerator
appVarS *mockvar.MockStore
userVarS *mockvar.MockStore
varGetter *mockvar.MockVariablesMetaGetter
modelManage *mockmodel.MockManager
plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledge
database *databasemock.MockDatabase
pluginSrv *pluginmock.MockPluginService
internalModel *testutil.UTChatModel
publishPatcher *mockey.Mocker
ctx context.Context
closeFn func()
}
var req2URL = map[reflect.Type]string{
@ -247,13 +250,10 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
mockTos := storageMock.NewMockStorage(ctrl)
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel, nil)
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
mockSearchNotify := searchmock.NewMockNotifier(ctrl)
mockey.Mock(crosssearch.GetNotifier).Return(mockSearchNotify).Build()
mockSearchNotify.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
publishPatcher := mockey.Mock(appworkflow.PublishWorkflowResource).Return(nil).Build()
mockCU := mockCrossUser.NewMockUser(ctrl)
mockCU.EXPECT().GetUserSpaceList(gomock.Any(), gomock.Any()).Return([]*crossuser.EntitySpace{
@ -276,12 +276,12 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
mPlugin := mockPlugin.NewMockPluginService(ctrl)
mockKwOperator := knowledgemock.NewMockKnowledgeOperator(ctrl)
knowledge.SetKnowledgeOperator(mockKwOperator)
mockKwOperator := knowledgemock.NewMockKnowledge(ctrl)
crossknowledge.SetDefaultSVC(mockKwOperator)
mockModelManage := mockmodel.NewMockManager(ctrl)
mockModelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(nil, nil, nil).AnyTimes()
m3 := mockey.Mock(model.GetManager).Return(mockModelManage).Build()
m3 := mockey.Mock(crossmodelmgr.DefaultSVC).Return(mockModelManage).Build()
m := mockey.Mock(crossuser.DefaultSVC).Return(mockCU).Build()
m1 := mockey.Mock(ctxutil.GetApiAuthFromCtx).Return(&entity2.ApiKey{
@ -291,17 +291,18 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
m4 := mockey.Mock(ctxutil.MustGetUIDFromCtx).Return(int64(1)).Build()
m5 := mockey.Mock(ctxutil.GetUIDFromCtx).Return(ptr.Of(int64(1))).Build()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
database.SetDatabaseOperator(mockDatabaseOperator)
mockDatabaseOperator := databasemock.NewMockDatabase(ctrl)
crossdatabase.SetDefaultSVC(mockDatabaseOperator)
mockPluginSrv := pluginmock.NewMockService(ctrl)
plugin.SetPluginService(mockPluginSrv)
mockPluginSrv := pluginmock.NewMockPluginService(ctrl)
crossplugin.SetDefaultSVC(mockPluginSrv)
mockey.Mock((*user.UserApplicationService).MGetUserBasicInfo).Return(&playground.MGetUserBasicInfoResponse{
UserBasicInfoMap: make(map[string]*playground.UserBasicInfo),
}, nil).Build()
f := func() {
publishPatcher.UnPatch()
m.UnPatch()
m1.UnPatch()
m2.UnPatch()
@ -314,23 +315,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
}
return &wfTestRunner{
t: t,
h: h,
ctrl: ctrl,
idGen: mockIDGen,
search: mockSearchNotify,
appVarS: mockGlobalAppVarStore,
userVarS: mockGlobalUserVarStore,
varGetter: mockVarGetter,
modelManage: mockModelManage,
plugin: mPlugin,
tos: mockTos,
knowledge: mockKwOperator,
database: mockDatabaseOperator,
internalModel: utChatModel,
ctx: context.Background(),
closeFn: f,
pluginSrv: mockPluginSrv,
t: t,
h: h,
ctrl: ctrl,
idGen: mockIDGen,
appVarS: mockGlobalAppVarStore,
userVarS: mockGlobalUserVarStore,
varGetter: mockVarGetter,
modelManage: mockModelManage,
plugin: mPlugin,
tos: mockTos,
knowledge: mockKwOperator,
database: mockDatabaseOperator,
internalModel: utChatModel,
ctx: context.Background(),
closeFn: f,
pluginSrv: mockPluginSrv,
publishPatcher: publishPatcher,
}
}
@ -2253,24 +2254,22 @@ func TestNodeWithBatchEnabled(t *testing.T) {
})
e := r.getProcess(id, exeID)
e.assertSuccess()
assert.Equal(t, map[string]any{
outputMap := mustUnmarshalToMap(t, e.output)
assert.Contains(t, outputMap["output"], map[string]any{
"output": []any{
map[string]any{
"output": []any{
"answer",
"for index 0",
},
"input": "answer。for index 0",
},
map[string]any{
"output": []any{
"answer",
"for index 1",
},
"input": "answerfor index 1",
},
"answer",
"for index 0",
},
}, mustUnmarshalToMap(t, e.output))
"input": "answer。for index 0",
})
assert.Contains(t, outputMap["output"], map[string]any{
"output": []any{
"answer",
"for index 1",
},
"input": "answerfor index 1",
})
assert.Equal(t, 2, len(outputMap["output"].([]any)))
e.tokenEqual(10, 12)
// verify this workflow has previously succeeded a test run
@ -2874,9 +2873,8 @@ func TestLLMWithSkills(t *testing.T) {
{ID: int64(7509353598782816256), Operation: operation},
}, nil).AnyTimes()
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
plugin.SetPluginService(pluginSrv)
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
crossplugin.SetDefaultSVC(pluginSrv)
t.Run("llm with plugin tool", func(t *testing.T) {
id := r.load("llm_node_with_skills/llm_node_with_plugin_tool.json")
@ -2998,22 +2996,22 @@ func TestLLMWithSkills(t *testing.T) {
},
}, nil).AnyTimes()
r.knowledge.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(&knowledge.RetrieveResponse{
Slices: []*knowledge.Slice{
{DocumentID: "1", Output: "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌"},
{DocumentID: "2", Output: "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉"},
},
}, nil).AnyTimes()
// r.knowledge.EXPECT().Retrieve(gomock.Any(), gomock.Any()).Return(&knowledge.RetrieveResponse{
// RetrieveSlices: []*knowledge.RetrieveSlice{
// {Slice: &knowledge.Slice{DocumentID: 1, Output: "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌"}, Score: 0.9},
// {Slice: &knowledge.Slice{DocumentID: 2, Output: "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉"}, Score: 0.8},
// },
// }, nil).AnyTimes()
t.Run("llm node with knowledge skill", func(t *testing.T) {
id := r.load("llm_node_with_skills/llm_with_knowledge_skill.json")
exeID := r.testRun(id, map[string]string{
"input": "北京有哪些著名的景点",
})
e := r.getProcess(id, exeID)
e.assertSuccess()
assert.Equal(t, `{"output":"八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌"}`, e.output)
})
// t.Run("llm node with knowledge skill", func(t *testing.T) {
// id := r.load("llm_node_with_skills/llm_with_knowledge_skill.json")
// exeID := r.testRun(id, map[string]string{
// "input": "北京有哪些著名的景点",
// })
// e := r.getProcess(id, exeID)
// e.assertSuccess()
// assert.Equal(t, `{"output":"八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌"}`, e.output)
// })
})
}
@ -3412,8 +3410,8 @@ func TestGetLLMNodeFCSettingsDetailAndMerged(t *testing.T) {
{ID: 123, Operation: operation},
}, nil).AnyTimes()
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
plugin.SetPluginService(pluginSrv)
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
crossplugin.SetDefaultSVC(pluginSrv)
t.Run("plugin tool info ", func(t *testing.T) {
fcSettingDetailReq := &workflow.GetLLMNodeFCSettingDetailRequest{
@ -3529,8 +3527,8 @@ func TestGetLLMNodeFCSettingsDetailAndMerged(t *testing.T) {
{ID: 123, Operation: operation},
}, nil).AnyTimes()
pluginSrv := plugin3.NewPluginService(r.plugin, r.tos)
plugin.SetPluginService(pluginSrv)
pluginSrv := pluginImpl.InitDomainService(r.plugin, r.tos)
crossplugin.SetDefaultSVC(pluginSrv)
t.Run("plugin merge", func(t *testing.T) {
fcSettingMergedReq := &workflow.GetLLMNodeFCSettingsMergedRequest{
@ -3696,7 +3694,7 @@ func TestCopyWorkflow(t *testing.T) {
_, err := appworkflow.GetWorkflowDomainSVC().Get(context.Background(), &vo.GetPolicy{
ID: wid,
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
CommitID: "",
})
assert.NotNil(t, err)
@ -3758,7 +3756,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
wf, err = appworkflow.GetWorkflowDomainSVC().Get(context.Background(), &vo.GetPolicy{
ID: 100100100100,
QType: vo.FromSpecificVersion,
QType: workflowModel.FromSpecificVersion,
Version: version,
})
assert.NoError(t, err)
@ -4038,7 +4036,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
mockey.PatchConvey("copy with subworkflow, subworkflow with external resource ", t, func() {
var copiedIDs = make([]int64, 0)
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
var ignoreIDs = map[int64]bool{
7515027325977624576: true,
7515027249628708864: true,
@ -4046,15 +4044,15 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
7515027150387281920: true,
7515027091302121472: true,
}
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
if ignoreIDs[event.WorkflowID] {
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
if ignoreIDs[workflowID] {
return nil
}
wf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
ID: event.WorkflowID,
QType: vo.FromLatestVersion,
ID: workflowID,
QType: workflowModel.FromLatestVersion,
})
copiedIDs = append(copiedIDs, event.WorkflowID)
copiedIDs = append(copiedIDs, workflowID)
assert.NoError(t, err)
assert.Equal(t, "v0.0.1", wf.Version)
canvas := &vo.Canvas{}
@ -4094,7 +4092,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
subWf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
ID: wfId,
QType: vo.FromLatestVersion,
QType: workflowModel.FromLatestVersion,
})
assert.NoError(t, err)
subworkflowCanvas := &vo.Canvas{}
@ -4143,7 +4141,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
}
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
appID := "7513788954458456064"
appIDInt64, _ := strconv.ParseInt(appID, 10, 64)
@ -4186,21 +4184,21 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
mockey.PatchConvey("copy only with external resource", t, func() {
var copiedIDs = make([]int64, 0)
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
var ignoreIDs = map[int64]bool{
7516518409656336384: true,
7516516198096306176: true,
}
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
if ignoreIDs[event.WorkflowID] {
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
if ignoreIDs[workflowID] {
return nil
}
wf, err := appworkflow.GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
ID: event.WorkflowID,
QType: vo.FromLatestVersion,
ID: workflowID,
QType: workflowModel.FromLatestVersion,
})
copiedIDs = append(copiedIDs, event.WorkflowID)
copiedIDs = append(copiedIDs, workflowID)
assert.NoError(t, err)
assert.Equal(t, "v0.0.1", wf.Version)
canvas := &vo.Canvas{}
@ -4255,7 +4253,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
}
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).CopyKnowledge).Return(&modelknowledge.CopyKnowledgeResponse{
TargetKnowledgeID: 100100,
@ -4296,6 +4294,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
func TestMoveWorkflowAppToLibrary(t *testing.T) {
mockey.PatchConvey("test move workflow", t, func() {
r := newWfTestRunner(t)
r.publishPatcher.UnPatch()
defer r.closeFn()
vars := map[string]*vo.TypeInfo{
"app_v1": {
@ -4315,21 +4314,21 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
r.varGetter.EXPECT().GetAppVariablesMeta(gomock.Any(), gomock.Any(), gomock.Any()).Return(vars, nil).AnyTimes()
t.Run("move workflow", func(t *testing.T) {
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
named2Idx := []string{"c1", "c2", "cc1", "main"}
callCount := 0
initialWf2ID := map[string]int64{}
old2newID := map[int64]int64{}
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
if callCount <= 3 {
initialWf2ID[named2Idx[callCount]] = event.WorkflowID
initialWf2ID[named2Idx[callCount]] = workflowID
callCount++
return nil
}
if OpType == crosssearch.Created {
if oldID, ok := initialWf2ID[*event.Name]; ok {
old2newID[oldID] = event.WorkflowID
if op == search.Created {
if oldID, ok := initialWf2ID[*r.Name]; ok {
old2newID[oldID] = workflowID
}
}
@ -4337,7 +4336,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
}
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
defer mockey.Mock((*appknowledge.KnowledgeApplicationService).MoveKnowledgeToLibrary).Return(nil).Build().UnPatch()
defer mockey.Mock((*appmemory.DatabaseApplicationService).MoveDatabaseToLibrary).Return(&appmemory.MoveDatabaseToLibraryResponse{}, nil).Build().UnPatch()
@ -4455,6 +4454,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
func TestDuplicateWorkflowsByAppID(t *testing.T) {
mockey.PatchConvey("test duplicate work", t, func() {
r := newWfTestRunner(t)
r.publishPatcher.UnPatch()
defer r.closeFn()
vars := map[string]*vo.TypeInfo{
@ -4474,7 +4474,7 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
r.varGetter.EXPECT().GetAppVariablesMeta(gomock.Any(), gomock.Any(), gomock.Any()).Return(vars, nil).AnyTimes()
var copiedIDs = make([]int64, 0)
var mockPublishWorkflowResource func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
var mockPublishWorkflowResource func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error
var ignoreIDs = map[int64]bool{
7515027325977624576: true,
7515027249628708864: true,
@ -4483,16 +4483,16 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
7515027091302121472: true,
7515027325977624579: true,
}
mockPublishWorkflowResource = func(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error {
if ignoreIDs[event.WorkflowID] {
mockPublishWorkflowResource = func(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
if ignoreIDs[workflowID] {
return nil
}
copiedIDs = append(copiedIDs, event.WorkflowID)
copiedIDs = append(copiedIDs, workflowID)
return nil
}
r.search.EXPECT().PublishWorkflowResource(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mockPublishWorkflowResource).AnyTimes()
defer mockey.Mock(appworkflow.PublishWorkflowResource).To(mockPublishWorkflowResource).Build().UnPatch()
appIDInt64 := int64(7513788954458456064)
@ -4660,7 +4660,7 @@ func TestJsonSerializationDeserializationWithWarning(t *testing.T) {
})
}
func TestSetAppVariablesFOrSubProcesses(t *testing.T) {
func TestSetAppVariablesForSubProcesses(t *testing.T) {
mockey.PatchConvey("app variables for sub_process", t, func() {
r := newWfTestRunner(t)
defer r.closeFn()
@ -4677,3 +4677,79 @@ func TestSetAppVariablesFOrSubProcesses(t *testing.T) {
})
}
func TestHttpImplicitDependencies(t *testing.T) {
mockey.PatchConvey("test http implicit dependencies", t, func() {
r := newWfTestRunner(t)
defer r.closeFn()
r.appVarS.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return("1.0", nil).AnyTimes()
idStr := r.load("httprequester/http_implicit_dependencies.json")
r.publish(idStr, "v0.0.1", true)
runner := mockcode.NewMockRunner(r.ctrl)
runner.EXPECT().Run(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *coderunner.RunRequest) (*coderunner.RunResponse, error) {
in := request.Params["input"]
_ = in
result := make(map[string]any)
err := sonic.UnmarshalString(in.(string), &result)
if err != nil {
return nil, err
}
return &coderunner.RunResponse{
Result: result,
}, nil
}).AnyTimes()
code.SetCodeRunner(runner)
mockey.PatchConvey("test http node implicit dependencies", func() {
input := map[string]string{
"input": "a",
}
result, _ := r.openapiSyncRun(idStr, input)
batchRets := result["batch"].([]any)
loopRets := result["loop"].([]any)
for _, r := range batchRets {
assert.Contains(t, []any{
"http://echo.apifox.com/anything?aa=1.0&cc=1",
"http://echo.apifox.com/anything?aa=1.0&cc=2",
}, r)
}
for _, r := range loopRets {
assert.Contains(t, []any{
"http://echo.apifox.com/anything?a=1&m=123",
"http://echo.apifox.com/anything?a=2&m=123",
}, r)
}
})
mockey.PatchConvey("node debug http node implicit dependencies", func() {
exeID := r.nodeDebug(idStr, "109387",
withNDInput(map[string]string{
"__apiInfo_url_87fc7c69536cae843fa7f5113cf0067b": "m",
"__apiInfo_url_ac86361e3cd503952e71986dc091fa6f": "a",
"__body_bodyData_json_ac86361e3cd503952e71986dc091fa6f": "b",
"__body_bodyData_json_f77817a7cf8441279e1cfd8af4eeb1da": "1",
}))
e := r.getProcess(idStr, exeID, withSpecificNodeID("109387"))
e.assertSuccess()
ret := make(map[string]any)
err := sonic.UnmarshalString(e.output, &ret)
assert.Nil(t, err)
err = sonic.UnmarshalString(ret["body"].(string), &ret)
assert.Nil(t, err)
assert.Equal(t, ret["url"].(string), "http://echo.apifox.com/anything?a=a&m=m")
})
})
}

View File

@ -6837,7 +6837,7 @@ type OpenMessageApi struct {
//message content
Content string `thrift:"content,4" form:"content" json:"content" query:"content"`
//session id
ConversationID int64 `thrift:"conversation_id,5" form:"conversation_id" json:"conversation_id" query:"conversation_id"`
ConversationID int64 `thrift:"conversation_id,5" form:"conversation_id" json:"conversation_id,string" query:"conversation_id"`
// custom field
MetaData map[string]string `thrift:"meta_data,6" form:"meta_data" json:"meta_data" query:"meta_data"`
//creation time
@ -6845,7 +6845,7 @@ type OpenMessageApi struct {
//update time
UpdatedAt int64 `thrift:"updated_at,8" form:"updated_at" json:"updated_at" query:"updated_at"`
// ID of a conversation
ChatID int64 `thrift:"chat_id,9" form:"chat_id" json:"chat_id" query:"chat_id"`
ChatID int64 `thrift:"chat_id,9" form:"chat_id" json:"chat_id,string" query:"chat_id"`
// Content type, text/mix
ContentType string `thrift:"content_type,10" form:"content_type" json:"content_type" query:"content_type"`
//Message Type answer/question/function_call/tool_response

View File

@ -204,3 +204,110 @@ type GetAllDatabaseByAppIDRequest struct {
type GetAllDatabaseByAppIDResponse struct {
Databases []*Database // online databases
}
type SQLParam struct {
Value string
IsNull bool
}
type CustomSQLRequest struct {
DatabaseInfoID int64
SQL string
Params []SQLParam
IsDebugRun bool
UserID string
ConnectorID int64
}
type Object = map[string]any
type Response struct {
RowNumber *int64
Objects []Object
}
type Operator string
type ClauseRelation string
const (
ClauseRelationAND ClauseRelation = "and"
ClauseRelationOR ClauseRelation = "or"
)
const (
OperatorEqual Operator = "="
OperatorNotEqual Operator = "!="
OperatorGreater Operator = ">"
OperatorLesser Operator = "<"
OperatorGreaterOrEqual Operator = ">="
OperatorLesserOrEqual Operator = "<="
OperatorIn Operator = "in"
OperatorNotIn Operator = "not_in"
OperatorIsNull Operator = "is_null"
OperatorIsNotNull Operator = "is_not_null"
OperatorLike Operator = "like"
OperatorNotLike Operator = "not_like"
)
type ClauseGroup struct {
Single *Clause
Multi *MultiClause
}
type Clause struct {
Left string
Operator Operator
}
type MultiClause struct {
Clauses []*Clause
Relation ClauseRelation
}
type ConditionStr struct {
Left string
Operator Operator
Right any
}
type ConditionGroup struct {
Conditions []*ConditionStr
Relation ClauseRelation
}
type DeleteRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
IsDebugRun bool
UserID string
ConnectorID int64
}
type QueryRequest struct {
DatabaseInfoID int64
SelectFields []string
Limit int64
ConditionGroup *ConditionGroup
OrderClauses []*OrderClause
IsDebugRun bool
UserID string
ConnectorID int64
}
type OrderClause struct {
FieldID string
IsAsc bool
}
type UpdateRequest struct {
DatabaseInfoID int64
ConditionGroup *ConditionGroup
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}
type InsertRequest struct {
DatabaseInfoID int64
Fields map[string]any
IsDebugRun bool
UserID string
ConnectorID int64
}

View File

@ -23,6 +23,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
@ -124,6 +125,7 @@ type RetrievalStrategy struct {
EnableQueryRewrite bool
EnableRerank bool
EnableNL2SQL bool
IsPersonalOnly bool
}
type SelectType int64
@ -283,3 +285,69 @@ type CopyKnowledgeResponse struct {
type MoveKnowledgeToLibraryRequest struct {
KnowledgeID int64
}
type ParseMode string
const (
FastParseMode = "fast_mode"
AccurateParseMode = "accurate_mode"
)
type ChunkType string
const (
ChunkTypeDefault ChunkType = "default"
ChunkTypeCustom ChunkType = "custom"
ChunkTypeLeveled ChunkType = "leveled"
)
type ParsingStrategy struct {
ParseMode ParseMode
ExtractImage bool
ExtractTable bool
ImageOCR bool
}
type ChunkingStrategy struct {
ChunkType ChunkType
ChunkSize int64
Separator string
Overlap int64
}
type CreateDocumentRequest struct {
KnowledgeID int64
ParsingStrategy *ParsingStrategy
ChunkingStrategy *ChunkingStrategy
FileURL string
FileName string
FileExtension parser.FileExtension
}
type CreateDocumentResponse struct {
DocumentID int64
FileName string
FileURL string
}
type DeleteDocumentRequest struct {
DocumentID string
}
type DeleteDocumentResponse struct {
IsSuccess bool
}
type KnowledgeDetail struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
IconURL string `json:"-"`
FormatType int64 `json:"-"`
}
type ListKnowledgeDetailRequest struct {
KnowledgeIDs []int64
}
type ListKnowledgeDetailResponse struct {
KnowledgeDetails []*KnowledgeDetail
}

View File

@ -0,0 +1,24 @@
package model
type LLMParams struct {
ModelName string `json:"modelName"`
ModelType int64 `json:"modelType"`
Prompt string `json:"prompt"` // user prompt
Temperature *float64 `json:"temperature"`
FrequencyPenalty float64 `json:"frequencyPenalty"`
PresencePenalty float64 `json:"presencePenalty"`
MaxTokens int `json:"maxTokens"`
TopP *float64 `json:"topP"`
TopK *int `json:"topK"`
EnableChatHistory bool `json:"enableChatHistory"`
SystemPrompt string `json:"systemPrompt"`
ResponseFormat ResponseFormat `json:"responseFormat"`
}
type ResponseFormat int64
const (
ResponseFormatText ResponseFormat = 0
ResponseFormatMarkdown ResponseFormat = 1
ResponseFormatJSON ResponseFormat = 2
)

View File

@ -255,6 +255,9 @@ func (op *Openapi3Operation) ToEinoSchemaParameterInfo(ctx context.Context) (map
if err != nil {
return nil, err
}
if paramInfo == nil {
continue
}
if _, ok := result[paramName]; ok {
logs.CtxWarnf(ctx, "duplicate parameter name '%s'", paramName)

View File

@ -0,0 +1,75 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
)
type ToolsInfoRequest struct {
PluginEntity PluginEntity
ToolIDs []int64
IsDraft bool
}
type PluginEntity struct {
PluginID int64
PluginVersion *string // nil or "0" means draft, "" means latest/online version, otherwise is specific version
}
type ToolsInfoResponse struct {
PluginID int64
SpaceID int64
Version string
PluginName string
Description string
IconURL string
PluginType int64
ToolInfoList map[int64]ToolInfoW
LatestVersion *string
IsOfficial bool
AppID int64
}
type ToolInfoW struct {
ToolName string
ToolID int64
Description string
DebugExample *DebugExample
Inputs []*workflow.APIParameter
Outputs []*workflow.APIParameter
}
type DebugExample struct {
ReqExample string
RespExample string
}
type ToolsInvokableRequest struct {
PluginEntity PluginEntity
ToolsInvokableInfo map[int64]*ToolsInvokableInfo
IsDraft bool
}
type WorkflowAPIParameters = []*workflow.APIParameter
type ToolsInvokableInfo struct {
ToolID int64
RequestAPIParametersConfig WorkflowAPIParameters
ResponseAPIParametersConfig WorkflowAPIParameters
}

View File

@ -23,7 +23,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
)
type AgentRuntime struct {

View File

@ -14,7 +14,15 @@
* limitations under the License.
*/
package vo
package workflow
type Locator uint8
const (
FromDraft Locator = iota
FromSpecificVersion
FromLatestVersion
)
type ExecuteConfig struct {
ID int64

View File

@ -22,7 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/application/openauth"
"github.com/coze-dev/coze-studio/backend/application/template"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
"github.com/coze-dev/coze-studio/backend/application/app"
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
@ -39,18 +39,19 @@ import (
"github.com/coze-dev/coze-studio/backend/application/upload"
"github.com/coze-dev/coze-studio/backend/application/user"
"github.com/coze-dev/coze-studio/backend/application/workflow"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
agentrunImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/agentrun"
connectorImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/connector"
conversationImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/conversation"
@ -59,12 +60,14 @@ import (
dataCopyImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/datacopy"
knowledgeImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/knowledge"
messageImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/message"
modelmgrImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/modelmgr"
pluginImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/plugin"
searchImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/search"
singleagentImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/singleagent"
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
implEventbus "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
)
@ -130,7 +133,7 @@ func Init(ctx context.Context) (err error) {
crossconnector.SetDefaultSVC(connectorImpl.InitDomainService(basicServices.connectorSVC.DomainSVC))
crossdatabase.SetDefaultSVC(databaseImpl.InitDomainService(primaryServices.memorySVC.DatabaseDomainSVC))
crossknowledge.SetDefaultSVC(knowledgeImpl.InitDomainService(primaryServices.knowledgeSVC.DomainSVC))
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC))
crossplugin.SetDefaultSVC(pluginImpl.InitDomainService(primaryServices.pluginSVC.DomainSVC, infra.TOSClient))
crossvariables.SetDefaultSVC(variablesImpl.InitDomainService(primaryServices.memorySVC.VariablesDomainSVC))
crossworkflow.SetDefaultSVC(workflowImpl.InitDomainService(primaryServices.workflowSVC.DomainSVC))
crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC))
@ -140,6 +143,7 @@ func Init(ctx context.Context) (err error) {
crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC))
crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra))
crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC))
crossmodelmgr.SetDefaultSVC(modelmgrImpl.InitDomainService(infra.ModelMgr, nil))
return nil
}
@ -188,7 +192,9 @@ func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*pr
memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())
knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
knowledgeSVC, err := knowledge.InitService(ctx,
basicServices.toKnowledgeServiceComponents(memorySVC),
basicServices.eventbus.resourceEventBus)
if err != nil {
return nil, err
}
@ -252,14 +258,19 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
return &knowledge.ServiceComponents{
DB: b.infra.DB,
IDGenSVC: b.infra.IDGenSVC,
Storage: b.infra.TOSClient,
RDB: memoryService.RDBDomainSVC,
ImageX: b.infra.ImageXClient,
ES: b.infra.ESClient,
EventBus: b.eventbus.resourceEventBus,
CacheCli: b.infra.CacheCli,
DB: b.infra.DB,
IDGen: b.infra.IDGenSVC,
RDB: memoryService.RDBDomainSVC,
Producer: b.infra.KnowledgeEventProducer,
SearchStoreManagers: b.infra.SearchStoreManagers,
ParseManager: b.infra.ParserManager,
Storage: b.infra.TOSClient,
Rewriter: b.infra.Rewriter,
Reranker: b.infra.Reranker,
NL2Sql: b.infra.NL2SQL,
OCR: b.infra.OCR,
CacheCli: b.infra.CacheCli,
ModelFactory: chatmodel.NewDefaultFactory(),
}
}
@ -276,19 +287,19 @@ func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {
func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
return &workflow.ServiceComponents{
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
Cache: b.infra.CacheCli,
Tos: b.infra.TOSClient,
ImageX: b.infra.ImageXClient,
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
PluginDomainSVC: pluginSVC.DomainSVC,
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
ModelManager: b.infra.ModelMgr,
DomainNotifier: b.eventbus.resourceEventBus,
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
CodeRunner: b.infra.CodeRunner,
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
Cache: b.infra.CacheCli,
Tos: b.infra.TOSClient,
ImageX: b.infra.ImageXClient,
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
PluginDomainSVC: pluginSVC.DomainSVC,
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
DomainNotifier: b.eventbus.resourceEventBus,
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
CodeRunner: b.infra.CodeRunner,
WorkflowBuildInChatModel: b.infra.WorkflowBuildInChatModel,
}
}

View File

@ -18,40 +18,86 @@ package appinfra
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"google.golang.org/genai"
"gorm.io/gorm"
"github.com/cloudwego/eino-ext/components/embedding/gemini"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/visual"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
embeddingHttp "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/infra/impl/es"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type AppDependencies struct {
DB *gorm.DB
CacheCli cache.Cmdable
IDGenSVC idgen.IDGenerator
ESClient es.Client
ImageXClient imagex.ImageX
TOSClient storage.Storage
ResourceEventProducer eventbus.Producer
AppEventProducer eventbus.Producer
ModelMgr modelmgr.Manager
CodeRunner coderunner.Runner
DB *gorm.DB
CacheCli cache.Cmdable
IDGenSVC idgen.IDGenerator
ESClient es.Client
ImageXClient imagex.ImageX
TOSClient storage.Storage
ResourceEventProducer eventbus.Producer
AppEventProducer eventbus.Producer
KnowledgeEventProducer eventbus.Producer
ModelMgr modelmgr.Manager
CodeRunner coderunner.Runner
OCR ocr.OCR
ParserManager parser.Manager
SearchStoreManagers []searchstore.Manager
Reranker rerank.Reranker
Rewriter messages2query.MessagesToQuery
NL2SQL nl2sql.NL2SQL
WorkflowBuildInChatModel chatmodel.BaseChatModel
}
func Init(ctx context.Context) (*AppDependencies, error) {
@ -60,55 +106,195 @@ func Init(ctx context.Context) (*AppDependencies, error) {
deps.DB, err = mysql.New()
if err != nil {
return nil, err
return nil, fmt.Errorf("init db failed, err=%w", err)
}
deps.CacheCli = redis.New()
deps.IDGenSVC, err = idgen.New(deps.CacheCli)
if err != nil {
return nil, err
return nil, fmt.Errorf("init id gen svc failed, err=%w", err)
}
deps.ESClient, err = es.New()
if err != nil {
return nil, err
return nil, fmt.Errorf("init es client failed, err=%w", err)
}
deps.ImageXClient, err = initImageX(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("init imagex client failed, err=%w", err)
}
deps.TOSClient, err = initTOS(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("init tos client failed, err=%w", err)
}
deps.ResourceEventProducer, err = initResourceEventBusProducer()
if err != nil {
return nil, err
return nil, fmt.Errorf("init resource event bus producer failed, err=%w", err)
}
deps.AppEventProducer, err = initAppEventProducer()
if err != nil {
return nil, err
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
}
deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer()
if err != nil {
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
}
deps.Reranker = initReranker()
deps.Rewriter, err = initRewriter(ctx)
if err != nil {
return nil, fmt.Errorf("init rewriter failed, err=%w", err)
}
deps.NL2SQL, err = initNL2SQL(ctx)
if err != nil {
return nil, fmt.Errorf("init nl2sql failed, err=%w", err)
}
deps.ModelMgr, err = initModelMgr()
if err != nil {
return nil, err
return nil, fmt.Errorf("init model manager failed, err=%w", err)
}
deps.CodeRunner = initCodeRunner()
deps.OCR = initOCR()
imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_")
if err != nil {
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
}
var ok bool
deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_")
if err != nil {
return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err)
}
if !ok {
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
}
deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
if err != nil {
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
}
deps.SearchStoreManagers, err = initSearchStoreManagers(ctx, deps.ESClient)
if err != nil {
return nil, fmt.Errorf("init search store managers failed, err=%w", err)
}
return deps, nil
}
func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.Manager, error) {
// es full text search
esSearchstoreManager := elasticsearch.NewManager(&elasticsearch.ManagerConfig{Client: es})
// vector search
mgr, err := getVectorStore(ctx)
if err != nil {
return nil, fmt.Errorf("init vector store failed, err=%w", err)
}
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
}
func initReranker() rerank.Reranker {
rerankerType := os.Getenv("RERANK_TYPE")
switch rerankerType {
case "vikingdb":
return vikingReranker.NewReranker(getVikingRerankerConfig())
case "rrf":
return rrf.NewRRFReranker(0)
default:
return rrf.NewRRFReranker(0)
}
}
func getVikingRerankerConfig() *vikingReranker.Config {
return &vikingReranker.Config{
AK: os.Getenv("VIKINGDB_RERANK_AK"),
SK: os.Getenv("VIKINGDB_RERANK_SK"),
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
}
}
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
if err != nil {
return nil, err
}
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
if err != nil {
return nil, err
}
return rewriter, nil
}
func getWorkingDirectory() string {
root, err := os.Getwd()
if err != nil {
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
root = os.Getenv("PWD")
}
return root
}
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
b, err := os.ReadFile(jsonFilePath)
if err != nil {
return nil, err
}
var m2qMessages []*schema.Message
if err = json.Unmarshal(b, &m2qMessages); err != nil {
return nil, err
}
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
for i := range m2qMessages {
tpl[i] = m2qMessages[i]
}
return prompt.FromMessages(schema.Jinja2, tpl...), nil
}
func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) {
n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_")
if err != nil {
return nil, err
}
filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
n2sTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
if err != nil {
return nil, err
}
return n2s, nil
}
func initImageX(ctx context.Context) (imagex.ImageX, error) {
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
if uploadComponentType != consts.FileUploadComponentTypeImagex {
return storage.NewImagex(ctx)
}
@ -147,6 +333,17 @@ func initAppEventProducer() (eventbus.Producer, error) {
return appEventProducer, nil
}
func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
nameServer := os.Getenv(consts.MQServer)
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
if err != nil {
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
}
return knowledgeProducer, nil
}
func initCodeRunner() coderunner.Runner {
switch typ := os.Getenv(consts.CodeRunnerType); typ {
case "sandbox":
@ -183,3 +380,337 @@ func initCodeRunner() coderunner.Runner {
return direct.NewRunner()
}
}
func initOCR() ocr.OCR {
var ocr ocr.OCR
switch os.Getenv(consts.OCRType) {
case "ve":
ocrAK := os.Getenv(consts.VeOCRAK)
ocrSK := os.Getenv(consts.VeOCRSK)
if ocrAK == "" || ocrSK == "" {
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
}
inst := visual.NewInstance()
inst.Client.SetAccessKey(ocrAK)
inst.Client.SetSecretKey(ocrSK)
ocr = veocr.NewOCR(&veocr.Config{Client: inst})
case "paddleocr":
url := os.Getenv(consts.PPOCRAPIURL)
client := &http.Client{}
ocr = ppocr.NewOCR(&ppocr.Config{Client: client, URL: url})
default:
// accept ocr not configured
}
return ocr
}
func initParserManager(storage storage.Storage, ocr ocr.OCR, imageAnnotationModel chatmodel.BaseChatModel) (parser.Manager, error) {
var parserManager parser.Manager
parserType := os.Getenv(consts.ParserType)
switch parserType {
case "builtin", "":
parserManager = builtin.NewManager(storage, ocr, imageAnnotationModel)
case "paddleocr":
url := os.Getenv(consts.PPStructureAPIURL)
client := &http.Client{}
apiConfig := &ppstructure.APIConfig{
Client: client,
URL: url,
}
parserManager = ppstructure.NewManager(apiConfig, ocr, storage, imageAnnotationModel)
default:
return nil, fmt.Errorf("parser type %s not supported", parserType)
}
return parserManager, nil
}
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
vsType := os.Getenv("VECTOR_STORE_TYPE")
switch vsType {
case "milvus":
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
user := os.Getenv("MILVUS_USER")
password := os.Getenv("MILVUS_PASSWORD")
mc, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
Address: milvusAddr,
Username: user,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
mgr, err := milvus.NewManager(&milvus.ManagerConfig{
Client: mc,
Embedding: emb,
EnableHybrid: ptr.Of(true),
})
if err != nil {
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
}
return mgr, nil
case "vikingdb":
var (
host = os.Getenv("VIKING_DB_HOST")
region = os.Getenv("VIKING_DB_REGION")
ak = os.Getenv("VIKING_DB_AK")
sk = os.Getenv("VIKING_DB_SK")
scheme = os.Getenv("VIKING_DB_SCHEME")
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
)
if ak == "" || sk == "" {
return nil, fmt.Errorf("invalid vikingdb ak / sk")
}
if host == "" {
host = "api-vikingdb.volces.com"
}
if region == "" {
region = "cn-beijing"
}
if scheme == "" {
scheme = "https"
}
var embConfig *vikingdb.VikingEmbeddingConfig
if modelName != "" {
embName := vikingdb.VikingEmbeddingModelName(modelName)
if embName.Dimensions() == 0 {
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
}
embConfig = &vikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
ModelName: embName,
ModelVersion: embName.ModelVersion(),
DenseWeight: ptr.Of(0.2),
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
embConfig = &vikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: false,
EnableHybrid: false,
BuiltinEmbedding: builtinEmbedding,
}
}
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
mgr, err := vikingdb.NewManager(&vikingdb.ManagerConfig{
Service: svc,
IndexingConfig: nil, // use default config
EmbeddingConfig: embConfig,
})
if err != nil {
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
}
return mgr, nil
default:
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var batchSize int
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
batchSize = 100
} else {
batchSize = int(bs)
}
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
case "openai":
var (
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
)
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
if err != nil {
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
}
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
}
openAICfg := &openai.EmbeddingConfig{
APIKey: openAIEmbeddingApiKey,
ByAzure: byAzure,
BaseURL: openAIEmbeddingBaseURL,
APIVersion: openAIEmbeddingApiVersion,
Model: openAIEmbeddingModel,
// Dimensions: ptr.Of(int(dims)),
}
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
if reqDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case "ark":
var (
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
// deprecated: use ARK_EMBEDDING_API_KEY instead
// ARK_EMBEDDING_AK will be removed in the future
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
apiType := ark.APITypeText
if arkEmbeddingAPIType != "" {
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
} else {
apiType = t
}
}
emb, err = ark.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: func() string {
if arkEmbeddingApiKey != "" {
return arkEmbeddingApiKey
}
return arkEmbeddingAK
}(),
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
APIType: &apiType,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case "ollama":
var (
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
}
emb, err = wrap.NewOllamaEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case "gemini":
var (
geminiEmbeddingBaseURL = os.Getenv("GEMINI_EMBEDDING_BASE_URL")
geminiEmbeddingModel = os.Getenv("GEMINI_EMBEDDING_MODEL")
geminiEmbeddingApiKey = os.Getenv("GEMINI_EMBEDDING_API_KEY")
geminiEmbeddingDims = os.Getenv("GEMINI_EMBEDDING_DIMS")
geminiEmbeddingBackend = os.Getenv("GEMINI_EMBEDDING_BACKEND") // "1" for BackendGeminiAPI / "2" for BackendVertexAI
geminiEmbeddingProject = os.Getenv("GEMINI_EMBEDDING_PROJECT")
geminiEmbeddingLocation = os.Getenv("GEMINI_EMBEDDING_LOCATION")
)
if len(geminiEmbeddingModel) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_MODEL environment variable is required")
}
if len(geminiEmbeddingApiKey) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_API_KEY environment variable is required")
}
if len(geminiEmbeddingDims) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_DIMS environment variable is required")
}
if len(geminiEmbeddingBackend) == 0 {
return nil, fmt.Errorf("GEMINI_EMBEDDING_BACKEND environment variable is required")
}
dims, convErr := strconv.ParseInt(geminiEmbeddingDims, 10, 64)
if convErr != nil {
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_DIMS value: %s, err=%w", geminiEmbeddingDims, convErr)
}
backend, convErr := strconv.ParseInt(geminiEmbeddingBackend, 10, 64)
if convErr != nil {
return nil, fmt.Errorf("invalid GEMINI_EMBEDDING_BACKEND value: %s, err=%w", geminiEmbeddingBackend, convErr)
}
geminiCli, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: geminiEmbeddingApiKey,
Backend: genai.Backend(backend),
Project: geminiEmbeddingProject,
Location: geminiEmbeddingLocation,
HTTPOptions: genai.HTTPOptions{
BaseURL: geminiEmbeddingBaseURL,
},
})
if err != nil {
return nil, fmt.Errorf("init gemini client failed, err=%w", err)
}
emb, err = wrap.NewGeminiEmbedder(ctx, &gemini.EmbeddingConfig{
Client: geminiCli,
Model: geminiEmbeddingModel,
OutputDimensionality: ptr.Of(int32(dims)),
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init gemini embedding failed, err=%w", err)
}
case "http":
var (
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
}
emb, err = embeddingHttp.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package internal
package appinfra
import (
"context"
@ -33,7 +33,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
)
func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
getEnv := func(key string) string {
if val := os.Getenv(envPrefix + key); val != "" {
return val
@ -99,7 +99,7 @@ func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B
}
if err != nil {
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
return nil, false, fmt.Errorf("builtin %s chat model init failed, %w", envPrefix, err)
}
if bcm != nil {
configured = true

View File

@ -53,6 +53,10 @@ func (m *OpenapiMessageApplication) GetApiMessageList(ctx context.Context, mr *m
return nil, errorx.New(errno.ErrConversationPermissionCode, errorx.KV("msg", "permission denied"))
}
if mr.Limit == nil {
mr.Limit = ptr.Of(int64(50))
}
msgListMeta := &entity.ListMeta{
ConversationID: currentConversation.ID,
AgentID: currentConversation.AgentID,

View File

@ -18,418 +18,27 @@ package knowledge
import (
"context"
"encoding/json"
"fmt"
netHTTP "net/http"
"os"
"path/filepath"
"strconv"
"time"
"github.com/cloudwego/eino-ext/components/embedding/ark"
ollamaEmb "github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/vikingdb"
"github.com/volcengine/volc-sdk-golang/service/visual"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/internal"
"github.com/coze-dev/coze-studio/backend/application/search"
knowledgeImpl "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/es"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
chatmodelImpl "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
builtinParser "github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
ssvikingdb "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
arkemb "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/ark"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
"github.com/coze-dev/coze-studio/backend/infra/impl/embedding/wrap"
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/consts"
)
type ServiceComponents struct {
DB *gorm.DB
IDGenSVC idgen.IDGenerator
Storage storage.Storage
RDB rdb.RDB
ImageX imagex.ImageX
ES es.Client
EventBus search.ResourceEventBus
CacheCli cache.Cmdable
}
type ServiceComponents = knowledgeImpl.KnowledgeSVCConfig
func InitService(c *ServiceComponents) (*KnowledgeApplicationService, error) {
ctx := context.Background()
func InitService(ctx context.Context, c *ServiceComponents, bus search.ResourceEventBus) (*KnowledgeApplicationService, error) {
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(c)
nameServer := os.Getenv(consts.MQServer)
knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
if err != nil {
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
}
var sManagers []searchstore.Manager
// es full text search
sManagers = append(sManagers, sses.NewManager(&sses.ManagerConfig{Client: c.ES}))
// vector search
mgr, err := getVectorStore(ctx)
if err != nil {
return nil, fmt.Errorf("init vector store failed, err=%w", err)
}
sManagers = append(sManagers, mgr)
var ocrImpl ocr.OCR
switch os.Getenv("OCR_TYPE") {
case "ve":
ocrAK := os.Getenv("VE_OCR_AK")
ocrSK := os.Getenv("VE_OCR_SK")
if ocrAK == "" || ocrSK == "" {
logs.Warnf("[ve_ocr] ak / sk not configured, ocr might not work well")
}
inst := visual.NewInstance()
inst.Client.SetAccessKey(ocrAK)
inst.Client.SetSecretKey(ocrSK)
ocrImpl = veocr.NewOCR(&veocr.Config{Client: inst})
case "paddleocr":
ppocrURL := os.Getenv("PADDLEOCR_OCR_API_URL")
client := &netHTTP.Client{}
ocrImpl = veocr.NewPPOCR(&veocr.PPOCRConfig{Client: client, URL: ppocrURL})
default:
// accept ocr not configured
}
root, err := os.Getwd()
if err != nil {
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
root = os.Getenv("PWD")
}
var rewriter messages2query.MessagesToQuery
if rewriterChatModel, _, err := internal.GetBuiltinChatModel(ctx, "M2Q_"); err != nil {
return nil, err
} else {
filePath := filepath.Join(root, "resources/conf/prompt/messages_to_query_template_jinja2.json")
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
rewriter, err = builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
if err != nil {
return nil, err
}
}
var n2s nl2sql.NL2SQL
if n2sChatModel, _, err := internal.GetBuiltinChatModel(ctx, "NL2SQL_"); err != nil {
return nil, err
} else {
filePath := filepath.Join(root, "resources/conf/prompt/nl2sql_template_jinja2.json")
n2sTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}
n2s, err = builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
if err != nil {
return nil, err
}
}
imageAnnoChatModel, configured, err := internal.GetBuiltinChatModel(ctx, "IA_")
if err != nil {
return nil, err
}
knowledgeDomainSVC, knowledgeEventHandler := knowledgeImpl.NewKnowledgeSVC(&knowledgeImpl.KnowledgeSVCConfig{
DB: c.DB,
IDGen: c.IDGenSVC,
RDB: c.RDB,
Producer: knowledgeProducer,
SearchStoreManagers: sManagers,
ParseManager: builtinParser.NewManager(c.Storage, ocrImpl, imageAnnoChatModel), // default builtin
Storage: c.Storage,
Rewriter: rewriter,
Reranker: rrf.NewRRFReranker(0), // default rrf
NL2Sql: n2s,
OCR: ocrImpl,
CacheCli: c.CacheCli,
IsAutoAnnotationSupported: configured,
ModelFactory: chatmodelImpl.NewDefaultFactory(),
})
if err = eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
if err := eventbus.DefaultSVC().RegisterConsumer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, knowledgeEventHandler); err != nil {
return nil, fmt.Errorf("register knowledge consumer failed, err=%w", err)
}
KnowledgeSVC.DomainSVC = knowledgeDomainSVC
KnowledgeSVC.eventBus = c.EventBus
KnowledgeSVC.eventBus = bus
KnowledgeSVC.storage = c.Storage
return KnowledgeSVC, nil
}
func getVectorStore(ctx context.Context) (searchstore.Manager, error) {
vsType := os.Getenv("VECTOR_STORE_TYPE")
switch vsType {
case "milvus":
cctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
milvusAddr := os.Getenv("MILVUS_ADDR")
mc, err := milvusclient.New(cctx, &milvusclient.ClientConfig{Address: milvusAddr})
if err != nil {
return nil, fmt.Errorf("init milvus client failed, err=%w", err)
}
emb, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("init milvus embedding failed, err=%w", err)
}
mgr, err := ssmilvus.NewManager(&ssmilvus.ManagerConfig{
Client: mc,
Embedding: emb,
EnableHybrid: ptr.Of(true),
})
if err != nil {
return nil, fmt.Errorf("init milvus vector store failed, err=%w", err)
}
return mgr, nil
case "vikingdb":
var (
host = os.Getenv("VIKING_DB_HOST")
region = os.Getenv("VIKING_DB_REGION")
ak = os.Getenv("VIKING_DB_AK")
sk = os.Getenv("VIKING_DB_SK")
scheme = os.Getenv("VIKING_DB_SCHEME")
modelName = os.Getenv("VIKING_DB_MODEL_NAME")
)
if ak == "" || sk == "" {
return nil, fmt.Errorf("invalid vikingdb ak / sk")
}
if host == "" {
host = "api-vikingdb.volces.com"
}
if region == "" {
region = "cn-beijing"
}
if scheme == "" {
scheme = "https"
}
var embConfig *ssvikingdb.VikingEmbeddingConfig
if modelName != "" {
embName := ssvikingdb.VikingEmbeddingModelName(modelName)
if embName.Dimensions() == 0 {
return nil, fmt.Errorf("embedding model not support, model_name=%s", modelName)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: true,
EnableHybrid: embName.SupportStatus() == embedding.SupportDenseAndSparse,
ModelName: embName,
ModelVersion: embName.ModelVersion(),
DenseWeight: ptr.Of(0.2),
BuiltinEmbedding: nil,
}
} else {
builtinEmbedding, err := getEmbedding(ctx)
if err != nil {
return nil, fmt.Errorf("builtint embedding init failed, err=%w", err)
}
embConfig = &ssvikingdb.VikingEmbeddingConfig{
UseVikingEmbedding: false,
EnableHybrid: false,
BuiltinEmbedding: builtinEmbedding,
}
}
svc := vikingdb.NewVikingDBService(host, region, ak, sk, scheme)
mgr, err := ssvikingdb.NewManager(&ssvikingdb.ManagerConfig{
Service: svc,
IndexingConfig: nil, // use default config
EmbeddingConfig: embConfig,
})
if err != nil {
return nil, fmt.Errorf("init vikingdb manager failed, err=%w", err)
}
return mgr, nil
default:
return nil, fmt.Errorf("unexpected vector store type, type=%s", vsType)
}
}
func getEmbedding(ctx context.Context) (embedding.Embedder, error) {
var batchSize int
if bs, err := strconv.ParseInt(os.Getenv("EMBEDDING_MAX_BATCH_SIZE"), 10, 64); err != nil {
logs.CtxWarnf(ctx, "EMBEDDING_MAX_BATCH_SIZE not set / invalid, using default batchSize=100")
batchSize = 100
} else {
batchSize = int(bs)
}
var emb embedding.Embedder
switch os.Getenv("EMBEDDING_TYPE") {
case "openai":
var (
openAIEmbeddingBaseURL = os.Getenv("OPENAI_EMBEDDING_BASE_URL")
openAIEmbeddingModel = os.Getenv("OPENAI_EMBEDDING_MODEL")
openAIEmbeddingApiKey = os.Getenv("OPENAI_EMBEDDING_API_KEY")
openAIEmbeddingByAzure = os.Getenv("OPENAI_EMBEDDING_BY_AZURE")
openAIEmbeddingApiVersion = os.Getenv("OPENAI_EMBEDDING_API_VERSION")
openAIEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_DIMS")
openAIRequestEmbeddingDims = os.Getenv("OPENAI_EMBEDDING_REQUEST_DIMS")
)
byAzure, err := strconv.ParseBool(openAIEmbeddingByAzure)
if err != nil {
return nil, fmt.Errorf("init openai embedding by_azure failed, err=%w", err)
}
dims, err := strconv.ParseInt(openAIEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init openai embedding dims failed, err=%w", err)
}
openAICfg := &openai.EmbeddingConfig{
APIKey: openAIEmbeddingApiKey,
ByAzure: byAzure,
BaseURL: openAIEmbeddingBaseURL,
APIVersion: openAIEmbeddingApiVersion,
Model: openAIEmbeddingModel,
// Dimensions: ptr.Of(int(dims)),
}
reqDims := conv.StrToInt64D(openAIRequestEmbeddingDims, 0)
if reqDims > 0 {
// some openai model not support request dims
openAICfg.Dimensions = ptr.Of(int(reqDims))
}
emb, err = wrap.NewOpenAIEmbedder(ctx, openAICfg, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init openai embedding failed, err=%w", err)
}
case "ark":
var (
arkEmbeddingBaseURL = os.Getenv("ARK_EMBEDDING_BASE_URL")
arkEmbeddingModel = os.Getenv("ARK_EMBEDDING_MODEL")
arkEmbeddingApiKey = os.Getenv("ARK_EMBEDDING_API_KEY")
// deprecated: use ARK_EMBEDDING_API_KEY instead
// ARK_EMBEDDING_AK will be removed in the future
arkEmbeddingAK = os.Getenv("ARK_EMBEDDING_AK")
arkEmbeddingDims = os.Getenv("ARK_EMBEDDING_DIMS")
arkEmbeddingAPIType = os.Getenv("ARK_EMBEDDING_API_TYPE")
)
dims, err := strconv.ParseInt(arkEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ark embedding dims failed, err=%w", err)
}
apiType := ark.APITypeText
if arkEmbeddingAPIType != "" {
if t := ark.APIType(arkEmbeddingAPIType); t != ark.APITypeText && t != ark.APITypeMultiModal {
return nil, fmt.Errorf("init ark embedding api_type failed, invalid api_type=%s", t)
} else {
apiType = t
}
}
emb, err = arkemb.NewArkEmbedder(ctx, &ark.EmbeddingConfig{
APIKey: func() string {
if arkEmbeddingApiKey != "" {
return arkEmbeddingApiKey
}
return arkEmbeddingAK
}(),
Model: arkEmbeddingModel,
BaseURL: arkEmbeddingBaseURL,
APIType: &apiType,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ark embedding client failed, err=%w", err)
}
case "ollama":
var (
ollamaEmbeddingBaseURL = os.Getenv("OLLAMA_EMBEDDING_BASE_URL")
ollamaEmbeddingModel = os.Getenv("OLLAMA_EMBEDDING_MODEL")
ollamaEmbeddingDims = os.Getenv("OLLAMA_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(ollamaEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init ollama embedding dims failed, err=%w", err)
}
emb, err = wrap.NewOllamaEmbedder(ctx, &ollamaEmb.EmbeddingConfig{
BaseURL: ollamaEmbeddingBaseURL,
Model: ollamaEmbeddingModel,
}, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init ollama embedding failed, err=%w", err)
}
case "http":
var (
httpEmbeddingBaseURL = os.Getenv("HTTP_EMBEDDING_ADDR")
httpEmbeddingDims = os.Getenv("HTTP_EMBEDDING_DIMS")
)
dims, err := strconv.ParseInt(httpEmbeddingDims, 10, 64)
if err != nil {
return nil, fmt.Errorf("init http embedding dims failed, err=%w", err)
}
emb, err = http.NewEmbedding(httpEmbeddingBaseURL, dims, batchSize)
if err != nil {
return nil, fmt.Errorf("init http embedding failed, err=%w", err)
}
default:
return nil, fmt.Errorf("init knowledge embedding failed, type not configured")
}
return emb, nil
}
func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
b, err := os.ReadFile(jsonFilePath)
if err != nil {
return nil, err
}
var m2qMessages []*schema.Message
if err = json.Unmarshal(b, &m2qMessages); err != nil {
return nil, err
}
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
for i := range m2qMessages {
tpl[i] = m2qMessages[i]
}
return prompt.FromMessages(schema.Jinja2, tpl...), nil
}

View File

@ -28,7 +28,7 @@ import (
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/application/search"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
databaseEntity "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"

View File

@ -43,7 +43,7 @@ import (
resCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
pluginConf "github.com/coze-dev/coze-studio/backend/domain/plugin/conf"
"github.com/coze-dev/coze-studio/backend/domain/plugin/encrypt"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"

View File

@ -199,11 +199,7 @@ func (p *PromptApplicationService) updatePromptResource(ctx context.Context, req
return nil, errorx.New(errno.ErrPromptPermissionCode, errorx.KV("msg", "no permission"))
}
promptResource.Name = req.Prompt.GetName()
promptResource.Description = req.Prompt.GetDescription()
promptResource.PromptText = req.Prompt.GetPromptText()
err = p.DomainSVC.UpdatePromptResource(ctx, promptResource)
err = p.DomainSVC.UpdatePromptResource(ctx, promptID, req.Prompt.Name, req.Prompt.Description, req.Prompt.PromptText)
if err != nil {
return nil, err
}

View File

@ -354,8 +354,8 @@ func (s *SearchApplicationService) packProjectResource(ctx context.Context, reso
logs.CtxErrorf(ctx, "GetDataInfo failed, resID=%d, resType=%d, err=%v",
resource.ResID, resource.ResType, err)
} else {
info.BizResStatus = ptr.Of(*di.status)
if *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
info.BizResStatus = di.status
if di.status != nil && *di.status == int32(knowledgeModel.KnowledgeStatusDisable) {
actions := slices.Clone(info.Actions)
for _, a := range actions {
if a.Key == common.ProjectResourceActionKey_Disable {

View File

@ -23,7 +23,7 @@ import (
intelligence "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
searchEntity "github.com/coze-dev/coze-studio/backend/domain/search/entity"
shortcutCMDEntity "github.com/coze-dev/coze-studio/backend/domain/shortcutcmd/entity"

View File

@ -25,6 +25,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/playground"
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
plugin_develop_common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
@ -240,7 +241,7 @@ func (s *SingleAgentApplicationService) fetchWorkflowDetails(ctx context.Context
return a.GetWorkflowId()
}),
},
QType: vo.FromLatestVersion,
QType: workflowModel.FromLatestVersion,
}
ret, _, err := s.appContext.WorkflowDomainSVC.MGet(ctx, policy)
if err != nil {

View File

@ -38,7 +38,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/api/model/playground"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
variableEntity "github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
@ -524,8 +524,7 @@ func (s *SingleAgentApplicationService) GetAgentDraftDisplayInfo(ctx context.Con
func (s *SingleAgentApplicationService) ValidateAgentDraftAccess(ctx context.Context, agentID int64) (*entity.SingleAgent, error) {
uid := ctxutil.GetUIDFromCtx(ctx)
if uid == nil {
uid = ptr.Of(int64(888))
// return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
return nil, errorx.New(errno.ErrAgentPermissionCode, errorx.KV("msg", "session uid not found"))
}
do, err := s.DomainSVC.GetSingleAgentDraft(ctx, agentID)
@ -716,15 +715,18 @@ func (s *SingleAgentApplicationService) GetAgentOnlineInfo(ctx context.Context,
AgentID: ptr.Of(si.ObjectID),
Command: si.ShortcutCommand,
Components: slices.Transform(si.Components, func(i *playground.Components) *bot_common.ShortcutCommandComponent {
return &bot_common.ShortcutCommandComponent{
sc := &bot_common.ShortcutCommandComponent{
Name: i.Name,
Description: i.Description,
Type: i.InputType.String(),
ToolParameter: ptr.Of(i.Parameter),
Options: i.Options,
DefaultValue: ptr.Of(i.DefaultValue.Value),
IsHide: i.Hide,
}
if i.DefaultValue != nil {
sc.DefaultValue = ptr.Of(i.DefaultValue.Value)
}
return sc
}),
}
})

View File

@ -0,0 +1,61 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package workflow
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
"github.com/coze-dev/coze-studio/backend/domain/search/service"
)
var eventBus service.ResourceEventBus
func setEventBus(bus service.ResourceEventBus) {
eventBus = bus
}
func PublishWorkflowResource(ctx context.Context, workflowID int64, mode *int32, op search.OpType, r *search.ResourceDocument) error {
if r == nil {
r = &search.ResourceDocument{}
}
r.ResType = common.ResType_Workflow
r.ResID = workflowID
r.ResSubType = mode
event := &entity.ResourceDomainEvent{
OpType: entity.OpType(op),
Resource: r,
}
if op == search.Created {
event.Resource.CreateTimeMS = r.CreateTimeMS
event.Resource.UpdateTimeMS = r.UpdateTimeMS
} else if op == search.Updated {
event.Resource.UpdateTimeMS = r.UpdateTimeMS
}
err := eventBus.PublishResources(ctx, event)
if err != nil {
return err
}
return nil
}

View File

@ -18,90 +18,88 @@ package workflow
import (
"context"
"path/filepath"
"gopkg.in/yaml.v3"
"os"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/compose"
"gorm.io/gorm"
"github.com/cloudwego/eino/callbacks"
"github.com/coze-dev/coze-studio/backend/application/internal"
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
wfplugin "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/plugin"
wfsearch "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/search"
"github.com/coze-dev/coze-studio/backend/crossdomain/workflow/variable"
"github.com/coze-dev/coze-studio/backend/crossdomain/impl/code"
knowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
dbservice "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
crossvariable "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/config"
"github.com/coze-dev/coze-studio/backend/domain/workflow/service"
workflowservice "github.com/coze-dev/coze-studio/backend/domain/workflow/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type ServiceComponents struct {
IDGen idgen.IDGenerator
DB *gorm.DB
Cache cache.Cmdable
DatabaseDomainSVC dbservice.Database
VariablesDomainSVC variables.Variables
PluginDomainSVC plugin.PluginService
KnowledgeDomainSVC knowledge.Knowledge
ModelManager modelmgr.Manager
DomainNotifier search.ResourceEventBus
Tos storage.Storage
ImageX imagex.ImageX
CPStore compose.CheckPointStore
CodeRunner coderunner.Runner
IDGen idgen.IDGenerator
DB *gorm.DB
Cache cache.Cmdable
DatabaseDomainSVC dbservice.Database
VariablesDomainSVC variables.Variables
PluginDomainSVC plugin.PluginService
KnowledgeDomainSVC knowledge.Knowledge
DomainNotifier search.ResourceEventBus
Tos storage.Storage
ImageX imagex.ImageX
CPStore compose.CheckPointStore
CodeRunner coderunner.Runner
WorkflowBuildInChatModel chatmodel.BaseChatModel
}
func InitService(ctx context.Context, components *ServiceComponents) (*ApplicationService, error) {
bcm, ok, err := internal.GetBuiltinChatModel(ctx, "WKR_")
func initWorkflowConfig() (workflow.WorkflowConfig, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
if !ok {
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
configBs, err := os.ReadFile(filepath.Join(wd, "resources/conf/workflow/config.yaml"))
if err != nil {
return nil, err
}
var cfg *config.WorkflowConfig
err = yaml.Unmarshal(configBs, &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func InitService(_ context.Context, components *ServiceComponents) (*ApplicationService, error) {
service.RegisterAllNodeAdaptors()
cfg, err := initWorkflowConfig()
if err != nil {
return nil, err
}
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
components.Tos, components.CPStore, bcm)
components.Tos, components.CPStore, components.WorkflowBuildInChatModel, cfg)
workflow.SetRepository(workflowRepo)
workflowDomainSVC := service.NewWorkflowService(workflowRepo)
crossdatabase.SetDatabaseOperator(wfdatabase.NewDatabaseRepository(components.DatabaseDomainSVC))
crossvariable.SetVariableHandler(variable.NewVariableHandler(components.VariablesDomainSVC))
crossvariable.SetVariablesMetaGetter(variable.NewVariablesMetaGetter(components.VariablesDomainSVC))
crossplugin.SetPluginService(wfplugin.NewPluginService(components.PluginDomainSVC, components.Tos))
crossknowledge.SetKnowledgeOperator(wfknowledge.NewKnowledgeRepository(components.KnowledgeDomainSVC, components.IDGen))
crossmodel.SetManager(wfmodel.NewModelManager(components.ModelManager, nil))
crosscode.SetCodeRunner(components.CodeRunner)
crosssearch.SetNotifier(wfsearch.NewNotify(components.DomainNotifier))
code.SetCodeRunner(components.CodeRunner)
callbacks.AppendGlobalHandlers(workflowservice.GetTokenCallbackHandler())
setEventBus(components.DomainNotifier)
SVC.DomainSVC = workflowDomainSVC
SVC.ImageX = components.ImageX
SVC.TosClient = components.Tos
SVC.IDGenerator = components.IDGen
return SVC, err
return SVC, nil
}

View File

@ -25,28 +25,31 @@ import (
"strings"
"time"
xmaps "golang.org/x/exp/maps"
"github.com/cloudwego/eino/schema"
xmaps "golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
pluginmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/api/model/playground"
pluginAPI "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
resource "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
appknowledge "github.com/coze-dev/coze-studio/backend/application/knowledge"
appmemory "github.com/coze-dev/coze-studio/backend/application/memory"
appplugin "github.com/coze-dev/coze-studio/backend/application/plugin"
"github.com/coze-dev/coze-studio/backend/application/user"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
search "github.com/coze-dev/coze-studio/backend/domain/search/entity"
domainWorkflow "github.com/coze-dev/coze-studio/backend/domain/workflow"
workflowDomain "github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
@ -54,6 +57,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/maps"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
@ -182,6 +186,18 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C
return nil, err
}
err = PublishWorkflowResource(ctx, id, ptr.Of(int32(wf.Mode)), search.Created, &search.ResourceDocument{
Name: &wf.Name,
APPID: wf.AppID,
SpaceID: &wf.SpaceID,
OwnerID: &wf.CreatorID,
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
})
if err != nil {
return nil, vo.WrapError(errno.ErrNotifyWorkflowResourceChangeErr, err)
}
return &workflow.CreateWorkflowResponse{
Data: &workflow.CreateWorkflowData{
WorkflowID: strconv.FormatInt(id, 10),
@ -232,7 +248,8 @@ func (w *ApplicationService) UpdateWorkflowMeta(ctx context.Context, req *workfl
return nil, err
}
err = GetWorkflowDomainSVC().UpdateMeta(ctx, mustParseInt64(req.GetWorkflowID()), &vo.MetaUpdate{
workflowID := mustParseInt64(req.GetWorkflowID())
err = GetWorkflowDomainSVC().UpdateMeta(ctx, workflowID, &vo.MetaUpdate{
Name: req.Name,
Desc: req.Desc,
IconURI: req.IconURI,
@ -240,33 +257,31 @@ func (w *ApplicationService) UpdateWorkflowMeta(ctx context.Context, req *workfl
if err != nil {
return nil, err
}
safego.Go(ctx, func() {
err := PublishWorkflowResource(ctx, workflowID, nil, search.Updated, &search.ResourceDocument{
Name: req.Name,
UpdateTimeMS: ptr.Of(time.Now().UnixMilli()),
})
if err != nil {
logs.CtxErrorf(ctx, "publish update workflow resource failed, workflowID: %d, err: %v", workflowID, err)
}
})
return &workflow.UpdateWorkflowMetaResponse{}, nil
}
func (w *ApplicationService) DeleteWorkflow(ctx context.Context, req *workflow.DeleteWorkflowRequest) (
_ *workflow.DeleteWorkflowResponse, err error,
) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
_, err = w.BatchDeleteWorkflow(ctx, &workflow.BatchDeleteWorkflowRequest{
WorkflowIDList: []string{req.GetWorkflowID()},
SpaceID: req.SpaceID,
Action: req.Action,
})
if err != nil {
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
}
}()
if err := checkUserSpace(ctx, ctxutil.MustGetUIDFromCtx(ctx), mustParseInt64(req.GetSpaceID())); err != nil {
return nil, err
}
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{ID: ptr.Of(mustParseInt64(req.GetWorkflowID()))})
if err != nil {
return &workflow.DeleteWorkflowResponse{
Data: &workflow.DeleteWorkflowData{
Status: workflow.DeleteStatus_FAIL,
},
}, err
return nil, err
}
return &workflow.DeleteWorkflowResponse{
@ -276,19 +291,25 @@ func (w *ApplicationService) DeleteWorkflow(ctx context.Context, req *workflow.D
}, nil
}
func (w *ApplicationService) deleteWorkflowResource(ctx context.Context, policy *vo.DeletePolicy) error {
ids, err := GetWorkflowDomainSVC().Delete(ctx, policy)
if err != nil {
return err
}
safego.Go(ctx, func() {
for _, id := range ids {
if err = PublishWorkflowResource(ctx, id, nil, search.Deleted, &search.ResourceDocument{}); err != nil {
logs.CtxErrorf(ctx, "publish delete workflow event resource failed, workflowID: %d, err: %v", id, err)
}
}
})
return nil
}
func (w *ApplicationService) BatchDeleteWorkflow(ctx context.Context, req *workflow.BatchDeleteWorkflowRequest) (
_ *workflow.BatchDeleteWorkflowResponse, err error,
) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
}
}()
_ *workflow.BatchDeleteWorkflowResponse, err error) {
if err := checkUserSpace(ctx, ctxutil.MustGetUIDFromCtx(ctx), mustParseInt64(req.GetSpaceID())); err != nil {
return nil, err
}
@ -300,7 +321,7 @@ func (w *ApplicationService) BatchDeleteWorkflow(ctx context.Context, req *workf
return nil, err
}
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
err = w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
IDs: ids,
})
if err != nil {
@ -335,7 +356,7 @@ func (w *ApplicationService) GetCanvasInfo(ctx context.Context, req *workflow.Ge
wf, err := GetWorkflowDomainSVC().Get(ctx, &vo.GetPolicy{
ID: mustParseInt64(req.GetWorkflowID()),
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
})
if err != nil {
return nil, err
@ -431,19 +452,19 @@ func (w *ApplicationService) TestRun(ctx context.Context, req *workflow.WorkFlow
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
}
exeCfg := vo.ExecuteConfig{
exeCfg := workflowModel.ExecuteConfig{
ID: mustParseInt64(req.GetWorkflowID()),
From: vo.FromDraft,
From: workflowModel.FromDraft,
CommitID: req.GetCommitID(),
Operator: uID,
Mode: vo.ExecuteModeDebug,
Mode: workflowModel.ExecuteModeDebug,
AppID: appID,
AgentID: agentID,
ConnectorID: consts.CozeConnectorID,
ConnectorUID: strconv.FormatInt(uID, 10),
TaskType: vo.TaskTypeForeground,
SyncPattern: vo.SyncPatternAsync,
BizType: vo.BizTypeWorkflow,
TaskType: workflowModel.TaskTypeForeground,
SyncPattern: workflowModel.SyncPatternAsync,
BizType: workflowModel.BizTypeWorkflow,
Cancellable: true,
}
@ -503,18 +524,18 @@ func (w *ApplicationService) NodeDebug(ctx context.Context, req *workflow.Workfl
agentID = ptr.Of(mustParseInt64(req.GetBotID()))
}
exeCfg := vo.ExecuteConfig{
exeCfg := workflowModel.ExecuteConfig{
ID: mustParseInt64(req.GetWorkflowID()),
From: vo.FromDraft,
From: workflowModel.FromDraft,
Operator: uID,
Mode: vo.ExecuteModeNodeDebug,
Mode: workflowModel.ExecuteModeNodeDebug,
AppID: appID,
AgentID: agentID,
ConnectorID: consts.CozeConnectorID,
ConnectorUID: strconv.FormatInt(uID, 10),
TaskType: vo.TaskTypeForeground,
SyncPattern: vo.SyncPatternAsync,
BizType: vo.BizTypeWorkflow,
TaskType: workflowModel.TaskTypeForeground,
SyncPattern: workflowModel.SyncPatternAsync,
BizType: workflowModel.BizTypeWorkflow,
Cancellable: true,
}
@ -832,17 +853,7 @@ func (w *ApplicationService) GetNodeExecuteHistory(ctx context.Context, req *wor
}
func (w *ApplicationService) DeleteWorkflowsByAppID(ctx context.Context, appID int64) (err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrWorkflowOperationFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
}
}()
return GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
return w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
AppID: ptr.Of(appID),
})
}
@ -866,7 +877,7 @@ func (w *ApplicationService) CheckWorkflowsExistByAppID(ctx context.Context, app
Page: 0,
},
},
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
MetaOnly: true,
})
@ -874,8 +885,7 @@ func (w *ApplicationService) CheckWorkflowsExistByAppID(ctx context.Context, app
}
func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (
_ int64, _ []*vo.ValidateIssue, err error,
) {
_ int64, _ []*vo.ValidateIssue, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
@ -891,7 +901,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
return 0, nil, err
}
pluginMap := make(map[int64]*vo.PluginEntity)
pluginMap := make(map[int64]*plugin.PluginEntity)
pluginToolMap := make(map[int64]int64)
if len(ds.PluginIDs) > 0 {
@ -906,7 +916,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
return 0, nil, err
}
pInfo := response.Plugin
pluginMap[id] = &vo.PluginEntity{
pluginMap[id] = &plugin.PluginEntity{
PluginID: pInfo.ID,
PluginVersion: pInfo.Version,
}
@ -958,7 +968,7 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
}
relatedWorkflows, vIssues, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
relatedWorkflows, vIssues, err := w.copyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
PluginMap: pluginMap,
PluginToolMap: pluginToolMap,
KnowledgeMap: relatedKnowledgeMap,
@ -980,6 +990,32 @@ func (w *ApplicationService) CopyWorkflowFromAppToLibrary(ctx context.Context, w
return copiedWf.ID, vIssues, nil
}
func (w *ApplicationService) copyWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, appID int64, related vo.ExternalResourceRelated) (map[int64]entity.IDVersionPair, []*vo.ValidateIssue, error) {
resp, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, related)
if err != nil {
return nil, nil, err
}
for index := range resp.CopiedWorkflows {
wf := resp.CopiedWorkflows[index]
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
Name: &wf.Name,
SpaceID: &wf.SpaceID,
OwnerID: &wf.CreatorID,
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
})
if err != nil {
logs.CtxErrorf(ctx, "failed to publish workflow resource, workflow id=%d, err=%v", wf.ID, err)
return nil, nil, err
}
}
return resp.WorkflowIDVersionMap, resp.ValidateIssues, nil
}
type ExternalResource struct {
PluginMap map[int64]int64
PluginToolMap map[int64]int64
@ -998,9 +1034,9 @@ func (w *ApplicationService) DuplicateWorkflowsByAppID(ctx context.Context, sour
}
}()
pluginMap := make(map[int64]*vo.PluginEntity)
pluginMap := make(map[int64]*plugin.PluginEntity)
for o, n := range externalResource.PluginMap {
pluginMap[o] = &vo.PluginEntity{
pluginMap[o] = &plugin.PluginEntity{
PluginID: n,
}
}
@ -1011,7 +1047,29 @@ func (w *ApplicationService) DuplicateWorkflowsByAppID(ctx context.Context, sour
DatabaseMap: externalResource.DatabaseMap,
}
return GetWorkflowDomainSVC().DuplicateWorkflowsByAppID(ctx, sourceAppID, targetAppID, externalResourceRelated)
copiedWorkflowArray, err := GetWorkflowDomainSVC().DuplicateWorkflowsByAppID(ctx, sourceAppID, targetAppID, externalResourceRelated)
if err != nil {
return err
}
logs.CtxInfof(ctx, "[DuplicateWorkflowsByAppID] %s", conv.DebugJsonToStr(copiedWorkflowArray))
for index := range copiedWorkflowArray {
wf := copiedWorkflowArray[index]
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
Name: &wf.Name,
SpaceID: &wf.SpaceID,
OwnerID: &wf.CreatorID,
APPID: &targetAppID,
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
})
if err != nil {
logs.CtxErrorf(ctx, "failed to publish workflow resource, workflow id=%d, err=%v", wf.ID, err)
}
}
return nil
}
func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, workflowID int64, appID int64) (
@ -1027,7 +1085,7 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
}
}()
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
wf, err := w.copyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
TargetAppID: &appID,
})
if err != nil {
@ -1037,6 +1095,28 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
return wf.ID, nil
}
func (w *ApplicationService) copyWorkflow(ctx context.Context, workflowID int64, policy vo.CopyWorkflowPolicy) (*entity.Workflow, error) {
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, policy)
if err != nil {
return nil, err
}
err = PublishWorkflowResource(ctx, wf.ID, ptr.Of(int32(wf.Meta.Mode)), search.Created, &search.ResourceDocument{
Name: &wf.Name,
APPID: wf.AppID,
SpaceID: &wf.SpaceID,
OwnerID: &wf.CreatorID,
PublishStatus: ptr.Of(resource.PublishStatus_UnPublished),
CreateTimeMS: ptr.Of(time.Now().UnixMilli()),
})
if err != nil {
logs.CtxErrorf(ctx, "public copy workflow event failed, workflowID=%d, err=%v", wf.ID, err)
return nil, err
}
return wf, nil
}
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/
appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
defer func() {
@ -1054,7 +1134,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
return 0, nil, err
}
pluginMap := make(map[int64]*vo.PluginEntity)
pluginMap := make(map[int64]*plugin.PluginEntity)
if len(ds.PluginIDs) > 0 {
for idx := range ds.PluginIDs {
id := ds.PluginIDs[idx]
@ -1062,7 +1142,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
if err != nil {
return 0, nil, err
}
pluginMap[id] = &vo.PluginEntity{
pluginMap[id] = &plugin.PluginEntity{
PluginID: pInfo.ID,
PluginVersion: pInfo.Version,
}
@ -1091,7 +1171,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
}
}
relatedWorkflows, vIssues, err := GetWorkflowDomainSVC().CopyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
relatedWorkflows, vIssues, err := w.copyWorkflowFromAppToLibrary(ctx, workflowID, appID, vo.ExternalResourceRelated{
PluginMap: pluginMap,
})
if err != nil {
@ -1109,7 +1189,7 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
}
deleteWorkflowIDs := xmaps.Keys(relatedWorkflows)
err = GetWorkflowDomainSVC().Delete(ctx, &vo.DeletePolicy{
err = w.deleteWorkflowResource(ctx, &vo.DeletePolicy{
IDs: deleteWorkflowIDs,
})
if err != nil {
@ -1250,7 +1330,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID),
Event: string(DoneEvent),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
}, nil
case entity.WorkflowFailed, entity.WorkflowCancel:
var wfe vo.WorkflowError
@ -1260,7 +1340,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID),
Event: string(ErrEvent),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
ErrorCode: ptr.Of(int64(wfe.Code())),
ErrorMessage: ptr.Of(wfe.Msg()),
}, nil
@ -1269,7 +1349,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID),
Event: string(InterruptEvent),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
InterruptData: &workflow.Interrupt{
EventID: fmt.Sprintf("%d/%d", executeID, msg.InterruptEvent.ID),
Type: workflow.InterruptType(msg.InterruptEvent.EventType),
@ -1281,7 +1361,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID),
Event: string(InterruptEvent),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, executeID, spaceID, workflowID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, executeID, spaceID, workflowID)),
InterruptData: &workflow.Interrupt{
EventID: fmt.Sprintf("%d/%d", executeID, msg.InterruptEvent.ID),
Type: workflow.InterruptType(msg.InterruptEvent.ToolInterruptEvent.EventType),
@ -1397,20 +1477,20 @@ func (w *ApplicationService) OpenAPIStreamRun(ctx context.Context, req *workflow
connectorID = apiKeyInfo.ConnectorID
}
exeCfg := vo.ExecuteConfig{
exeCfg := workflowModel.ExecuteConfig{
ID: meta.ID,
From: vo.FromSpecificVersion,
From: workflowModel.FromSpecificVersion,
Version: *meta.LatestPublishedVersion,
Operator: userID,
Mode: vo.ExecuteModeRelease,
Mode: workflowModel.ExecuteModeRelease,
AppID: appID,
AgentID: agentID,
ConnectorID: connectorID,
ConnectorUID: strconv.FormatInt(userID, 10),
TaskType: vo.TaskTypeForeground,
SyncPattern: vo.SyncPatternStream,
TaskType: workflowModel.TaskTypeForeground,
SyncPattern: workflowModel.SyncPatternStream,
InputFailFast: true,
BizType: vo.BizTypeWorkflow,
BizType: workflowModel.BizTypeWorkflow,
}
if exeCfg.AppID != nil && exeCfg.AgentID != nil {
@ -1471,12 +1551,12 @@ func (w *ApplicationService) OpenAPIStreamResume(ctx context.Context, req *workf
connectorID = mustParseInt64(req.GetConnectorID())
}
sr, err := GetWorkflowDomainSVC().StreamResume(ctx, resumeReq, vo.ExecuteConfig{
sr, err := GetWorkflowDomainSVC().StreamResume(ctx, resumeReq, workflowModel.ExecuteConfig{
Operator: userID,
Mode: vo.ExecuteModeRelease,
Mode: workflowModel.ExecuteModeRelease,
ConnectorID: connectorID,
ConnectorUID: strconv.FormatInt(userID, 10),
BizType: vo.BizTypeWorkflow,
BizType: workflowModel.BizTypeWorkflow,
})
if err != nil {
return nil, err
@ -1546,18 +1626,18 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
connectorID = apiKeyInfo.ConnectorID
}
exeCfg := vo.ExecuteConfig{
exeCfg := workflowModel.ExecuteConfig{
ID: meta.ID,
From: vo.FromSpecificVersion,
From: workflowModel.FromSpecificVersion,
Version: *meta.LatestPublishedVersion,
Operator: userID,
Mode: vo.ExecuteModeRelease,
Mode: workflowModel.ExecuteModeRelease,
AppID: appID,
AgentID: agentID,
ConnectorID: connectorID,
ConnectorUID: strconv.FormatInt(userID, 10),
InputFailFast: true,
BizType: vo.BizTypeWorkflow,
BizType: workflowModel.BizTypeWorkflow,
}
if exeCfg.AppID != nil && exeCfg.AgentID != nil {
@ -1565,8 +1645,8 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
}
if req.GetIsAsync() {
exeCfg.SyncPattern = vo.SyncPatternAsync
exeCfg.TaskType = vo.TaskTypeBackground
exeCfg.SyncPattern = workflowModel.SyncPatternAsync
exeCfg.TaskType = workflowModel.TaskTypeBackground
exeID, err := GetWorkflowDomainSVC().AsyncExecute(ctx, exeCfg, parameters)
if err != nil {
return nil, err
@ -1574,12 +1654,12 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
return &workflow.OpenAPIRunFlowResponse{
ExecuteID: ptr.Of(strconv.FormatInt(exeID, 10)),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, exeID, meta.SpaceID, meta.ID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, exeID, meta.SpaceID, meta.ID)),
}, nil
}
exeCfg.SyncPattern = vo.SyncPatternSync
exeCfg.TaskType = vo.TaskTypeForeground
exeCfg.SyncPattern = workflowModel.SyncPatternSync
exeCfg.TaskType = workflowModel.TaskTypeForeground
wfExe, tPlan, err := GetWorkflowDomainSVC().SyncExecute(ctx, exeCfg, parameters)
if err != nil {
return nil, err
@ -1610,7 +1690,7 @@ func (w *ApplicationService) OpenAPIRun(ctx context.Context, req *workflow.OpenA
return &workflow.OpenAPIRunFlowResponse{
Data: data,
ExecuteID: ptr.Of(strconv.FormatInt(wfExe.ID, 10)),
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, wfExe.ID, wfExe.SpaceID, meta.ID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, wfExe.ID, wfExe.SpaceID, meta.ID)),
Token: ptr.Of(wfExe.TokenInfo.InputTokens + wfExe.TokenInfo.OutputTokens),
Cost: ptr.Of("0.00000"),
}, nil
@ -1650,11 +1730,11 @@ func (w *ApplicationService) OpenAPIGetWorkflowRunHistory(ctx context.Context, r
var runMode *workflow.WorkflowRunMode
switch exe.SyncPattern {
case vo.SyncPatternSync:
case workflowModel.SyncPatternSync:
runMode = ptr.Of(workflow.WorkflowRunMode_Sync)
case vo.SyncPatternAsync:
case workflowModel.SyncPatternAsync:
runMode = ptr.Of(workflow.WorkflowRunMode_Async)
case vo.SyncPatternStream:
case workflowModel.SyncPatternStream:
runMode = ptr.Of(workflow.WorkflowRunMode_Stream)
default:
}
@ -1671,7 +1751,7 @@ func (w *ApplicationService) OpenAPIGetWorkflowRunHistory(ctx context.Context, r
LogID: ptr.Of(exe.LogID),
CreateTime: ptr.Of(exe.CreatedAt.Unix()),
UpdateTime: updateTime,
DebugUrl: ptr.Of(fmt.Sprintf(vo.DebugURLTpl, exe.ID, exe.SpaceID, exe.WorkflowID)),
DebugUrl: ptr.Of(fmt.Sprintf(workflowModel.DebugURLTpl, exe.ID, exe.SpaceID, exe.WorkflowID)),
Input: exe.Input,
Output: exe.Output,
Token: ptr.Of(exe.TokenInfo.InputTokens + exe.TokenInfo.OutputTokens),
@ -1805,10 +1885,10 @@ func (w *ApplicationService) TestResume(ctx context.Context, req *workflow.Workf
EventID: mustParseInt64(req.GetEventID()),
ResumeData: req.GetData(),
}
err = GetWorkflowDomainSVC().AsyncResume(ctx, resumeReq, vo.ExecuteConfig{
err = GetWorkflowDomainSVC().AsyncResume(ctx, resumeReq, workflowModel.ExecuteConfig{
Operator: ptr.FromOrDefault(ctxutil.GetUIDFromCtx(ctx), 0),
Mode: vo.ExecuteModeDebug, // at this stage it could be debug or node debug, we will decide it within AsyncResume
BizType: vo.BizTypeWorkflow,
Mode: workflowModel.ExecuteModeDebug, // at this stage it could be debug or node debug, we will decide it within AsyncResume
BizType: workflowModel.BizTypeWorkflow,
Cancellable: true,
})
if err != nil {
@ -1948,7 +2028,7 @@ func (w *ApplicationService) PublishWorkflow(ctx context.Context, req *workflow.
Force: req.GetForce(),
}
err = GetWorkflowDomainSVC().Publish(ctx, info)
err = w.publishWorkflowResource(ctx, info)
if err != nil {
return nil, err
}
@ -2003,13 +2083,13 @@ func (w *ApplicationService) ListWorkflow(ctx context.Context, req *workflow.Get
}
status := req.GetStatus()
var qType vo.Locator
var qType workflowModel.Locator
if status == workflow.WorkFlowListStatus_UnPublished {
option.PublishStatus = ptr.Of(vo.UnPublished)
qType = vo.FromDraft
qType = workflowModel.FromDraft
} else if status == workflow.WorkFlowListStatus_HadPublished {
option.PublishStatus = ptr.Of(vo.HasPublished)
qType = vo.FromLatestVersion
qType = workflowModel.FromLatestVersion
}
if len(req.GetName()) > 0 {
@ -2077,9 +2157,9 @@ func (w *ApplicationService) ListWorkflow(ctx context.Context, req *workflow.Get
},
}
if qType == vo.FromDraft {
if qType == workflowModel.FromDraft {
ww.UpdateTime = w.DraftMeta.Timestamp.Unix()
} else if qType == vo.FromLatestVersion || qType == vo.FromSpecificVersion {
} else if qType == workflowModel.FromLatestVersion || qType == workflowModel.FromSpecificVersion {
ww.UpdateTime = w.VersionMeta.VersionCreatedAt.Unix()
} else if w.UpdatedAt != nil {
ww.UpdateTime = w.UpdatedAt.Unix()
@ -2164,7 +2244,7 @@ func (w *ApplicationService) GetWorkflowDetail(ctx context.Context, req *workflo
MetaQuery: vo.MetaQuery{
IDs: ids,
},
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
MetaOnly: false,
})
if err != nil {
@ -2270,7 +2350,7 @@ func (w *ApplicationService) GetWorkflowDetailInfo(ctx context.Context, req *wor
MetaQuery: vo.MetaQuery{
IDs: draftIDs,
},
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
MetaOnly: false,
})
if err != nil {
@ -2283,7 +2363,7 @@ func (w *ApplicationService) GetWorkflowDetailInfo(ctx context.Context, req *wor
MetaQuery: vo.MetaQuery{
IDs: versionIDs,
},
QType: vo.FromSpecificVersion,
QType: workflowModel.FromSpecificVersion,
MetaOnly: false,
Versions: id2Version,
})
@ -2484,8 +2564,8 @@ func (w *ApplicationService) GetApiDetail(ctx context.Context, req *workflow.Get
return nil, err
}
toolInfoResponse, err := crossplugin.GetPluginService().GetPluginToolsInfo(ctx, &crossplugin.ToolsInfoRequest{
PluginEntity: crossplugin.Entity{
toolInfoResponse, err := crossplugin.DefaultSVC().GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{
PluginEntity: plugin.PluginEntity{
PluginID: pluginID,
PluginVersion: req.PluginVersion,
},
@ -2551,8 +2631,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
}
var (
pluginSvc = crossplugin.GetPluginService()
pluginToolsInfoReqs = make(map[int64]*crossplugin.ToolsInfoRequest)
pluginSvc = crossplugin.DefaultSVC()
pluginToolsInfoReqs = make(map[int64]*plugin.ToolsInfoRequest)
pluginDetailMap = make(map[string]*workflow.PluginDetail)
toolsDetailInfo = make(map[string]*workflow.APIDetail)
workflowDetailMap = make(map[string]*workflow.WorkflowDetail)
@ -2574,8 +2654,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
if r, ok := pluginToolsInfoReqs[pluginID]; ok {
r.ToolIDs = append(r.ToolIDs, toolID)
} else {
pluginToolsInfoReqs[pluginID] = &crossplugin.ToolsInfoRequest{
PluginEntity: crossplugin.Entity{
pluginToolsInfoReqs[pluginID] = &plugin.ToolsInfoRequest{
PluginEntity: plugin.PluginEntity{
PluginID: pluginID,
PluginVersion: pl.PluginVersion,
},
@ -2656,7 +2736,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
MetaQuery: vo.MetaQuery{
IDs: draftIDs,
},
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
MetaOnly: false,
})
if err != nil {
@ -2669,7 +2749,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
MetaQuery: vo.MetaQuery{
IDs: versionIDs,
},
QType: vo.FromSpecificVersion,
QType: workflowModel.FromSpecificVersion,
MetaOnly: false,
Versions: id2Version,
})
@ -2705,18 +2785,18 @@ func (w *ApplicationService) GetLLMNodeFCSettingDetail(ctx context.Context, req
}
if len(req.GetDatasetList()) > 0 {
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
knowledgeOperator := crossknowledge.DefaultSVC()
knowledgeIDs, err := slices.TransformWithErrorCheck(req.GetDatasetList(), func(a *workflow.DatasetFCItem) (int64, error) {
return strconv.ParseInt(a.GetDatasetID(), 10, 64)
})
if err != nil {
return nil, err
}
details, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{KnowledgeIDs: knowledgeIDs})
details, err := knowledgeOperator.ListKnowledgeDetail(ctx, &model.ListKnowledgeDetailRequest{KnowledgeIDs: knowledgeIDs})
if err != nil {
return nil, err
}
knowledgeDetailMap = slices.ToMap(details.KnowledgeDetails, func(kd *crossknowledge.KnowledgeDetail) (string, *workflow.DatasetDetail) {
knowledgeDetailMap = slices.ToMap(details.KnowledgeDetails, func(kd *model.KnowledgeDetail) (string, *workflow.DatasetDetail) {
return strconv.FormatInt(kd.ID, 10), &workflow.DatasetDetail{
ID: strconv.FormatInt(kd.ID, 10),
Name: kd.Name,
@ -2759,7 +2839,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
var fcPluginSetting *workflow.FCPluginSetting
if req.GetPluginFcSetting() != nil {
var (
pluginSvc = crossplugin.GetPluginService()
pluginSvc = crossplugin.DefaultSVC()
pluginFcSetting = req.GetPluginFcSetting()
isDraft = pluginFcSetting.GetIsDraft()
)
@ -2774,8 +2854,8 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
return nil, err
}
pluginReq := &crossplugin.ToolsInfoRequest{
PluginEntity: vo.PluginEntity{
pluginReq := &plugin.ToolsInfoRequest{
PluginEntity: plugin.PluginEntity{
PluginID: pluginID,
},
ToolIDs: []int64{toolID},
@ -2816,7 +2896,7 @@ func (w *ApplicationService) GetLLMNodeFCSettingsMerged(ctx context.Context, req
policy := &vo.GetPolicy{
ID: wID,
QType: ternary.IFElse(len(setting.WorkflowVersion) == 0, vo.FromDraft, vo.FromSpecificVersion),
QType: ternary.IFElse(len(setting.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
Version: setting.WorkflowVersion,
}
@ -2891,7 +2971,7 @@ func (w *ApplicationService) GetPlaygroundPluginList(ctx context.Context, req *p
SpaceID: ptr.Of(req.GetSpaceID()),
PublishStatus: ptr.Of(vo.HasPublished),
},
QType: vo.FromLatestVersion,
QType: workflowModel.FromLatestVersion,
})
} else if req.GetPage() > 0 && req.GetSize() > 0 {
wfs, _, err = GetWorkflowDomainSVC().MGet(ctx, &vo.MGetPolicy{
@ -2903,7 +2983,7 @@ func (w *ApplicationService) GetPlaygroundPluginList(ctx context.Context, req *p
SpaceID: ptr.Of(req.GetSpaceID()),
PublishStatus: ptr.Of(vo.HasPublished),
},
QType: vo.FromLatestVersion,
QType: workflowModel.FromLatestVersion,
})
}
@ -2977,7 +3057,7 @@ func (w *ApplicationService) CopyWorkflow(ctx context.Context, req *workflow.Cop
return nil, err
}
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
wf, err := w.copyWorkflow(ctx, workflowID, vo.CopyWorkflowPolicy{
ShouldModifyWorkflowName: true,
})
if err != nil {
@ -3043,7 +3123,7 @@ func (w *ApplicationService) GetHistorySchema(ctx context.Context, req *workflow
// get the workflow entity for that workflowID and commitID
policy := &vo.GetPolicy{
ID: workflowID,
QType: ternary.IFElse(len(exe.Version) > 0, vo.FromSpecificVersion, vo.FromDraft),
QType: ternary.IFElse(len(exe.Version) > 0, workflowModel.FromSpecificVersion, workflowModel.FromDraft),
Version: exe.Version,
CommitID: exe.CommitID,
}
@ -3101,7 +3181,7 @@ func (w *ApplicationService) GetExampleWorkFlowList(ctx context.Context, req *wo
wfs, _, err := GetWorkflowDomainSVC().MGet(ctx, &vo.MGetPolicy{
MetaQuery: option,
QType: vo.FromDraft,
QType: workflowModel.FromDraft,
MetaOnly: false,
})
if err != nil {
@ -3180,7 +3260,7 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
if err != nil {
return nil, err
}
wf, err := GetWorkflowDomainSVC().CopyWorkflow(ctx, wid, vo.CopyWorkflowPolicy{
wf, err := w.copyWorkflow(ctx, wid, vo.CopyWorkflowPolicy{
ShouldModifyWorkflowName: true,
TargetSpaceID: ptr.Of(req.GetTargetSpaceID()),
TargetAppID: ptr.Of(int64(0)),
@ -3189,7 +3269,7 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
return nil, err
}
err = GetWorkflowDomainSVC().Publish(ctx, &vo.PublishPolicy{
err = w.publishWorkflowResource(ctx, &vo.PublishPolicy{
ID: wf.ID,
Version: "v0.0.0",
CommitID: wf.CommitID,
@ -3270,6 +3350,26 @@ func (w *ApplicationService) CopyWkTemplateApi(ctx context.Context, req *workflo
return resp, err
}
func (w *ApplicationService) publishWorkflowResource(ctx context.Context, policy *vo.PublishPolicy) error {
err := GetWorkflowDomainSVC().Publish(ctx, policy)
if err != nil {
return err
}
safego.Go(ctx, func() {
now := time.Now().UnixMilli()
if err := PublishWorkflowResource(ctx, policy.ID, nil, search.Updated, &search.ResourceDocument{
PublishStatus: ptr.Of(resource.PublishStatus_Published),
UpdateTimeMS: ptr.Of(now),
PublishTimeMS: ptr.Of(now),
}); err != nil {
logs.CtxErrorf(ctx, "publish workflow resource failed workflowID = %d, err: %v", policy.ID, err)
}
})
return nil
}
func mustParseInt64(s string) int64 {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
@ -3518,13 +3618,13 @@ func toVariable(p *workflow.APIParameter) (*vo.Variable, error) {
v.Type = vo.VariableTypeBoolean
case workflow.ParameterType_Array:
v.Type = vo.VariableTypeList
if len(p.SubParameters) == 1 {
if len(p.SubParameters) == 1 && p.SubType != nil && *p.SubType != workflow.ParameterType_Object {
av, err := toVariable(p.SubParameters[0])
if err != nil {
return nil, err
}
v.Schema = &av
} else if len(p.SubParameters) > 1 {
} else {
subVs := make([]any, 0)
for _, ap := range p.SubParameters {
av, err := toVariable(ap)

View File

@ -0,0 +1,4 @@
NodeOfCodeConfig:
SupportThirdPartModules:
- httpx
- numpy

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossagent
package agent
import (
"context"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossconnector
package connector
import (
"context"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossconversation
package conversation
import (
"context"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossdatabase
package database
import (
"context"
@ -22,6 +22,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
)
//go:generate mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
type Database interface {
ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error)
PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (resp *database.PublishDatabaseResponse, err error)
@ -30,6 +31,12 @@ type Database interface {
UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error
MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error)
GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error)
Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error)
Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error)
Update(context.Context, *database.UpdateRequest) (*database.Response, error)
Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error)
Delete(context.Context, *database.DeleteRequest) (*database.Response, error)
}
var defaultSVC Database

View File

@ -0,0 +1,219 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: database.go
//
// Generated by this command:
//
// mockgen -destination databasemock/database_mock.go --package databasemock -source database.go
//
// Package databasemock is a generated GoMock package.
package databasemock
import (
context "context"
reflect "reflect"
database "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
gomock "go.uber.org/mock/gomock"
)
// MockDatabase is a mock of Database interface.
type MockDatabase struct {
ctrl *gomock.Controller
recorder *MockDatabaseMockRecorder
isgomock struct{}
}
// MockDatabaseMockRecorder is the mock recorder for MockDatabase.
type MockDatabaseMockRecorder struct {
mock *MockDatabase
}
// NewMockDatabase creates a new mock instance.
func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase {
mock := &MockDatabase{ctrl: ctrl}
mock.recorder = &MockDatabaseMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder {
return m.recorder
}
// BindDatabase mocks base method.
func (m *MockDatabase) BindDatabase(ctx context.Context, req *database.BindDatabaseToAgentRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BindDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// BindDatabase indicates an expected call of BindDatabase.
func (mr *MockDatabaseMockRecorder) BindDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindDatabase", reflect.TypeOf((*MockDatabase)(nil).BindDatabase), ctx, req)
}
// Delete mocks base method.
func (m *MockDatabase) Delete(arg0 context.Context, arg1 *database.DeleteRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Delete indicates an expected call of Delete.
func (mr *MockDatabaseMockRecorder) Delete(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDatabase)(nil).Delete), arg0, arg1)
}
// DeleteDatabase mocks base method.
func (m *MockDatabase) DeleteDatabase(ctx context.Context, req *database.DeleteDatabaseRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteDatabase indicates an expected call of DeleteDatabase.
func (mr *MockDatabaseMockRecorder) DeleteDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDatabase", reflect.TypeOf((*MockDatabase)(nil).DeleteDatabase), ctx, req)
}
// Execute mocks base method.
func (m *MockDatabase) Execute(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Execute", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Execute indicates an expected call of Execute.
func (mr *MockDatabaseMockRecorder) Execute(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockDatabase)(nil).Execute), ctx, request)
}
// ExecuteSQL mocks base method.
func (m *MockDatabase) ExecuteSQL(ctx context.Context, req *database.ExecuteSQLRequest) (*database.ExecuteSQLResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExecuteSQL", ctx, req)
ret0, _ := ret[0].(*database.ExecuteSQLResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExecuteSQL indicates an expected call of ExecuteSQL.
func (mr *MockDatabaseMockRecorder) ExecuteSQL(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteSQL", reflect.TypeOf((*MockDatabase)(nil).ExecuteSQL), ctx, req)
}
// GetAllDatabaseByAppID mocks base method.
func (m *MockDatabase) GetAllDatabaseByAppID(ctx context.Context, req *database.GetAllDatabaseByAppIDRequest) (*database.GetAllDatabaseByAppIDResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllDatabaseByAppID", ctx, req)
ret0, _ := ret[0].(*database.GetAllDatabaseByAppIDResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllDatabaseByAppID indicates an expected call of GetAllDatabaseByAppID.
func (mr *MockDatabaseMockRecorder) GetAllDatabaseByAppID(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllDatabaseByAppID", reflect.TypeOf((*MockDatabase)(nil).GetAllDatabaseByAppID), ctx, req)
}
// Insert mocks base method.
func (m *MockDatabase) Insert(ctx context.Context, request *database.InsertRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insert", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Insert indicates an expected call of Insert.
func (mr *MockDatabaseMockRecorder) Insert(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockDatabase)(nil).Insert), ctx, request)
}
// MGetDatabase mocks base method.
func (m *MockDatabase) MGetDatabase(ctx context.Context, req *database.MGetDatabaseRequest) (*database.MGetDatabaseResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetDatabase", ctx, req)
ret0, _ := ret[0].(*database.MGetDatabaseResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetDatabase indicates an expected call of MGetDatabase.
func (mr *MockDatabaseMockRecorder) MGetDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetDatabase", reflect.TypeOf((*MockDatabase)(nil).MGetDatabase), ctx, req)
}
// PublishDatabase mocks base method.
func (m *MockDatabase) PublishDatabase(ctx context.Context, req *database.PublishDatabaseRequest) (*database.PublishDatabaseResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishDatabase", ctx, req)
ret0, _ := ret[0].(*database.PublishDatabaseResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PublishDatabase indicates an expected call of PublishDatabase.
func (mr *MockDatabaseMockRecorder) PublishDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDatabase", reflect.TypeOf((*MockDatabase)(nil).PublishDatabase), ctx, req)
}
// Query mocks base method.
func (m *MockDatabase) Query(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", ctx, request)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockDatabaseMockRecorder) Query(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockDatabase)(nil).Query), ctx, request)
}
// UnBindDatabase mocks base method.
func (m *MockDatabase) UnBindDatabase(ctx context.Context, req *database.UnBindDatabaseToAgentRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnBindDatabase", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// UnBindDatabase indicates an expected call of UnBindDatabase.
func (mr *MockDatabaseMockRecorder) UnBindDatabase(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnBindDatabase", reflect.TypeOf((*MockDatabase)(nil).UnBindDatabase), ctx, req)
}
// Update mocks base method.
func (m *MockDatabase) Update(arg0 context.Context, arg1 *database.UpdateRequest) (*database.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", arg0, arg1)
ret0, _ := ret[0].(*database.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockDatabaseMockRecorder) Update(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDatabase)(nil).Update), arg0, arg1)
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossdatacopy
package datacopy
import (
"context"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossknowledge
package knowledge
import (
"context"
@ -22,11 +22,16 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
)
//go:generate mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
type Knowledge interface {
ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (response *knowledge.ListKnowledgeResponse, err error)
GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (response *knowledge.GetKnowledgeByIDResponse, err error)
Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error)
DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error
MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (response *knowledge.MGetKnowledgeByIDResponse, err error)
Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error)
Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error)
ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error)
}
var defaultSVC Knowledge

View File

@ -0,0 +1,161 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: knowledge.go
//
// Generated by this command:
//
// mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
//
// Package knowledgemock is a generated GoMock package.
package knowledgemock
import (
context "context"
reflect "reflect"
knowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
gomock "go.uber.org/mock/gomock"
)
// MockKnowledge is a mock of Knowledge interface.
type MockKnowledge struct {
ctrl *gomock.Controller
recorder *MockKnowledgeMockRecorder
isgomock struct{}
}
// MockKnowledgeMockRecorder is the mock recorder for MockKnowledge.
type MockKnowledgeMockRecorder struct {
mock *MockKnowledge
}
// NewMockKnowledge creates a new mock instance.
func NewMockKnowledge(ctrl *gomock.Controller) *MockKnowledge {
mock := &MockKnowledge{ctrl: ctrl}
mock.recorder = &MockKnowledgeMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockKnowledge) EXPECT() *MockKnowledgeMockRecorder {
return m.recorder
}
// Delete mocks base method.
func (m *MockKnowledge) Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", ctx, r)
ret0, _ := ret[0].(*knowledge.DeleteDocumentResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Delete indicates an expected call of Delete.
func (mr *MockKnowledgeMockRecorder) Delete(ctx, r any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockKnowledge)(nil).Delete), ctx, r)
}
// DeleteKnowledge mocks base method.
func (m *MockKnowledge) DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteKnowledge", ctx, request)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteKnowledge indicates an expected call of DeleteKnowledge.
func (mr *MockKnowledgeMockRecorder) DeleteKnowledge(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKnowledge", reflect.TypeOf((*MockKnowledge)(nil).DeleteKnowledge), ctx, request)
}
// GetKnowledgeByID mocks base method.
func (m *MockKnowledge) GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (*knowledge.GetKnowledgeByIDResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKnowledgeByID", ctx, request)
ret0, _ := ret[0].(*knowledge.GetKnowledgeByIDResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKnowledgeByID indicates an expected call of GetKnowledgeByID.
func (mr *MockKnowledgeMockRecorder) GetKnowledgeByID(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).GetKnowledgeByID), ctx, request)
}
// ListKnowledge mocks base method.
func (m *MockKnowledge) ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (*knowledge.ListKnowledgeResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListKnowledge", ctx, request)
ret0, _ := ret[0].(*knowledge.ListKnowledgeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListKnowledge indicates an expected call of ListKnowledge.
func (mr *MockKnowledgeMockRecorder) ListKnowledge(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledge", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledge), ctx, request)
}
// ListKnowledgeDetail mocks base method.
func (m *MockKnowledge) ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListKnowledgeDetail", ctx, req)
ret0, _ := ret[0].(*knowledge.ListKnowledgeDetailResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListKnowledgeDetail indicates an expected call of ListKnowledgeDetail.
func (mr *MockKnowledgeMockRecorder) ListKnowledgeDetail(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledgeDetail", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledgeDetail), ctx, req)
}
// MGetKnowledgeByID mocks base method.
func (m *MockKnowledge) MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (*knowledge.MGetKnowledgeByIDResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetKnowledgeByID", ctx, request)
ret0, _ := ret[0].(*knowledge.MGetKnowledgeByIDResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetKnowledgeByID indicates an expected call of MGetKnowledgeByID.
func (mr *MockKnowledgeMockRecorder) MGetKnowledgeByID(ctx, request any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).MGetKnowledgeByID), ctx, request)
}
// Retrieve mocks base method.
func (m *MockKnowledge) Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Retrieve", ctx, req)
ret0, _ := ret[0].(*knowledge.RetrieveResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Retrieve indicates an expected call of Retrieve.
func (mr *MockKnowledgeMockRecorder) Retrieve(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockKnowledge)(nil).Retrieve), ctx, req)
}
// Store mocks base method.
func (m *MockKnowledge) Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Store", ctx, document)
ret0, _ := ret[0].(*knowledge.CreateDocumentResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Store indicates an expected call of Store.
func (mr *MockKnowledgeMockRecorder) Store(ctx, document any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockKnowledge)(nil).Store), ctx, document)
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crossmessage
package message
import (
"context"

View File

@ -0,0 +1,41 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package modelmgr
import (
"context"
eino "github.com/cloudwego/eino/components/model"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
)
//go:generate mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
type Manager interface {
GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error)
}
var defaultSVC Manager
func DefaultSVC() Manager {
return defaultSVC
}
func SetDefaultSVC(svc Manager) {
defaultSVC = svc
}

View File

@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model.go
// Source: modelmgr.go
//
// Generated by this command:
//
// mockgen -destination modelmock/model_mock.go --package mockmodel -source model.go
// mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
//
// Package mockmodel is a generated GoMock package.
@ -14,7 +14,7 @@ import (
reflect "reflect"
model "github.com/cloudwego/eino/components/model"
model0 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
model0 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
modelmgr "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
gomock "go.uber.org/mock/gomock"
)
@ -23,6 +23,7 @@ import (
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
isgomock struct{}
}
// MockManagerMockRecorder is the mock recorder for MockManager.

View File

@ -14,14 +14,18 @@
* limitations under the License.
*/
package crossplugin
package plugin
import (
"context"
"github.com/cloudwego/eino/schema"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
)
//go:generate mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go
type PluginService interface {
MGetVersionPlugins(ctx context.Context, versionPlugins []model.VersionPlugin) (plugins []*model.PluginInfo, err error)
MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (resp *model.MGetPluginLatestVersionResponse, err error)
@ -35,6 +39,14 @@ type PluginService interface {
PublishAPPPlugins(ctx context.Context, req *model.PublishAPPPluginsRequest) (resp *model.PublishAPPPluginsResponse, err error)
GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*model.PluginInfo, err error)
MGetVersionTools(ctx context.Context, versionTools []model.VersionTool) (tools []*model.ToolInfo, err error)
GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) (*model.ToolsInfoResponse, error)
GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) (map[int64]InvokableTool, error)
ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error)
}
type InvokableTool interface {
Info(ctx context.Context) (*schema.ToolInfo, error)
PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflow.ExecuteConfig) (string, error)
}
var defaultSVC PluginService

View File

@ -0,0 +1,324 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: plugin.go
//
// Generated by this command:
//
// mockgen -destination pluginmock/plugin_mock.go --package pluginmock -source plugin.go
//
// Package pluginmock is a generated GoMock package.
package pluginmock
import (
context "context"
reflect "reflect"
schema "github.com/cloudwego/eino/schema"
plugin "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
workflow "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
plugin0 "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
gomock "go.uber.org/mock/gomock"
)
// MockPluginService is a mock of PluginService interface.
type MockPluginService struct {
ctrl *gomock.Controller
recorder *MockPluginServiceMockRecorder
isgomock struct{}
}
// MockPluginServiceMockRecorder is the mock recorder for MockPluginService.
type MockPluginServiceMockRecorder struct {
mock *MockPluginService
}
// NewMockPluginService creates a new mock instance.
func NewMockPluginService(ctrl *gomock.Controller) *MockPluginService {
mock := &MockPluginService{ctrl: ctrl}
mock.recorder = &MockPluginServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPluginService) EXPECT() *MockPluginServiceMockRecorder {
return m.recorder
}
// BindAgentTools mocks base method.
func (m *MockPluginService) BindAgentTools(ctx context.Context, agentID int64, toolIDs []int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BindAgentTools", ctx, agentID, toolIDs)
ret0, _ := ret[0].(error)
return ret0
}
// BindAgentTools indicates an expected call of BindAgentTools.
func (mr *MockPluginServiceMockRecorder) BindAgentTools(ctx, agentID, toolIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindAgentTools", reflect.TypeOf((*MockPluginService)(nil).BindAgentTools), ctx, agentID, toolIDs)
}
// DeleteDraftPlugin mocks base method.
func (m *MockPluginService) DeleteDraftPlugin(ctx context.Context, PluginID int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteDraftPlugin", ctx, PluginID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteDraftPlugin indicates an expected call of DeleteDraftPlugin.
func (mr *MockPluginServiceMockRecorder) DeleteDraftPlugin(ctx, PluginID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDraftPlugin", reflect.TypeOf((*MockPluginService)(nil).DeleteDraftPlugin), ctx, PluginID)
}
// DuplicateDraftAgentTools mocks base method.
func (m *MockPluginService) DuplicateDraftAgentTools(ctx context.Context, fromAgentID, toAgentID int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DuplicateDraftAgentTools", ctx, fromAgentID, toAgentID)
ret0, _ := ret[0].(error)
return ret0
}
// DuplicateDraftAgentTools indicates an expected call of DuplicateDraftAgentTools.
func (mr *MockPluginServiceMockRecorder) DuplicateDraftAgentTools(ctx, fromAgentID, toAgentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DuplicateDraftAgentTools", reflect.TypeOf((*MockPluginService)(nil).DuplicateDraftAgentTools), ctx, fromAgentID, toAgentID)
}
// ExecutePlugin mocks base method.
func (m *MockPluginService) ExecutePlugin(ctx context.Context, input map[string]any, pe *plugin.PluginEntity, toolID int64, cfg workflow.ExecuteConfig) (map[string]any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExecutePlugin", ctx, input, pe, toolID, cfg)
ret0, _ := ret[0].(map[string]any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExecutePlugin indicates an expected call of ExecutePlugin.
func (mr *MockPluginServiceMockRecorder) ExecutePlugin(ctx, input, pe, toolID, cfg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecutePlugin", reflect.TypeOf((*MockPluginService)(nil).ExecutePlugin), ctx, input, pe, toolID, cfg)
}
// ExecuteTool mocks base method.
func (m *MockPluginService) ExecuteTool(ctx context.Context, req *plugin.ExecuteToolRequest, opts ...plugin.ExecuteToolOpt) (*plugin.ExecuteToolResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, req}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ExecuteTool", varargs...)
ret0, _ := ret[0].(*plugin.ExecuteToolResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExecuteTool indicates an expected call of ExecuteTool.
func (mr *MockPluginServiceMockRecorder) ExecuteTool(ctx, req any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, req}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTool", reflect.TypeOf((*MockPluginService)(nil).ExecuteTool), varargs...)
}
// GetAPPAllPlugins mocks base method.
func (m *MockPluginService) GetAPPAllPlugins(ctx context.Context, appID int64) ([]*plugin.PluginInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAPPAllPlugins", ctx, appID)
ret0, _ := ret[0].([]*plugin.PluginInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAPPAllPlugins indicates an expected call of GetAPPAllPlugins.
func (mr *MockPluginServiceMockRecorder) GetAPPAllPlugins(ctx, appID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPPAllPlugins", reflect.TypeOf((*MockPluginService)(nil).GetAPPAllPlugins), ctx, appID)
}
// GetPluginInvokableTools mocks base method.
func (m *MockPluginService) GetPluginInvokableTools(ctx context.Context, req *plugin.ToolsInvokableRequest) (map[int64]plugin0.InvokableTool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPluginInvokableTools", ctx, req)
ret0, _ := ret[0].(map[int64]plugin0.InvokableTool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginInvokableTools indicates an expected call of GetPluginInvokableTools.
func (mr *MockPluginServiceMockRecorder) GetPluginInvokableTools(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInvokableTools", reflect.TypeOf((*MockPluginService)(nil).GetPluginInvokableTools), ctx, req)
}
// GetPluginToolsInfo mocks base method.
func (m *MockPluginService) GetPluginToolsInfo(ctx context.Context, req *plugin.ToolsInfoRequest) (*plugin.ToolsInfoResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPluginToolsInfo", ctx, req)
ret0, _ := ret[0].(*plugin.ToolsInfoResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginToolsInfo indicates an expected call of GetPluginToolsInfo.
func (mr *MockPluginServiceMockRecorder) GetPluginToolsInfo(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginToolsInfo", reflect.TypeOf((*MockPluginService)(nil).GetPluginToolsInfo), ctx, req)
}
// MGetAgentTools mocks base method.
func (m *MockPluginService) MGetAgentTools(ctx context.Context, req *plugin.MGetAgentToolsRequest) ([]*plugin.ToolInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetAgentTools", ctx, req)
ret0, _ := ret[0].([]*plugin.ToolInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetAgentTools indicates an expected call of MGetAgentTools.
func (mr *MockPluginServiceMockRecorder) MGetAgentTools(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetAgentTools", reflect.TypeOf((*MockPluginService)(nil).MGetAgentTools), ctx, req)
}
// MGetPluginLatestVersion mocks base method.
func (m *MockPluginService) MGetPluginLatestVersion(ctx context.Context, pluginIDs []int64) (*plugin.MGetPluginLatestVersionResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetPluginLatestVersion", ctx, pluginIDs)
ret0, _ := ret[0].(*plugin.MGetPluginLatestVersionResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetPluginLatestVersion indicates an expected call of MGetPluginLatestVersion.
func (mr *MockPluginServiceMockRecorder) MGetPluginLatestVersion(ctx, pluginIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetPluginLatestVersion", reflect.TypeOf((*MockPluginService)(nil).MGetPluginLatestVersion), ctx, pluginIDs)
}
// MGetVersionPlugins mocks base method.
func (m *MockPluginService) MGetVersionPlugins(ctx context.Context, versionPlugins []plugin.VersionPlugin) ([]*plugin.PluginInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetVersionPlugins", ctx, versionPlugins)
ret0, _ := ret[0].([]*plugin.PluginInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetVersionPlugins indicates an expected call of MGetVersionPlugins.
func (mr *MockPluginServiceMockRecorder) MGetVersionPlugins(ctx, versionPlugins any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionPlugins", reflect.TypeOf((*MockPluginService)(nil).MGetVersionPlugins), ctx, versionPlugins)
}
// MGetVersionTools mocks base method.
func (m *MockPluginService) MGetVersionTools(ctx context.Context, versionTools []plugin.VersionTool) ([]*plugin.ToolInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MGetVersionTools", ctx, versionTools)
ret0, _ := ret[0].([]*plugin.ToolInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MGetVersionTools indicates an expected call of MGetVersionTools.
func (mr *MockPluginServiceMockRecorder) MGetVersionTools(ctx, versionTools any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetVersionTools", reflect.TypeOf((*MockPluginService)(nil).MGetVersionTools), ctx, versionTools)
}
// PublishAPPPlugins mocks base method.
func (m *MockPluginService) PublishAPPPlugins(ctx context.Context, req *plugin.PublishAPPPluginsRequest) (*plugin.PublishAPPPluginsResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishAPPPlugins", ctx, req)
ret0, _ := ret[0].(*plugin.PublishAPPPluginsResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PublishAPPPlugins indicates an expected call of PublishAPPPlugins.
func (mr *MockPluginServiceMockRecorder) PublishAPPPlugins(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAPPPlugins", reflect.TypeOf((*MockPluginService)(nil).PublishAPPPlugins), ctx, req)
}
// PublishAgentTools mocks base method.
func (m *MockPluginService) PublishAgentTools(ctx context.Context, agentID int64, agentVersion string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishAgentTools", ctx, agentID, agentVersion)
ret0, _ := ret[0].(error)
return ret0
}
// PublishAgentTools indicates an expected call of PublishAgentTools.
func (mr *MockPluginServiceMockRecorder) PublishAgentTools(ctx, agentID, agentVersion any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAgentTools", reflect.TypeOf((*MockPluginService)(nil).PublishAgentTools), ctx, agentID, agentVersion)
}
// PublishPlugin mocks base method.
func (m *MockPluginService) PublishPlugin(ctx context.Context, req *plugin.PublishPluginRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishPlugin", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}
// PublishPlugin indicates an expected call of PublishPlugin.
func (mr *MockPluginServiceMockRecorder) PublishPlugin(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishPlugin", reflect.TypeOf((*MockPluginService)(nil).PublishPlugin), ctx, req)
}
// MockInvokableTool is a mock of InvokableTool interface.
type MockInvokableTool struct {
ctrl *gomock.Controller
recorder *MockInvokableToolMockRecorder
isgomock struct{}
}
// MockInvokableToolMockRecorder is the mock recorder for MockInvokableTool.
type MockInvokableToolMockRecorder struct {
mock *MockInvokableTool
}
// NewMockInvokableTool creates a new mock instance.
func NewMockInvokableTool(ctrl *gomock.Controller) *MockInvokableTool {
mock := &MockInvokableTool{ctrl: ctrl}
mock.recorder = &MockInvokableToolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockInvokableTool) EXPECT() *MockInvokableToolMockRecorder {
return m.recorder
}
// Info mocks base method.
func (m *MockInvokableTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Info", ctx)
ret0, _ := ret[0].(*schema.ToolInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Info indicates an expected call of Info.
func (mr *MockInvokableToolMockRecorder) Info(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockInvokableTool)(nil).Info), ctx)
}
// PluginInvoke mocks base method.
func (m *MockInvokableTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflow.ExecuteConfig) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PluginInvoke", ctx, argumentsInJSON, cfg)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PluginInvoke indicates an expected call of PluginInvoke.
func (mr *MockInvokableToolMockRecorder) PluginInvoke(ctx, argumentsInJSON, cfg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PluginInvoke", reflect.TypeOf((*MockInvokableTool)(nil).PluginInvoke), ctx, argumentsInJSON, cfg)
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package crosssearch
package search
import (
"context"

View File

@ -14,19 +14,25 @@
* limitations under the License.
*/
package crossvariables
package variables
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
)
// TODO (@fanlv): Parameter references need to be modified.
type Variables interface {
GetVariableInstance(ctx context.Context, e *variables.UserVariableMeta, keywords []string) ([]*kvmemory.KVItem, error)
SetVariableInstance(ctx context.Context, e *variables.UserVariableMeta, items []*kvmemory.KVItem) ([]string, error)
DecryptSysUUIDKey(ctx context.Context, encryptSysUUIDKey string) *variables.UserVariableMeta
GetVariableChannelInstance(ctx context.Context, e *variables.UserVariableMeta, keywords []string, varChannel *project_memory.VariableChannel) ([]*kvmemory.KVItem, error)
GetProjectVariablesMeta(ctx context.Context, projectID, version string) (*entity.VariablesMeta, error)
GetAgentVariableMeta(ctx context.Context, agentID int64, version string) (*entity.VariablesMeta, error)
}
var defaultSVC Variables

View File

@ -23,28 +23,28 @@ import (
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// TODO (@fanlv): Parameter references need to be modified.
type Workflow interface {
WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]workflow.ToolFromWorkflow, error)
DeleteWorkflow(ctx context.Context, id int64) error
PublishWorkflow(ctx context.Context, info *vo.PublishPolicy) (err error)
WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterruptEvent, resumeData string,
allInterruptEvents map[string]*workflowEntity.ToolInterruptEvent) einoCompose.Option
ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error)
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func())
}
type ExecuteConfig = vo.ExecuteConfig
type ExecuteMode = vo.ExecuteMode
type ExecuteConfig = workflowModel.ExecuteConfig
type ExecuteMode = workflowModel.ExecuteMode
type NodeType = entity.NodeType
type WorkflowMessage = entity.Message
@ -59,14 +59,14 @@ const (
ExecuteModeNodeDebug ExecuteMode = "node_debug"
)
type TaskType = vo.TaskType
type TaskType = workflowModel.TaskType
const (
TaskTypeForeground TaskType = "foreground"
TaskTypeBackground TaskType = "background"
)
type BizType = vo.BizType
type BizType = workflowModel.BizType
const (
BizTypeAgent BizType = "agent"

View File

@ -19,7 +19,7 @@ package agentrun
import (
"context"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagentrun"
crossagentrun "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agentrun"
agentrun "github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/service"
)

View File

@ -20,7 +20,7 @@ import (
"context"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/connector"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
connector "github.com/coze-dev/coze-studio/backend/domain/connector/service"
)

View File

@ -20,7 +20,7 @@ import (
"context"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconversation"
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
conversation "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/service"
)

View File

@ -19,7 +19,7 @@ package crossuser
import (
"context"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossuser"
crossuser "github.com/coze-dev/coze-studio/backend/crossdomain/contract/user"
"github.com/coze-dev/coze-studio/backend/domain/user/entity"
"github.com/coze-dev/coze-studio/backend/domain/user/service"
)

View File

@ -18,10 +18,20 @@ package database
import (
"context"
"fmt"
"strings"
"github.com/spf13/cast"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
database "github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
var defaultSVC crossdatabase.Database
@ -65,3 +75,406 @@ func (c *databaseImpl) MGetDatabase(ctx context.Context, req *model.MGetDatabase
func (c *databaseImpl) GetAllDatabaseByAppID(ctx context.Context, req *model.GetAllDatabaseByAppIDRequest) (*model.GetAllDatabaseByAppIDResponse, error) {
return c.DomainSVC.GetAllDatabaseByAppID(ctx, req)
}
func (d *databaseImpl) Execute(ctx context.Context, request *model.CustomSQLRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Custom,
SQL: &request.SQL,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SQLParams = make([]*model.SQLParamVal, 0, len(request.Params))
for i := range request.Params {
param := request.Params[i]
req.SQLParams = append(req.SQLParams, &model.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &param.Value,
ISNull: param.IsNull,
})
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
// if rows affected is nil use 0 instead
if response.RowsAffected == nil {
response.RowsAffected = ptr.Of(int64(0))
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Delete(ctx context.Context, request *model.DeleteRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Delete,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Query(ctx context.Context, request *model.QueryRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Select,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SelectFieldList = &model.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
for i := range request.SelectFields {
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
}
req.OrderByList = make([]model.OrderBy, 0)
for i := range request.OrderClauses {
clause := request.OrderClauses[i]
req.OrderByList = append(req.OrderByList, model.OrderBy{
Field: clause.FieldID,
Direction: toOrderDirection(clause.IsAsc),
})
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
limit := request.Limit
req.Limit = &limit
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Update(ctx context.Context, request *model.UpdateRequest) (*model.Response, error) {
var (
err error
condition *model.ComplexCondition
params []*model.SQLParamVal
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Update,
SQLParams: make([]*model.SQLParamVal, 0),
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.UserID = conv.Int64ToStr(*uid)
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
if request.ConditionGroup != nil {
condition, params, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
req.Condition = condition
req.SQLParams = append(req.SQLParams, params...)
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) Insert(ctx context.Context, request *model.InsertRequest) (*model.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: model.OperateType_Insert,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
response, err := d.DomainSVC.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *databaseImpl) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
resp, err := d.DomainSVC.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
if err != nil {
return 0, err
}
return resp.Database.ID, nil
}
func buildComplexCondition(conditionGroup *model.ConditionGroup) (*model.ComplexCondition, []*model.SQLParamVal, error) {
condition := &model.ComplexCondition{}
logic, err := toLogic(conditionGroup.Relation)
if err != nil {
return nil, nil, err
}
condition.Logic = logic
params := make([]*model.SQLParamVal, 0)
for i := range conditionGroup.Conditions {
var (
nCond = conditionGroup.Conditions[i]
vals []*model.SQLParamVal
dCond = &model.Condition{
Left: nCond.Left,
}
)
opt, err := toOperation(nCond.Operator)
if err != nil {
return nil, nil, err
}
dCond.Operation = opt
if isNullOrNotNull(opt) {
condition.Conditions = append(condition.Conditions, dCond)
continue
}
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
if err != nil {
return nil, nil, err
}
condition.Conditions = append(condition.Conditions, dCond)
params = append(params, vals...)
}
return condition, params, nil
}
func toMapStringAny(m map[string]string) map[string]any {
ret := make(map[string]any, len(m))
for k, v := range m {
ret[k] = v
}
return ret
}
func toOperation(operator model.Operator) (model.Operation, error) {
switch operator {
case model.OperatorEqual:
return model.Operation_EQUAL, nil
case model.OperatorNotEqual:
return model.Operation_NOT_EQUAL, nil
case model.OperatorGreater:
return model.Operation_GREATER_THAN, nil
case model.OperatorGreaterOrEqual:
return model.Operation_GREATER_EQUAL, nil
case model.OperatorLesser:
return model.Operation_LESS_THAN, nil
case model.OperatorLesserOrEqual:
return model.Operation_LESS_EQUAL, nil
case model.OperatorIn:
return model.Operation_IN, nil
case model.OperatorNotIn:
return model.Operation_NOT_IN, nil
case model.OperatorIsNotNull:
return model.Operation_IS_NOT_NULL, nil
case model.OperatorIsNull:
return model.Operation_IS_NULL, nil
case model.OperatorLike:
return model.Operation_LIKE, nil
case model.OperatorNotLike:
return model.Operation_NOT_LIKE, nil
default:
return model.Operation(0), fmt.Errorf("invalid operator %v", operator)
}
}
func resolveRightValue(operator model.Operation, right any) (string, []*model.SQLParamVal, error) {
if isInOrNotIn(operator) {
var (
vals = make([]*model.SQLParamVal, 0)
anyVals = make([]any, 0)
commas = make([]string, 0, len(anyVals))
)
anyVals = right.([]any)
for i := range anyVals {
v := cast.ToString(anyVals[i])
vals = append(vals, &model.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
commas = append(commas, "?")
}
value := "(" + strings.Join(commas, ",") + ")"
return value, vals, nil
}
rightValue, err := cast.ToStringE(right)
if err != nil {
return "", nil, err
}
if isLikeOrNotLike(operator) {
var (
value = "?"
v = "%s" + rightValue + "%s"
)
return value, []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
}
return "?", []*model.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
}
func resolveUpsertRow(fields map[string]any) ([]*model.UpsertRow, []*model.SQLParamVal, error) {
upsertRow := &model.UpsertRow{Records: make([]*model.Record, 0, len(fields))}
params := make([]*model.SQLParamVal, 0)
for key, value := range fields {
val, err := cast.ToStringE(value)
if err != nil {
return nil, nil, err
}
record := &model.Record{
FieldId: key,
FieldValue: "?",
}
upsertRow.Records = append(upsertRow.Records, record)
params = append(params, &model.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &val,
})
}
return []*model.UpsertRow{upsertRow}, params, nil
}
func isNullOrNotNull(opt model.Operation) bool {
return opt == model.Operation_IS_NOT_NULL || opt == model.Operation_IS_NULL
}
func isLikeOrNotLike(opt model.Operation) bool {
return opt == model.Operation_LIKE || opt == model.Operation_NOT_LIKE
}
func isInOrNotIn(opt model.Operation) bool {
return opt == model.Operation_IN || opt == model.Operation_NOT_IN
}
func toOrderDirection(isAsc bool) table.SortDirection {
if isAsc {
return table.SortDirection_ASC
}
return table.SortDirection_Desc
}
func toLogic(relation model.ClauseRelation) (model.Logic, error) {
switch relation {
case model.ClauseRelationOR:
return model.Logic_Or, nil
case model.ClauseRelationAND:
return model.Logic_And, nil
default:
return model.Logic(0), fmt.Errorf("invalid relation %v", relation)
}
}
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *model.Response {
objects := make([]model.Object, 0, len(response.Records))
for i := range response.Records {
objects = append(objects, response.Records[i])
}
return &model.Response{
Objects: objects,
RowNumber: response.RowsAffected,
}
}

View File

@ -22,7 +22,7 @@ import (
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/application/base/appinfra"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
"github.com/coze-dev/coze-studio/backend/domain/datacopy"
"github.com/coze-dev/coze-studio/backend/domain/datacopy/service"
)

View File

@ -18,10 +18,18 @@ package knowledge
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
var defaultSVC crossknowledge.Knowledge
@ -57,3 +65,120 @@ func (i *impl) GetKnowledgeByID(ctx context.Context, request *model.GetKnowledge
func (i *impl) MGetKnowledgeByID(ctx context.Context, request *model.MGetKnowledgeByIDRequest) (response *model.MGetKnowledgeByIDResponse, err error) {
return i.DomainSVC.MGetKnowledgeByID(ctx, request)
}
func (i *impl) Store(ctx context.Context, document *model.CreateDocumentRequest) (*model.CreateDocumentResponse, error) {
var (
ps *entity.ParsingStrategy
cs = &entity.ChunkingStrategy{}
)
if document.ParsingStrategy == nil {
return nil, errors.New("document parsing strategy is required")
}
if document.ChunkingStrategy == nil {
return nil, errors.New("document chunking strategy is required")
}
if document.ParsingStrategy.ParseMode == model.AccurateParseMode {
ps = &entity.ParsingStrategy{}
ps.ExtractImage = document.ParsingStrategy.ExtractImage
ps.ExtractTable = document.ParsingStrategy.ExtractTable
ps.ImageOCR = document.ParsingStrategy.ImageOCR
}
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
if err != nil {
return nil, err
}
cs.ChunkType = chunkType
cs.Separator = document.ChunkingStrategy.Separator
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
cs.Overlap = document.ChunkingStrategy.Overlap
req := &entity.Document{
Info: knowledge.Info{
Name: document.FileName,
},
KnowledgeID: document.KnowledgeID,
Type: knowledge.DocumentTypeText,
URL: document.FileURL,
Source: entity.DocumentSourceLocal,
ParsingStrategy: ps,
ChunkingStrategy: cs,
FileExtension: document.FileExtension,
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.Info.CreatorID = *uid
}
response, err := i.DomainSVC.CreateDocument(ctx, &service.CreateDocumentRequest{
Documents: []*entity.Document{req},
})
if err != nil {
return nil, err
}
kCResponse := &model.CreateDocumentResponse{
FileURL: document.FileURL,
DocumentID: response.Documents[0].Info.ID,
FileName: response.Documents[0].Info.Name,
}
return kCResponse, nil
}
func (i *impl) Delete(ctx context.Context, r *model.DeleteDocumentRequest) (*model.DeleteDocumentResponse, error) {
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
}
err = i.DomainSVC.DeleteDocument(ctx, &service.DeleteDocumentRequest{
DocumentID: docID,
})
if err != nil {
return &model.DeleteDocumentResponse{IsSuccess: false}, err
}
return &model.DeleteDocumentResponse{IsSuccess: true}, nil
}
func (i *impl) ListKnowledgeDetail(ctx context.Context, req *model.ListKnowledgeDetailRequest) (*model.ListKnowledgeDetailResponse, error) {
response, err := i.DomainSVC.MGetKnowledgeByID(ctx, &service.MGetKnowledgeByIDRequest{
KnowledgeIDs: req.KnowledgeIDs,
})
if err != nil {
return nil, err
}
resp := &model.ListKnowledgeDetailResponse{
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *model.KnowledgeDetail {
return &model.KnowledgeDetail{
ID: a.ID,
Name: a.Name,
Description: a.Description,
IconURL: a.IconURL,
FormatType: int64(a.Type),
}
}),
}
return resp, nil
}
func toChunkType(typ model.ChunkType) (parser.ChunkType, error) {
switch typ {
case model.ChunkTypeDefault:
return parser.ChunkTypeDefault, nil
case model.ChunkTypeCustom:
return parser.ChunkTypeCustom, nil
case model.ChunkTypeLeveled:
return parser.ChunkTypeLeveled, nil
default:
return 0, fmt.Errorf("unknown chunk type: %v", typ)
}
}

View File

@ -20,7 +20,7 @@ import (
"context"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
message "github.com/coze-dev/coze-studio/backend/domain/conversation/message/service"
)

View File

@ -14,37 +14,38 @@
* limitations under the License.
*/
package model
package modelmgr
import (
"context"
"fmt"
model2 "github.com/cloudwego/eino/components/model"
eino "github.com/cloudwego/eino/components/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
chatmodel2 "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type ModelManager struct {
type modelManager struct {
modelMgr modelmgr.Manager
factory chatmodel.Factory
}
func NewModelManager(m modelmgr.Manager, f chatmodel.Factory) *ModelManager {
func InitDomainService(m modelmgr.Manager, f chatmodel.Factory) crossmodelmgr.Manager {
if f == nil {
f = chatmodel2.NewDefaultFactory()
}
return &ModelManager{
return &modelManager{
modelMgr: m,
factory: f,
}
}
func (m *ModelManager) GetModel(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
func (m *modelManager) GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error) {
modelID := params.ModelType
models, err := m.modelMgr.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
IDs: []int64{modelID},

View File

@ -18,23 +18,46 @@ package plugin
import (
"context"
"fmt"
"strconv"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3"
"golang.org/x/exp/maps"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
plugin "github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
var defaultSVC crossplugin.PluginService
type impl struct {
DomainSVC plugin.PluginService
tos storage.Storage
}
func InitDomainService(c plugin.PluginService) crossplugin.PluginService {
func InitDomainService(c plugin.PluginService, tos storage.Storage) crossplugin.PluginService {
defaultSVC = &impl{
DomainSVC: c,
tos: tos,
}
return defaultSVC
@ -105,3 +128,479 @@ func (s *impl) GetAPPAllPlugins(ctx context.Context, appID int64) (plugins []*mo
return plugins, nil
}
type pluginInfo struct {
*entity.PluginInfo
LatestVersion *string
}
func (s *impl) getPluginsWithTools(ctx context.Context, pluginEntity *model.PluginEntity, toolIDs []int64, isDraft bool) (
_ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var pluginsInfo []*entity.PluginInfo
var latestPluginInfo *entity.PluginInfo
pluginID := pluginEntity.PluginID
if isDraft {
plugins, err := s.DomainSVC.MGetDraftPlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
plugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
} else {
plugins, err := s.DomainSVC.MGetVersionPlugins(ctx, []entity.VersionPlugin{
{PluginID: pluginID, Version: *pluginEntity.PluginVersion},
})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
onlinePlugins, err := s.DomainSVC.MGetOnlinePlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
for _, pi := range onlinePlugins {
if pi.ID == pluginID {
latestPluginInfo = pi
break
}
}
}
var pInfo *entity.PluginInfo
for _, p := range pluginsInfo {
if p.ID == pluginID {
pInfo = p
break
}
}
if pInfo == nil {
return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10)))
}
if isDraft {
tools, err := s.DomainSVC.MGetDraftTools(ctx, toolIDs)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
tools, err := s.DomainSVC.MGetOnlineTools(ctx, toolIDs)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
} else {
eVersionTools := slices.Transform(toolIDs, func(tid int64) entity.VersionTool {
return entity.VersionTool{
ToolID: tid,
Version: *pluginEntity.PluginVersion,
}
})
tools, err := s.DomainSVC.MGetVersionTools(ctx, eVersionTools)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
}
if latestPluginInfo != nil {
return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil
}
return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil
}
func (s *impl) GetPluginToolsInfo(ctx context.Context, req *model.ToolsInfoRequest) (
_ *model.ToolsInfoResponse, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var toolsInfo []*entity.ToolInfo
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft)
if err != nil {
return nil, err
}
url, err := s.tos.GetObjectUrl(ctx, pInfo.GetIconURI())
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrTOSError, err)
}
response := &model.ToolsInfoResponse{
PluginID: pInfo.ID,
SpaceID: pInfo.SpaceID,
Version: pInfo.GetVersion(),
PluginName: pInfo.GetName(),
Description: pInfo.GetDesc(),
IconURL: url,
PluginType: int64(pInfo.PluginType),
ToolInfoList: make(map[int64]model.ToolInfoW),
LatestVersion: pInfo.LatestVersion,
IsOfficial: pInfo.IsOfficial(),
AppID: pInfo.GetAPPID(),
}
for _, tf := range toolsInfo {
inputs, err := tf.ToReqAPIParameter()
if err != nil {
return nil, err
}
outputs, err := tf.ToRespAPIParameter()
if err != nil {
return nil, err
}
toolExample := pInfo.GetToolExample(ctx, tf.GetName())
var (
requestExample string
responseExample string
)
if toolExample != nil {
requestExample = toolExample.RequestExample
responseExample = toolExample.ResponseExample
}
response.ToolInfoList[tf.ID] = model.ToolInfoW{
ToolID: tf.ID,
ToolName: tf.GetName(),
Inputs: slices.Transform(inputs, toWorkflowAPIParameter),
Outputs: slices.Transform(outputs, toWorkflowAPIParameter),
Description: tf.GetDesc(),
DebugExample: &model.DebugExample{
ReqExample: requestExample,
RespExample: responseExample,
},
}
}
return response, nil
}
func (s *impl) GetPluginInvokableTools(ctx context.Context, req *model.ToolsInvokableRequest) (
_ map[int64]crossplugin.InvokableTool, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var toolsInfo []*entity.ToolInfo
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
pInfo, toolsInfo, err := s.getPluginsWithTools(ctx, &model.PluginEntity{
PluginID: req.PluginEntity.PluginID,
PluginVersion: req.PluginEntity.PluginVersion,
}, maps.Keys(req.ToolsInvokableInfo), isDraft)
if err != nil {
return nil, err
}
result := map[int64]crossplugin.InvokableTool{}
for _, tf := range toolsInfo {
tl := &pluginInvokeTool{
pluginEntity: model.PluginEntity{
PluginID: pInfo.ID,
PluginVersion: pInfo.Version,
},
client: s.DomainSVC,
toolInfo: tf,
IsDraft: isDraft,
}
if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) {
reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter)
respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter)
tl.toolOperation, err = pluginutil.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters)
if err != nil {
return nil, err
}
tl.toolOperation.OperationID = tf.Operation.OperationID
tl.toolOperation.Summary = tf.Operation.Summary
}
result[tf.ID] = tl
}
return result, nil
}
func (s *impl) ExecutePlugin(ctx context.Context, input map[string]any, pe *model.PluginEntity,
toolID int64, cfg workflowModel.ExecuteConfig) (map[string]any, error) {
args, err := sonic.MarshalString(input)
if err != nil {
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
var uID string
if cfg.AgentID != nil {
uID = cfg.ConnectorUID
} else {
uID = conv.Int64ToStr(cfg.Operator)
}
req := &service.ExecuteToolRequest{
UserID: uID,
PluginID: pe.PluginID,
ToolID: toolID,
ExecScene: model.ExecSceneOfWorkflow,
ArgumentsInJson: args,
ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0",
}
execOpts := []entity.ExecuteToolOpt{
model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault),
}
if pe.PluginVersion != nil {
execOpts = append(execOpts, model.WithToolVersion(*pe.PluginVersion))
}
r, err := s.DomainSVC.ExecuteTool(ctx, req, execOpts...)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
pluginTIE, ok := extra.(*model.ToolInterruptEvent)
if !ok {
return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
}
var eventType workflow3.EventType
switch pluginTIE.Event {
case model.InterruptEventTypeOfToolNeedOAuth:
eventType = workflow3.EventType_WorkflowOauthPlugin
default:
return nil, vo.WrapError(errno.ErrPluginAPIErr,
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return nil, vo.WrapError(errno.ErrIDGenError, err)
}
ie := &entity2.InterruptEvent{
ID: id,
InterruptData: pluginTIE.ToolNeedOAuth.Message,
EventType: eventType,
}
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
interruptData := ie.InterruptData
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
}
return nil, err
}
var output map[string]any
err = sonic.UnmarshalString(r.TrimmedResp, &output)
if err != nil {
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
return output, nil
}
type pluginInvokeTool struct {
pluginEntity model.PluginEntity
client service.PluginService
toolInfo *entity.ToolInfo
toolOperation *openapi3.Operation
IsDraft bool
}
func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var parameterInfo map[string]*schema.ParameterInfo
if p.toolOperation != nil {
parameterInfo, err = model.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx)
} else {
parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx)
}
if err != nil {
return nil, err
}
return &schema.ToolInfo{
Name: p.toolInfo.GetName(),
Desc: p.toolInfo.GetDesc(),
ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo),
}, nil
}
func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg workflowModel.ExecuteConfig) (string, error) {
req := &service.ExecuteToolRequest{
UserID: conv.Int64ToStr(cfg.Operator),
PluginID: p.pluginEntity.PluginID,
ToolID: p.toolInfo.ID,
ExecScene: model.ExecSceneOfWorkflow,
ArgumentsInJson: argumentsInJSON,
ExecDraftTool: p.IsDraft,
}
execOpts := []entity.ExecuteToolOpt{
model.WithInvalidRespProcessStrategy(model.InvalidResponseProcessStrategyOfReturnDefault),
}
if p.pluginEntity.PluginVersion != nil {
execOpts = append(execOpts, model.WithToolVersion(*p.pluginEntity.PluginVersion))
}
if p.toolOperation != nil {
execOpts = append(execOpts, model.WithOpenapiOperation(model.NewOpenapi3Operation(p.toolOperation)))
}
r, err := p.client.ExecuteTool(ctx, req, execOpts...)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
pluginTIE, ok := extra.(*model.ToolInterruptEvent)
if !ok {
return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
}
var eventType workflow3.EventType
switch pluginTIE.Event {
case model.InterruptEventTypeOfToolNeedOAuth:
eventType = workflow3.EventType_WorkflowOauthPlugin
default:
return "", vo.WrapError(errno.ErrPluginAPIErr,
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return "", vo.WrapError(errno.ErrIDGenError, err)
}
ie := &entity2.InterruptEvent{
ID: id,
InterruptData: pluginTIE.ToolNeedOAuth.Message,
EventType: eventType,
}
tie := &entity2.ToolInterruptEvent{
ToolCallID: compose.GetToolCallID(ctx),
ToolName: p.toolInfo.GetName(),
InterruptEvent: ie,
}
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
_ = tie
interruptData := ie.InterruptData
return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
}
return "", err
}
return r.TrimmedResp, nil
}
func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter {
if parameter == nil {
return nil
}
p := &common.APIParameter{
ID: parameter.ID,
Name: parameter.Name,
Desc: parameter.Desc,
Type: common.ParameterType(parameter.Type),
Location: common.ParameterLocation(parameter.Location),
IsRequired: parameter.IsRequired,
GlobalDefault: parameter.GlobalDefault,
GlobalDisable: parameter.GlobalDisable,
LocalDefault: parameter.LocalDefault,
LocalDisable: parameter.LocalDisable,
VariableRef: parameter.VariableRef,
}
if parameter.SubType != nil {
p.SubType = ptr.Of(common.ParameterType(*parameter.SubType))
}
if parameter.DefaultParamSource != nil {
p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource))
}
if parameter.AssistType != nil {
p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType))
}
if len(parameter.SubParameters) > 0 {
p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters))
for _, subParam := range parameter.SubParameters {
p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam))
}
}
return p
}
func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter {
if parameter == nil {
return nil
}
p := &workflow3.APIParameter{
ID: parameter.ID,
Name: parameter.Name,
Desc: parameter.Desc,
Type: workflow3.ParameterType(parameter.Type),
Location: workflow3.ParameterLocation(parameter.Location),
IsRequired: parameter.IsRequired,
GlobalDefault: parameter.GlobalDefault,
GlobalDisable: parameter.GlobalDisable,
LocalDefault: parameter.LocalDefault,
LocalDisable: parameter.LocalDisable,
VariableRef: parameter.VariableRef,
}
if parameter.SubType != nil {
p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType))
}
if parameter.DefaultParamSource != nil {
p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource))
}
if parameter.AssistType != nil {
p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType))
}
// Check if it's an array that needs unwrapping.
if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" {
arrayItem := parameter.SubParameters[0]
p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type))
// If the "[Array Item]" is an object, its sub-parameters become the array's sub-parameters.
if arrayItem.Type == common.ParameterType_Object {
p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters))
for _, subParam := range arrayItem.SubParameters {
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
}
} else {
// The array's SubType is the Type of the "[Array Item]".
p.SubParameters = make([]*workflow3.APIParameter, 0, 1)
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem))
p.SubParameters[0].Name = "" // Remove the "[Array Item]" name.
}
} else if len(parameter.SubParameters) > 0 { // A simple object or a non-wrapped array.
p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters))
for _, subParam := range parameter.SubParameters {
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
}
}
return p
}

View File

@ -20,7 +20,7 @@ import (
"context"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/search"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crosssearch"
crosssearch "github.com/coze-dev/coze-studio/backend/crossdomain/contract/search"
"github.com/coze-dev/coze-studio/backend/domain/search/service"
)

View File

@ -25,7 +25,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"

View File

@ -21,8 +21,8 @@ import (
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
"github.com/coze-dev/coze-studio/backend/domain/memory/variables/entity"
variables "github.com/coze-dev/coze-studio/backend/domain/memory/variables/service"
)
@ -65,3 +65,16 @@ func (s *impl) DecryptSysUUIDKey(ctx context.Context, encryptSysUUIDKey string)
ConnectorID: m.ConnectorID,
}
}
func (s *impl) GetVariableChannelInstance(ctx context.Context, e *model.UserVariableMeta, keywords []string, varChannel *project_memory.VariableChannel) ([]*kvmemory.KVItem, error) {
m := entity.NewUserVariableMeta(e)
return s.DomainSVC.GetVariableChannelInstance(ctx, m, keywords, varChannel)
}
func (s *impl) GetProjectVariablesMeta(ctx context.Context, projectID string, version string) (*entity.VariablesMeta, error) {
return s.DomainSVC.GetProjectVariablesMeta(ctx, projectID, version)
}
func (s *impl) GetAgentVariableMeta(ctx context.Context, agentID int64, version string) (*entity.VariablesMeta, error) {
return s.DomainSVC.GetAgentVariableMeta(ctx, agentID, version)
}

View File

@ -23,7 +23,8 @@ import (
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
@ -50,16 +51,6 @@ func (i *impl) WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy
return i.DomainSVC.WorkflowAsModelTool(ctx, policies)
}
func (i *impl) PublishWorkflow(ctx context.Context, info *vo.PublishPolicy) (err error) {
return i.DomainSVC.Publish(ctx, info)
}
func (i *impl) DeleteWorkflow(ctx context.Context, id int64) error {
return i.DomainSVC.Delete(ctx, &vo.DeletePolicy{
ID: ptr.Of(id),
})
}
func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, config *vo.ReleaseWorkflowConfig) ([]*vo.ValidateIssue, error) {
return i.DomainSVC.ReleaseApplicationWorkflows(ctx, appID, config)
}
@ -67,15 +58,15 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
func (i *impl) WithResumeToolWorkflow(resumingEvent *workflowEntity.ToolInterruptEvent, resumeData string, allInterruptEvents map[string]*workflowEntity.ToolInterruptEvent) einoCompose.Option {
return i.DomainSVC.WithResumeToolWorkflow(resumingEvent, resumeData, allInterruptEvents)
}
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
func (i *impl) SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) {
return i.DomainSVC.SyncExecute(ctx, config, input)
}
func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option {
func (i *impl) WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option {
return i.DomainSVC.WithExecuteConfig(cfg)
}
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
return i.DomainSVC.WithMessagePipe()
}

View File

@ -1,447 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"fmt"
"strings"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
type DatabaseRepository struct {
client service.Database
}
func NewDatabaseRepository(client service.Database) *DatabaseRepository {
return &DatabaseRepository{
client: client,
}
}
func (d *DatabaseRepository) Execute(ctx context.Context, request *nodedatabase.CustomSQLRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Custom,
SQL: &request.SQL,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SQLParams = make([]*database.SQLParamVal, 0, len(request.Params))
for i := range request.Params {
param := request.Params[i]
req.SQLParams = append(req.SQLParams, &database.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &param.Value,
ISNull: param.IsNull,
})
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
// if rows affected is nil use 0 instead
if response.RowsAffected == nil {
response.RowsAffected = ptr.Of(int64(0))
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Delete(ctx context.Context, request *nodedatabase.DeleteRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Delete,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Query(ctx context.Context, request *nodedatabase.QueryRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Select,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.SelectFieldList = &database.SelectFieldList{FieldID: make([]string, 0, len(request.SelectFields))}
for i := range request.SelectFields {
req.SelectFieldList.FieldID = append(req.SelectFieldList.FieldID, request.SelectFields[i])
}
req.OrderByList = make([]database.OrderBy, 0)
for i := range request.OrderClauses {
clause := request.OrderClauses[i]
req.OrderByList = append(req.OrderByList, database.OrderBy{
Field: clause.FieldID,
Direction: toOrderDirection(clause.IsAsc),
})
}
if request.ConditionGroup != nil {
req.Condition, req.SQLParams, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
}
limit := request.Limit
req.Limit = &limit
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Update(ctx context.Context, request *nodedatabase.UpdateRequest) (*nodedatabase.Response, error) {
var (
err error
condition *database.ComplexCondition
params []*database.SQLParamVal
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Update,
SQLParams: make([]*database.SQLParamVal, 0),
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.UserID = conv.Int64ToStr(*uid)
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
if request.ConditionGroup != nil {
condition, params, err = buildComplexCondition(request.ConditionGroup)
if err != nil {
return nil, err
}
req.Condition = condition
req.SQLParams = append(req.SQLParams, params...)
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) Insert(ctx context.Context, request *nodedatabase.InsertRequest) (*nodedatabase.Response, error) {
var (
err error
databaseInfoID = request.DatabaseInfoID
tableType = ternary.IFElse[table.TableType](request.IsDebugRun, table.TableType_DraftTable, table.TableType_OnlineTable)
)
if request.IsDebugRun {
databaseInfoID, err = d.getDraftTableID(ctx, databaseInfoID)
if err != nil {
return nil, err
}
}
req := &service.ExecuteSQLRequest{
DatabaseID: databaseInfoID,
OperateType: database.OperateType_Insert,
TableType: tableType,
UserID: request.UserID,
ConnectorID: ptr.Of(request.ConnectorID),
}
req.UpsertRows, req.SQLParams, err = resolveUpsertRow(request.Fields)
if err != nil {
return nil, err
}
response, err := d.client.ExecuteSQL(ctx, req)
if err != nil {
return nil, err
}
return toNodeDateBaseResponse(response), nil
}
func (d *DatabaseRepository) getDraftTableID(ctx context.Context, onlineID int64) (int64, error) {
resp, err := d.client.GetDraftDatabaseByOnlineID(ctx, &service.GetDraftDatabaseByOnlineIDRequest{OnlineID: onlineID})
if err != nil {
return 0, err
}
return resp.Database.ID, nil
}
func buildComplexCondition(conditionGroup *nodedatabase.ConditionGroup) (*database.ComplexCondition, []*database.SQLParamVal, error) {
condition := &database.ComplexCondition{}
logic, err := toLogic(conditionGroup.Relation)
if err != nil {
return nil, nil, err
}
condition.Logic = logic
params := make([]*database.SQLParamVal, 0)
for i := range conditionGroup.Conditions {
var (
nCond = conditionGroup.Conditions[i]
vals []*database.SQLParamVal
dCond = &database.Condition{
Left: nCond.Left,
}
)
opt, err := toOperation(nCond.Operator)
if err != nil {
return nil, nil, err
}
dCond.Operation = opt
if isNullOrNotNull(opt) {
condition.Conditions = append(condition.Conditions, dCond)
continue
}
dCond.Right, vals, err = resolveRightValue(opt, nCond.Right)
if err != nil {
return nil, nil, err
}
condition.Conditions = append(condition.Conditions, dCond)
params = append(params, vals...)
}
return condition, params, nil
}
func toMapStringAny(m map[string]string) map[string]any {
ret := make(map[string]any, len(m))
for k, v := range m {
ret[k] = v
}
return ret
}
func toOperation(operator nodedatabase.Operator) (database.Operation, error) {
switch operator {
case nodedatabase.OperatorEqual:
return database.Operation_EQUAL, nil
case nodedatabase.OperatorNotEqual:
return database.Operation_NOT_EQUAL, nil
case nodedatabase.OperatorGreater:
return database.Operation_GREATER_THAN, nil
case nodedatabase.OperatorGreaterOrEqual:
return database.Operation_GREATER_EQUAL, nil
case nodedatabase.OperatorLesser:
return database.Operation_LESS_THAN, nil
case nodedatabase.OperatorLesserOrEqual:
return database.Operation_LESS_EQUAL, nil
case nodedatabase.OperatorIn:
return database.Operation_IN, nil
case nodedatabase.OperatorNotIn:
return database.Operation_NOT_IN, nil
case nodedatabase.OperatorIsNotNull:
return database.Operation_IS_NOT_NULL, nil
case nodedatabase.OperatorIsNull:
return database.Operation_IS_NULL, nil
case nodedatabase.OperatorLike:
return database.Operation_LIKE, nil
case nodedatabase.OperatorNotLike:
return database.Operation_NOT_LIKE, nil
default:
return database.Operation(0), fmt.Errorf("invalid operator %v", operator)
}
}
func resolveRightValue(operator database.Operation, right any) (string, []*database.SQLParamVal, error) {
if isInOrNotIn(operator) {
var (
vals = make([]*database.SQLParamVal, 0)
anyVals = make([]any, 0)
commas = make([]string, 0, len(anyVals))
)
anyVals = right.([]any)
for i := range anyVals {
v := cast.ToString(anyVals[i])
vals = append(vals, &database.SQLParamVal{ValueType: table.FieldItemType_Text, Value: &v})
commas = append(commas, "?")
}
value := "(" + strings.Join(commas, ",") + ")"
return value, vals, nil
}
rightValue, err := cast.ToStringE(right)
if err != nil {
return "", nil, err
}
if isLikeOrNotLike(operator) {
var (
value = "?"
v = "%s" + rightValue + "%s"
)
return value, []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &v}}, nil
}
return "?", []*database.SQLParamVal{{ValueType: table.FieldItemType_Text, Value: &rightValue}}, nil
}
func resolveUpsertRow(fields map[string]any) ([]*database.UpsertRow, []*database.SQLParamVal, error) {
upsertRow := &database.UpsertRow{Records: make([]*database.Record, 0, len(fields))}
params := make([]*database.SQLParamVal, 0)
for key, value := range fields {
val, err := cast.ToStringE(value)
if err != nil {
return nil, nil, err
}
record := &database.Record{
FieldId: key,
FieldValue: "?",
}
upsertRow.Records = append(upsertRow.Records, record)
params = append(params, &database.SQLParamVal{
ValueType: table.FieldItemType_Text,
Value: &val,
})
}
return []*database.UpsertRow{upsertRow}, params, nil
}
func isNullOrNotNull(opt database.Operation) bool {
return opt == database.Operation_IS_NOT_NULL || opt == database.Operation_IS_NULL
}
func isLikeOrNotLike(opt database.Operation) bool {
return opt == database.Operation_LIKE || opt == database.Operation_NOT_LIKE
}
func isInOrNotIn(opt database.Operation) bool {
return opt == database.Operation_IN || opt == database.Operation_NOT_IN
}
func toOrderDirection(isAsc bool) table.SortDirection {
if isAsc {
return table.SortDirection_ASC
}
return table.SortDirection_Desc
}
func toLogic(relation nodedatabase.ClauseRelation) (database.Logic, error) {
switch relation {
case nodedatabase.ClauseRelationOR:
return database.Logic_Or, nil
case nodedatabase.ClauseRelationAND:
return database.Logic_And, nil
default:
return database.Logic(0), fmt.Errorf("invalid relation %v", relation)
}
}
func toNodeDateBaseResponse(response *service.ExecuteSQLResponse) *nodedatabase.Response {
objects := make([]nodedatabase.Object, 0, len(response.Records))
for i := range response.Records {
objects = append(objects, response.Records[i])
}
return &nodedatabase.Response{
Objects: objects,
RowNumber: response.RowsAffected,
}
}

View File

@ -1,224 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"testing"
"github.com/spf13/cast"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
nodedatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
mockDatabase "github.com/coze-dev/coze-studio/backend/internal/mock/domain/memory/database"
)
func mockExecuteSQL(t *testing.T) func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
return func(ctx context.Context, request *service.ExecuteSQLRequest) (*service.ExecuteSQLResponse, error) {
if request.OperateType == database.OperateType_Custom {
assert.Equal(t, *request.SQL, "select * from table where v1=? and v2=?")
rs := make([]string, 0)
for idx := range request.SQLParams {
rs = append(rs, *request.SQLParams[idx].Value)
}
assert.Equal(t, rs, []string{"1", "2"})
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2"},
},
}, nil
}
if request.OperateType == database.OperateType_Select {
sFields := []string{"v1", "v2", "v3", "v4"}
assert.Equal(t, request.SelectFieldList.FieldID, sFields)
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
rowsAffected := int64(10)
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
},
RowsAffected: &rowsAffected,
}, nil
}
if request.OperateType == database.OperateType_Delete {
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_NOT_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
rowsAffected := int64(10)
return &service.ExecuteSQLResponse{
Records: []map[string]any{
{"v1": "1", "v2": "2", "v3": "3", "v4": "4"},
},
RowsAffected: &rowsAffected,
}, nil
}
if request.OperateType == database.OperateType_Insert {
records := request.UpsertRows[0].Records
ret := map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "44aacc",
}
for idx := range records {
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
}
}
if request.OperateType == database.OperateType_Update {
records := request.UpsertRows[0].Records
ret := map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "aabbcc",
}
for idx := range records {
assert.Equal(t, *request.SQLParams[idx].Value, cast.ToString(ret[records[idx].FieldId]))
}
request.SQLParams = request.SQLParams[len(records):]
cond := request.Condition.Conditions[1] // in
assert.Equal(t, "(?,?)", cond.Right)
assert.Equal(t, database.Operation_IN, cond.Operation)
assert.Equal(t, "v2_1", *request.SQLParams[1].Value)
assert.Equal(t, "v2_2", *request.SQLParams[2].Value)
assert.Equal(t, "%sv4%s", *request.SQLParams[3].Value)
}
return &service.ExecuteSQLResponse{}, nil
}
}
func TestDatabase_Database(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient := mockDatabase.NewMockDatabase(ctrl)
defer ctrl.Finish()
ds := DatabaseRepository{
client: mockClient,
}
mockClient.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(mockExecuteSQL(t)).AnyTimes()
t.Run("execute", func(t *testing.T) {
response, err := ds.Execute(context.Background(), &nodedatabase.CustomSQLRequest{
DatabaseInfoID: 1,
SQL: "select * from table where v1=? and v2=?",
Params: []nodedatabase.SQLParam{
nodedatabase.SQLParam{Value: "1"},
nodedatabase.SQLParam{Value: "2"},
},
})
assert.Nil(t, err)
assert.Equal(t, response.Objects, []nodedatabase.Object{
{"v1": "1", "v2": "2"},
})
})
t.Run("select", func(t *testing.T) {
req := &nodedatabase.QueryRequest{
DatabaseInfoID: 1,
SelectFields: []string{"v1", "v2", "v3", "v4"},
Limit: 10,
OrderClauses: []*nodedatabase.OrderClause{
{FieldID: "v1", IsAsc: true},
{FieldID: "v2", IsAsc: false},
},
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
}
response, err := ds.Query(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, *response.RowNumber, int64(10))
})
t.Run("delete", func(t *testing.T) {
req := &nodedatabase.DeleteRequest{
DatabaseInfoID: 1,
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorNotIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNotNull},
{Left: "v4", Operator: nodedatabase.OperatorNotLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
}
response, err := ds.Delete(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, *response.RowNumber, int64(10))
})
t.Run("insert", func(t *testing.T) {
req := &nodedatabase.InsertRequest{
DatabaseInfoID: 1,
Fields: map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "44aacc",
},
}
_, err := ds.Insert(context.Background(), req)
assert.Nil(t, err)
})
t.Run("update", func(t *testing.T) {
req := &nodedatabase.UpdateRequest{
DatabaseInfoID: 1,
ConditionGroup: &nodedatabase.ConditionGroup{
Conditions: []*nodedatabase.Condition{
{Left: "v1", Operator: nodedatabase.OperatorEqual, Right: "1"},
{Left: "v2", Operator: nodedatabase.OperatorIn, Right: []any{"v2_1", "v2_2"}},
{Left: "v3", Operator: nodedatabase.OperatorIsNull},
{Left: "v4", Operator: nodedatabase.OperatorLike, Right: "v4"},
},
Relation: nodedatabase.ClauseRelationOR,
},
Fields: map[string]interface{}{
"v1": "1",
"v2": 2,
"v3": 33,
"v4": "aabbcc",
},
}
_, err := ds.Update(context.Background(), req)
assert.Nil(t, err)
})
}

View File

@ -1,217 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
domainknowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type Knowledge struct {
client domainknowledge.Knowledge
idGen idgen.IDGenerator
}
func NewKnowledgeRepository(client domainknowledge.Knowledge, idGen idgen.IDGenerator) *Knowledge {
return &Knowledge{
client: client,
idGen: idGen,
}
}
func (k *Knowledge) Store(ctx context.Context, document *crossknowledge.CreateDocumentRequest) (*crossknowledge.CreateDocumentResponse, error) {
var (
ps *entity.ParsingStrategy
cs = &entity.ChunkingStrategy{}
)
if document.ParsingStrategy == nil {
return nil, errors.New("document parsing strategy is required")
}
if document.ChunkingStrategy == nil {
return nil, errors.New("document chunking strategy is required")
}
if document.ParsingStrategy.ParseMode == crossknowledge.AccurateParseMode {
ps = &entity.ParsingStrategy{}
ps.ExtractImage = document.ParsingStrategy.ExtractImage
ps.ExtractTable = document.ParsingStrategy.ExtractTable
ps.ImageOCR = document.ParsingStrategy.ImageOCR
}
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
if err != nil {
return nil, err
}
cs.ChunkType = chunkType
cs.Separator = document.ChunkingStrategy.Separator
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
cs.Overlap = document.ChunkingStrategy.Overlap
req := &entity.Document{
Info: knowledge.Info{
Name: document.FileName,
},
KnowledgeID: document.KnowledgeID,
Type: knowledge.DocumentTypeText,
URL: document.FileURL,
Source: entity.DocumentSourceLocal,
ParsingStrategy: ps,
ChunkingStrategy: cs,
FileExtension: document.FileExtension,
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.Info.CreatorID = *uid
}
response, err := k.client.CreateDocument(ctx, &domainknowledge.CreateDocumentRequest{
Documents: []*entity.Document{req},
})
if err != nil {
return nil, err
}
kCResponse := &crossknowledge.CreateDocumentResponse{
FileURL: document.FileURL,
DocumentID: response.Documents[0].Info.ID,
FileName: response.Documents[0].Info.Name,
}
return kCResponse, nil
}
func (k *Knowledge) Retrieve(ctx context.Context, r *crossknowledge.RetrieveRequest) (*crossknowledge.RetrieveResponse, error) {
rs := &entity.RetrievalStrategy{}
if r.RetrievalStrategy != nil {
rs.TopK = r.RetrievalStrategy.TopK
rs.MinScore = r.RetrievalStrategy.MinScore
searchType, err := toSearchType(r.RetrievalStrategy.SearchType)
if err != nil {
return nil, err
}
rs.SearchType = searchType
rs.EnableQueryRewrite = r.RetrievalStrategy.EnableQueryRewrite
rs.EnableRerank = r.RetrievalStrategy.EnableRerank
rs.EnableNL2SQL = r.RetrievalStrategy.EnableNL2SQL
}
req := &domainknowledge.RetrieveRequest{
Query: r.Query,
KnowledgeIDs: r.KnowledgeIDs,
Strategy: rs,
}
response, err := k.client.Retrieve(ctx, req)
if err != nil {
return nil, err
}
ss := make([]*crossknowledge.Slice, 0, len(response.RetrieveSlices))
for _, s := range response.RetrieveSlices {
if s.Slice == nil {
continue
}
ss = append(ss, &crossknowledge.Slice{
DocumentID: strconv.FormatInt(s.Slice.DocumentID, 10),
Output: s.Slice.GetSliceContent(),
})
}
return &crossknowledge.RetrieveResponse{
Slices: ss,
}, nil
}
func (k *Knowledge) Delete(ctx context.Context, r *crossknowledge.DeleteDocumentRequest) (*crossknowledge.DeleteDocumentResponse, error) {
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
}
err = k.client.DeleteDocument(ctx, &domainknowledge.DeleteDocumentRequest{
DocumentID: docID,
})
if err != nil {
return &crossknowledge.DeleteDocumentResponse{IsSuccess: false}, err
}
return &crossknowledge.DeleteDocumentResponse{IsSuccess: true}, nil
}
func (k *Knowledge) ListKnowledgeDetail(ctx context.Context, req *crossknowledge.ListKnowledgeDetailRequest) (*crossknowledge.ListKnowledgeDetailResponse, error) {
response, err := k.client.MGetKnowledgeByID(ctx, &domainknowledge.MGetKnowledgeByIDRequest{
KnowledgeIDs: req.KnowledgeIDs,
})
if err != nil {
return nil, err
}
resp := &crossknowledge.ListKnowledgeDetailResponse{
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *crossknowledge.KnowledgeDetail {
return &crossknowledge.KnowledgeDetail{
ID: a.ID,
Name: a.Name,
Description: a.Description,
IconURL: a.IconURL,
FormatType: int64(a.Type),
}
}),
}
return resp, nil
}
func toSearchType(typ crossknowledge.SearchType) (knowledge.SearchType, error) {
switch typ {
case crossknowledge.SearchTypeSemantic:
return knowledge.SearchTypeSemantic, nil
case crossknowledge.SearchTypeFullText:
return knowledge.SearchTypeFullText, nil
case crossknowledge.SearchTypeHybrid:
return knowledge.SearchTypeHybrid, nil
default:
return 0, fmt.Errorf("unknown search type: %v", typ)
}
}
func toChunkType(typ crossknowledge.ChunkType) (parser.ChunkType, error) {
switch typ {
case crossknowledge.ChunkTypeDefault:
return parser.ChunkTypeDefault, nil
case crossknowledge.ChunkTypeCustom:
return parser.ChunkTypeCustom, nil
case crossknowledge.ChunkTypeLeveled:
return parser.ChunkTypeLeveled, nil
default:
return 0, fmt.Errorf("unknown chunk type: %v", typ)
}
}

View File

@ -1,533 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"context"
"fmt"
"strconv"
"github.com/getkin/kin-openapi/openapi3"
"golang.org/x/exp/maps"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/application/base/pluginutil"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
entity2 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type pluginService struct {
client service.PluginService
tos storage.Storage
}
func NewPluginService(client service.PluginService, tos storage.Storage) crossplugin.Service {
return &pluginService{client: client, tos: tos}
}
type pluginInfo struct {
*entity.PluginInfo
LatestVersion *string
}
func (t *pluginService) getPluginsWithTools(ctx context.Context, pluginEntity *crossplugin.Entity, toolIDs []int64, isDraft bool) (
_ *pluginInfo, toolsInfo []*entity.ToolInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var pluginsInfo []*entity.PluginInfo
var latestPluginInfo *entity.PluginInfo
pluginID := pluginEntity.PluginID
if isDraft {
plugins, err := t.client.MGetDraftPlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
plugins, err := t.client.MGetOnlinePlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
} else {
plugins, err := t.client.MGetVersionPlugins(ctx, []entity.VersionPlugin{
{PluginID: pluginID, Version: *pluginEntity.PluginVersion},
})
if err != nil {
return nil, nil, err
}
pluginsInfo = plugins
onlinePlugins, err := t.client.MGetOnlinePlugins(ctx, []int64{pluginID})
if err != nil {
return nil, nil, err
}
for _, pi := range onlinePlugins {
if pi.ID == pluginID {
latestPluginInfo = pi
break
}
}
}
var pInfo *entity.PluginInfo
for _, p := range pluginsInfo {
if p.ID == pluginID {
pInfo = p
break
}
}
if pInfo == nil {
return nil, nil, vo.NewError(errno.ErrPluginIDNotFound, errorx.KV("id", strconv.FormatInt(pluginID, 10)))
}
if isDraft {
tools, err := t.client.MGetDraftTools(ctx, toolIDs)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
} else if pluginEntity.PluginVersion == nil || (pluginEntity.PluginVersion != nil && *pluginEntity.PluginVersion == "") {
tools, err := t.client.MGetOnlineTools(ctx, toolIDs)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
} else {
eVersionTools := slices.Transform(toolIDs, func(tid int64) entity.VersionTool {
return entity.VersionTool{
ToolID: tid,
Version: *pluginEntity.PluginVersion,
}
})
tools, err := t.client.MGetVersionTools(ctx, eVersionTools)
if err != nil {
return nil, nil, err
}
toolsInfo = tools
}
if latestPluginInfo != nil {
return &pluginInfo{PluginInfo: pInfo, LatestVersion: latestPluginInfo.Version}, toolsInfo, nil
}
return &pluginInfo{PluginInfo: pInfo}, toolsInfo, nil
}
func (t *pluginService) GetPluginToolsInfo(ctx context.Context, req *crossplugin.ToolsInfoRequest) (
_ *crossplugin.ToolsInfoResponse, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var toolsInfo []*entity.ToolInfo
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
pInfo, toolsInfo, err := t.getPluginsWithTools(ctx, &crossplugin.Entity{PluginID: req.PluginEntity.PluginID, PluginVersion: req.PluginEntity.PluginVersion}, req.ToolIDs, isDraft)
if err != nil {
return nil, err
}
url, err := t.tos.GetObjectUrl(ctx, pInfo.GetIconURI())
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrTOSError, err)
}
response := &crossplugin.ToolsInfoResponse{
PluginID: pInfo.ID,
SpaceID: pInfo.SpaceID,
Version: pInfo.GetVersion(),
PluginName: pInfo.GetName(),
Description: pInfo.GetDesc(),
IconURL: url,
PluginType: int64(pInfo.PluginType),
ToolInfoList: make(map[int64]crossplugin.ToolInfo),
LatestVersion: pInfo.LatestVersion,
IsOfficial: pInfo.IsOfficial(),
AppID: pInfo.GetAPPID(),
}
for _, tf := range toolsInfo {
inputs, err := tf.ToReqAPIParameter()
if err != nil {
return nil, err
}
outputs, err := tf.ToRespAPIParameter()
if err != nil {
return nil, err
}
toolExample := pInfo.GetToolExample(ctx, tf.GetName())
var (
requestExample string
responseExample string
)
if toolExample != nil {
requestExample = toolExample.RequestExample
responseExample = toolExample.ResponseExample
}
response.ToolInfoList[tf.ID] = crossplugin.ToolInfo{
ToolID: tf.ID,
ToolName: tf.GetName(),
Inputs: slices.Transform(inputs, toWorkflowAPIParameter),
Outputs: slices.Transform(outputs, toWorkflowAPIParameter),
Description: tf.GetDesc(),
DebugExample: &crossplugin.DebugExample{
ReqExample: requestExample,
RespExample: responseExample,
},
}
}
return response, nil
}
func (t *pluginService) GetPluginInvokableTools(ctx context.Context, req *crossplugin.ToolsInvokableRequest) (
_ map[int64]crossplugin.InvokableTool, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var toolsInfo []*entity.ToolInfo
isDraft := req.IsDraft || (req.PluginEntity.PluginVersion != nil && *req.PluginEntity.PluginVersion == "0")
pInfo, toolsInfo, err := t.getPluginsWithTools(ctx, &crossplugin.Entity{
PluginID: req.PluginEntity.PluginID,
PluginVersion: req.PluginEntity.PluginVersion,
}, maps.Keys(req.ToolsInvokableInfo), isDraft)
if err != nil {
return nil, err
}
result := map[int64]crossplugin.InvokableTool{}
for _, tf := range toolsInfo {
tl := &pluginInvokeTool{
pluginEntity: crossplugin.Entity{
PluginID: pInfo.ID,
PluginVersion: pInfo.Version,
},
client: t.client,
toolInfo: tf,
IsDraft: isDraft,
}
if r, ok := req.ToolsInvokableInfo[tf.ID]; ok && (r.RequestAPIParametersConfig != nil && r.ResponseAPIParametersConfig != nil) {
reqPluginCommonAPIParameters := slices.Transform(r.RequestAPIParametersConfig, toPluginCommonAPIParameter)
respPluginCommonAPIParameters := slices.Transform(r.ResponseAPIParametersConfig, toPluginCommonAPIParameter)
tl.toolOperation, err = pluginutil.APIParamsToOpenapiOperation(reqPluginCommonAPIParameters, respPluginCommonAPIParameters)
if err != nil {
return nil, err
}
tl.toolOperation.OperationID = tf.Operation.OperationID
tl.toolOperation.Summary = tf.Operation.Summary
}
result[tf.ID] = tl
}
return result, nil
}
func (t *pluginService) ExecutePlugin(ctx context.Context, input map[string]any, pe *crossplugin.Entity,
toolID int64, cfg crossplugin.ExecConfig) (map[string]any, error) {
args, err := sonic.MarshalString(input)
if err != nil {
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
var uID string
if cfg.AgentID != nil {
uID = cfg.ConnectorUID
} else {
uID = conv.Int64ToStr(cfg.Operator)
}
req := &service.ExecuteToolRequest{
UserID: uID,
PluginID: pe.PluginID,
ToolID: toolID,
ExecScene: plugin.ExecSceneOfWorkflow,
ArgumentsInJson: args,
ExecDraftTool: pe.PluginVersion == nil || *pe.PluginVersion == "0",
}
execOpts := []entity.ExecuteToolOpt{
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
}
if pe.PluginVersion != nil {
execOpts = append(execOpts, plugin.WithToolVersion(*pe.PluginVersion))
}
r, err := t.client.ExecuteTool(ctx, req, execOpts...)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
pluginTIE, ok := extra.(*plugin.ToolInterruptEvent)
if !ok {
return nil, vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
}
var eventType workflow3.EventType
switch pluginTIE.Event {
case plugin.InterruptEventTypeOfToolNeedOAuth:
eventType = workflow3.EventType_WorkflowOauthPlugin
default:
return nil, vo.WrapError(errno.ErrPluginAPIErr,
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return nil, vo.WrapError(errno.ErrIDGenError, err)
}
ie := &entity2.InterruptEvent{
ID: id,
InterruptData: pluginTIE.ToolNeedOAuth.Message,
EventType: eventType,
}
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
interruptData := ie.InterruptData
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
}
return nil, err
}
var output map[string]any
err = sonic.UnmarshalString(r.TrimmedResp, &output)
if err != nil {
return nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
return output, nil
}
type pluginInvokeTool struct {
pluginEntity crossplugin.Entity
client service.PluginService
toolInfo *entity.ToolInfo
toolOperation *openapi3.Operation
IsDraft bool
}
func (p *pluginInvokeTool) Info(ctx context.Context) (_ *schema.ToolInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrPluginAPIErr, err)
}
}()
var parameterInfo map[string]*schema.ParameterInfo
if p.toolOperation != nil {
parameterInfo, err = plugin.NewOpenapi3Operation(p.toolOperation).ToEinoSchemaParameterInfo(ctx)
} else {
parameterInfo, err = p.toolInfo.Operation.ToEinoSchemaParameterInfo(ctx)
}
if err != nil {
return nil, err
}
return &schema.ToolInfo{
Name: p.toolInfo.GetName(),
Desc: p.toolInfo.GetDesc(),
ParamsOneOf: schema.NewParamsOneOfByParams(parameterInfo),
}, nil
}
func (p *pluginInvokeTool) PluginInvoke(ctx context.Context, argumentsInJSON string, cfg crossplugin.ExecConfig) (string, error) {
req := &service.ExecuteToolRequest{
UserID: conv.Int64ToStr(cfg.Operator),
PluginID: p.pluginEntity.PluginID,
ToolID: p.toolInfo.ID,
ExecScene: plugin.ExecSceneOfWorkflow,
ArgumentsInJson: argumentsInJSON,
ExecDraftTool: p.IsDraft,
}
execOpts := []entity.ExecuteToolOpt{
plugin.WithInvalidRespProcessStrategy(plugin.InvalidResponseProcessStrategyOfReturnDefault),
}
if p.pluginEntity.PluginVersion != nil {
execOpts = append(execOpts, plugin.WithToolVersion(*p.pluginEntity.PluginVersion))
}
if p.toolOperation != nil {
execOpts = append(execOpts, plugin.WithOpenapiOperation(plugin.NewOpenapi3Operation(p.toolOperation)))
}
r, err := p.client.ExecuteTool(ctx, req, execOpts...)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
pluginTIE, ok := extra.(*plugin.ToolInterruptEvent)
if !ok {
return "", vo.WrapError(errno.ErrPluginAPIErr, fmt.Errorf("expects ToolInterruptEvent, got %T", extra))
}
var eventType workflow3.EventType
switch pluginTIE.Event {
case plugin.InterruptEventTypeOfToolNeedOAuth:
eventType = workflow3.EventType_WorkflowOauthPlugin
default:
return "", vo.WrapError(errno.ErrPluginAPIErr,
fmt.Errorf("unsupported interrupt event type: %s", pluginTIE.Event))
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return "", vo.WrapError(errno.ErrIDGenError, err)
}
ie := &entity2.InterruptEvent{
ID: id,
InterruptData: pluginTIE.ToolNeedOAuth.Message,
EventType: eventType,
}
tie := &entity2.ToolInterruptEvent{
ToolCallID: compose.GetToolCallID(ctx),
ToolName: p.toolInfo.GetName(),
InterruptEvent: ie,
}
// temporarily replace interrupt with real error, until frontend can handle plugin oauth interrupt
_ = tie
interruptData := ie.InterruptData
return "", vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
}
return "", err
}
return r.TrimmedResp, nil
}
func toPluginCommonAPIParameter(parameter *workflow3.APIParameter) *common.APIParameter {
if parameter == nil {
return nil
}
p := &common.APIParameter{
ID: parameter.ID,
Name: parameter.Name,
Desc: parameter.Desc,
Type: common.ParameterType(parameter.Type),
Location: common.ParameterLocation(parameter.Location),
IsRequired: parameter.IsRequired,
GlobalDefault: parameter.GlobalDefault,
GlobalDisable: parameter.GlobalDisable,
LocalDefault: parameter.LocalDefault,
LocalDisable: parameter.LocalDisable,
VariableRef: parameter.VariableRef,
}
if parameter.SubType != nil {
p.SubType = ptr.Of(common.ParameterType(*parameter.SubType))
}
if parameter.DefaultParamSource != nil {
p.DefaultParamSource = ptr.Of(common.DefaultParamSource(*parameter.DefaultParamSource))
}
if parameter.AssistType != nil {
p.AssistType = ptr.Of(common.AssistParameterType(*parameter.AssistType))
}
if len(parameter.SubParameters) > 0 {
p.SubParameters = make([]*common.APIParameter, 0, len(parameter.SubParameters))
for _, subParam := range parameter.SubParameters {
p.SubParameters = append(p.SubParameters, toPluginCommonAPIParameter(subParam))
}
}
return p
}
func toWorkflowAPIParameter(parameter *common.APIParameter) *workflow3.APIParameter {
if parameter == nil {
return nil
}
p := &workflow3.APIParameter{
ID: parameter.ID,
Name: parameter.Name,
Desc: parameter.Desc,
Type: workflow3.ParameterType(parameter.Type),
Location: workflow3.ParameterLocation(parameter.Location),
IsRequired: parameter.IsRequired,
GlobalDefault: parameter.GlobalDefault,
GlobalDisable: parameter.GlobalDisable,
LocalDefault: parameter.LocalDefault,
LocalDisable: parameter.LocalDisable,
VariableRef: parameter.VariableRef,
}
if parameter.SubType != nil {
p.SubType = ptr.Of(workflow3.ParameterType(*parameter.SubType))
}
if parameter.DefaultParamSource != nil {
p.DefaultParamSource = ptr.Of(workflow3.DefaultParamSource(*parameter.DefaultParamSource))
}
if parameter.AssistType != nil {
p.AssistType = ptr.Of(workflow3.AssistParameterType(*parameter.AssistType))
}
// Check if it's an array that needs unwrapping.
if parameter.Type == common.ParameterType_Array && len(parameter.SubParameters) == 1 && parameter.SubParameters[0].Name == "[Array Item]" {
arrayItem := parameter.SubParameters[0]
p.SubType = ptr.Of(workflow3.ParameterType(arrayItem.Type))
// If the "[Array Item]" is an object, its sub-parameters become the array's sub-parameters.
if arrayItem.Type == common.ParameterType_Object {
p.SubParameters = make([]*workflow3.APIParameter, 0, len(arrayItem.SubParameters))
for _, subParam := range arrayItem.SubParameters {
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
}
} else {
// The array's SubType is the Type of the "[Array Item]".
p.SubParameters = make([]*workflow3.APIParameter, 0, 1)
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(arrayItem))
p.SubParameters[0].Name = "" // Remove the "[Array Item]" name.
}
} else if len(parameter.SubParameters) > 0 { // A simple object or a non-wrapped array.
p.SubParameters = make([]*workflow3.APIParameter, 0, len(parameter.SubParameters))
for _, subParam := range parameter.SubParameters {
p.SubParameters = append(p.SubParameters, toWorkflowAPIParameter(subParam))
}
}
return p
}

View File

@ -1,189 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"testing"
"github.com/stretchr/testify/assert"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestToWorkflowAPIParameter(t *testing.T) {
cases := []struct {
name string
param *common.APIParameter
expected *workflow3.APIParameter
}{
{
name: "nil parameter",
param: nil,
expected: nil,
},
{
name: "simple string parameter",
param: &common.APIParameter{
Name: "prompt",
Type: common.ParameterType_String,
Desc: "User's prompt",
},
expected: &workflow3.APIParameter{
Name: "prompt",
Type: workflow3.ParameterType_String,
Desc: "User's prompt",
},
},
{
name: "simple object parameter",
param: &common.APIParameter{
Name: "user_info",
Type: common.ParameterType_Object,
SubParameters: []*common.APIParameter{
{
Name: "name",
Type: common.ParameterType_String,
},
{
Name: "age",
Type: common.ParameterType_Number,
},
},
},
expected: &workflow3.APIParameter{
Name: "user_info",
Type: workflow3.ParameterType_Object,
SubParameters: []*workflow3.APIParameter{
{
Name: "name",
Type: workflow3.ParameterType_String,
},
{
Name: "age",
Type: workflow3.ParameterType_Number,
},
},
},
},
{
name: "array of strings",
param: &common.APIParameter{
Name: "tags",
Type: common.ParameterType_Array,
SubParameters: []*common.APIParameter{
{
Name: "[Array Item]",
Type: common.ParameterType_String,
},
},
},
expected: &workflow3.APIParameter{
Name: "tags",
Type: workflow3.ParameterType_Array,
SubType: ptr.Of(workflow3.ParameterType_String),
SubParameters: []*workflow3.APIParameter{
{
Type: workflow3.ParameterType_String,
},
},
},
},
{
name: "array of objects",
param: &common.APIParameter{
Name: "users",
Type: common.ParameterType_Array,
SubParameters: []*common.APIParameter{
{
Name: "[Array Item]",
Type: common.ParameterType_Object,
SubParameters: []*common.APIParameter{
{
Name: "name",
Type: common.ParameterType_String,
},
{
Name: "id",
Type: common.ParameterType_Number,
},
},
},
},
},
expected: &workflow3.APIParameter{
Name: "users",
Type: workflow3.ParameterType_Array,
SubType: ptr.Of(workflow3.ParameterType_Object),
SubParameters: []*workflow3.APIParameter{
{
Name: "name",
Type: workflow3.ParameterType_String,
},
{
Name: "id",
Type: workflow3.ParameterType_Number,
},
},
},
},
{
name: "array of array of strings",
param: &common.APIParameter{
Name: "matrix",
Type: common.ParameterType_Array,
SubParameters: []*common.APIParameter{
{
Name: "[Array Item]",
Type: common.ParameterType_Array,
SubParameters: []*common.APIParameter{
{
Name: "[Array Item]",
Type: common.ParameterType_String,
},
},
},
},
},
expected: &workflow3.APIParameter{
Name: "matrix",
Type: workflow3.ParameterType_Array,
SubType: ptr.Of(workflow3.ParameterType_Array),
SubParameters: []*workflow3.APIParameter{
{
Name: "", // Name is cleared
Type: workflow3.ParameterType_Array,
SubType: ptr.Of(workflow3.ParameterType_String),
SubParameters: []*workflow3.APIParameter{
{
Type: workflow3.ParameterType_String,
},
},
},
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
actual := toWorkflowAPIParameter(tc.param)
assert.Equal(t, tc.expected, actual)
})
}
}

View File

@ -1,74 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package search
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/domain/search/entity"
search "github.com/coze-dev/coze-studio/backend/domain/search/service"
crosssearch "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/search"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Notifier interface {
PublishWorkflowResource(ctx context.Context, OpType crosssearch.OpType, event *crosssearch.Resource) error
}
type Notify struct {
client search.ResourceEventBus
}
func NewNotify(client search.ResourceEventBus) *Notify {
return &Notify{client: client}
}
func (n *Notify) PublishWorkflowResource(ctx context.Context, op crosssearch.OpType, r *crosssearch.Resource) error {
entityResource := &entity.ResourceDocument{
ResType: common.ResType_Workflow,
ResID: r.WorkflowID,
ResSubType: r.Mode,
Name: r.Name,
SpaceID: r.SpaceID,
OwnerID: r.OwnerID,
APPID: r.APPID,
}
if r.PublishStatus != nil {
publishStatus := *r.PublishStatus
entityResource.PublishStatus = ptr.Of(common.PublishStatus(publishStatus))
entityResource.PublishTimeMS = r.PublishedAt
}
resource := &entity.ResourceDomainEvent{
OpType: entity.OpType(op),
Resource: entityResource,
}
if op == crosssearch.Created {
resource.Resource.CreateTimeMS = r.CreatedAt
resource.Resource.UpdateTimeMS = r.UpdatedAt
} else if op == crosssearch.Updated {
resource.Resource.UpdateTimeMS = r.UpdatedAt
}
err := n.client.PublishResources(ctx, resource)
if err != nil {
return err
}
return nil
}

View File

@ -29,7 +29,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
@ -74,6 +74,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
var composeOpts []compose.Option
var pipeMsgOpt compose.Option
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
var workflowMsgCloser func()
if r.containWfTool {
cfReq := crossworkflow.ExecuteConfig{
AgentID: &req.Identity.AgentID,
@ -88,7 +89,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
}
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
composeOpts = append(composeOpts, wfConfig)
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
pipeMsgOpt, workflowMsgSr, workflowMsgCloser = crossworkflow.DefaultSVC().WithMessagePipe()
composeOpts = append(composeOpts, pipeMsgOpt)
}
@ -120,6 +121,9 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
sw.Send(nil, errors.New("internal server error"))
}
if workflowMsgCloser != nil {
workflowMsgCloser()
}
sw.Close()
}()
_, _ = r.runner.Stream(ctx, req, composeOpts...)
@ -136,6 +140,7 @@ func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.Str
if swT != nil {
swT.Close()
}
wfStream.Close()
}()
for {
msg, err := wfStream.Recv()

View File

@ -31,7 +31,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"

View File

@ -26,7 +26,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
knowledgeEntity "github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"

View File

@ -56,6 +56,7 @@ func newSuggestGraph(_ context.Context, conf *Config, chatModel chatmodel.ToolCa
}
suggestPrompt := prompt.FromMessages(schema.Jinja2,
schema.SystemMessage(SUGGESTION_PROMPT_JINJA2),
schema.UserMessage("Based on the contextual information, provide three recommended questions"),
)
suggestGraph := compose.NewGraph[[]*schema.Message, *schema.Message]()

View File

@ -33,7 +33,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/service"
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"

View File

@ -28,7 +28,7 @@ import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
knowledgeEntity "github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
)

View File

@ -27,7 +27,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"

View File

@ -26,11 +26,11 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
pluginEntity "github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/service"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
@ -89,7 +89,7 @@ func (pr *toolPreCallConf) toolPreRetrieve(ctx context.Context, ar *AgentRequest
logs.CtxErrorf(ctx, "Failed to unmarshal json arguments: %s", item.Arguments)
return nil, err
}
execResp, _, err := crossworkflow.DefaultSVC().SyncExecuteWorkflow(ctx, vo.ExecuteConfig{
execResp, _, err := crossworkflow.DefaultSVC().SyncExecuteWorkflow(ctx, workflowModel.ExecuteConfig{
ID: item.PluginID,
ConnectorID: ar.Identity.ConnectorID,
ConnectorUID: ar.UserID,

View File

@ -25,7 +25,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/kvmemory"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"

View File

@ -20,7 +20,8 @@ import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
@ -36,7 +37,7 @@ func newWorkflowTools(ctx context.Context, conf *workflowConfig) ([]workflow.Too
id := info.GetWorkflowId()
policies = append(policies, &vo.GetPolicy{
ID: id,
QType: vo.FromLatestVersion,
QType: workflowModel.FromLatestVersion,
})
}

View File

@ -102,6 +102,18 @@ func (sa *SingleAgentDraftDAO) MGet(ctx context.Context, agentIDs []int64) ([]*e
return dos, nil
}
func (sa *SingleAgentDraftDAO) Save(ctx context.Context, agentInfo *entity.SingleAgent) (err error) {
po := sa.singleAgentDraftDo2Po(agentInfo)
singleAgentDAOModel := sa.dbQuery.SingleAgentDraft
err = singleAgentDAOModel.WithContext(ctx).Where(singleAgentDAOModel.AgentID.Eq(agentInfo.AgentID)).Save(po)
if err != nil {
return errorx.WrapByCode(err, errno.ErrAgentUpdateCode)
}
return nil
}
func (sa *SingleAgentDraftDAO) Update(ctx context.Context, agentInfo *entity.SingleAgent) (err error) {
po := sa.singleAgentDraftDo2Po(agentInfo)
singleAgentDAOModel := sa.dbQuery.SingleAgentDraft

View File

@ -46,6 +46,7 @@ type SingleAgentDraftRepo interface {
MGet(ctx context.Context, agentIDs []int64) ([]*entity.SingleAgent, error)
Delete(ctx context.Context, spaceID, agentID int64) (err error)
Update(ctx context.Context, agentInfo *entity.SingleAgent) (err error)
Save(ctx context.Context, agentInfo *entity.SingleAgent) (err error)
GetDisplayInfo(ctx context.Context, userID, agentID int64) (*entity.AgentDraftDisplayInfo, error)
UpdateDisplayInfo(ctx context.Context, userID int64, e *entity.AgentDraftDisplayInfo) error

View File

@ -22,7 +22,7 @@ import (
"time"
"github.com/coze-dev/coze-studio/backend/api/model/app/developer_api"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"

View File

@ -27,7 +27,7 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/internal/agentflow"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/repository"
@ -153,7 +153,7 @@ func (s *singleAgentImpl) UpdateSingleAgentDraft(ctx context.Context, agentInfo
}
}
return s.AgentDraftRepo.Update(ctx, agentInfo)
return s.AgentDraftRepo.Save(ctx, agentInfo)
}
func (s *singleAgentImpl) CreateSingleAgentDraftWithID(ctx context.Context, creatorID, agentID int64, draft *entity.SingleAgent) (int64, error) {

View File

@ -22,8 +22,8 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
resourceCommon "github.com/coze-dev/coze-studio/backend/api/model/resource/common"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"

View File

@ -25,10 +25,10 @@ import (
connectorModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/connector"
databaseModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossconnector"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatabase"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossknowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossplugin"
crossconnector "github.com/coze-dev/coze-studio/backend/crossdomain/contract/connector"
crossdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/contract/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/domain/app/entity"
"github.com/coze-dev/coze-studio/backend/domain/app/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"

View File

@ -37,8 +37,8 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"

View File

@ -27,8 +27,15 @@ type WhereSliceOpt struct {
DocumentID int64
DocumentIDs []int64
Keyword *string
Sequence int64
PageSize int64
Offset int64
NotEmpty *bool
}
type WherePhotoSliceOpt struct {
KnowledgeID int64
DocumentIDs []int64
Limit *int
Offset *int
HasCaption *bool
}

View File

@ -25,6 +25,7 @@ import (
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
@ -52,7 +53,7 @@ func (dao *KnowledgeDocumentSliceDAO) Update(ctx context.Context, slice *model.K
}
func (dao *KnowledgeDocumentSliceDAO) BatchCreate(ctx context.Context, slices []*model.KnowledgeDocumentSlice) error {
return dao.Query.KnowledgeDocumentSlice.WithContext(ctx).CreateInBatches(slices, 100)
return dao.Query.KnowledgeDocumentSlice.WithContext(ctx).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(slices, 100)
}
func (dao *KnowledgeDocumentSliceDAO) BatchSetStatus(ctx context.Context, ids []int64, status int32, reason string) error {
@ -236,8 +237,11 @@ func (dao *KnowledgeDocumentSliceDAO) FindSliceByCondition(ctx context.Context,
if opts.PageSize != 0 {
do = do.Limit(int(opts.PageSize))
do = do.Offset(int(opts.Sequence)).Order(s.Sequence.Asc())
}
if opts.Offset != 0 {
do = do.Offset(int(opts.Offset))
}
do = do.Order(s.Sequence.Asc())
if opts.NotEmpty != nil {
if ptr.From(opts.NotEmpty) {
do = do.Where(s.Content.Neq(""))
@ -319,3 +323,44 @@ func (dao *KnowledgeDocumentSliceDAO) GetLastSequence(ctx context.Context, docum
}
return resp.Sequence, nil
}
func (dao *KnowledgeDocumentSliceDAO) ListPhotoSlice(ctx context.Context, opts *entity.WherePhotoSliceOpt) ([]*model.KnowledgeDocumentSlice, int64, error) {
s := dao.Query.KnowledgeDocumentSlice
do := s.WithContext(ctx)
if opts.KnowledgeID != 0 {
do = do.Where(s.KnowledgeID.Eq(opts.KnowledgeID))
}
if len(opts.DocumentIDs) != 0 {
do = do.Where(s.DocumentID.In(opts.DocumentIDs...))
}
if ptr.From(opts.Limit) != 0 {
do = do.Limit(int(ptr.From(opts.Limit)))
}
if ptr.From(opts.Offset) != 0 {
do = do.Offset(int(ptr.From(opts.Offset)))
}
if opts.HasCaption != nil {
if ptr.From(opts.HasCaption) {
do = do.Where(s.Content.Neq(""))
} else {
do = do.Where(s.Content.Eq(""))
}
}
do = do.Order(s.UpdatedAt.Desc())
pos, err := do.Find()
if err != nil {
return nil, 0, err
}
total, err := do.Limit(-1).Offset(-1).Count()
if err != nil {
return nil, 0, err
}
return pos, total, nil
}
func (dao *KnowledgeDocumentSliceDAO) BatchCreateWithTX(ctx context.Context, tx *gorm.DB, slices []*model.KnowledgeDocumentSlice) error {
if len(slices) == 0 {
return nil
}
return tx.WithContext(ctx).Debug().Model(&model.KnowledgeDocumentSlice{}).CreateInBatches(slices, 100).Error
}

View File

@ -49,8 +49,9 @@ type baseDocProcessor struct {
documentSource *entity.DocumentSource
// Drop DB model
TableName string
docModels []*model.KnowledgeDocument
TableName string
docModels []*model.KnowledgeDocument
imageSlices []*model.KnowledgeDocumentSlice
storage storage.Storage
knowledgeRepo repository.KnowledgeRepo
@ -69,14 +70,14 @@ func (p *baseDocProcessor) BeforeCreate() error {
func (p *baseDocProcessor) BuildDBModel() error {
p.docModels = make([]*model.KnowledgeDocument, 0, len(p.Documents))
ids, err := p.idgen.GenMultiIDs(p.ctx, len(p.Documents))
if err != nil {
logs.CtxErrorf(p.ctx, "gen ids failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeIDGenCode)
}
for i := range p.Documents {
id, err := p.idgen.GenID(p.ctx)
if err != nil {
logs.CtxErrorf(p.ctx, "gen id failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeIDGenCode)
}
docModel := &model.KnowledgeDocument{
ID: ids[i],
ID: id,
KnowledgeID: p.Documents[i].KnowledgeID,
Name: p.Documents[i].Name,
FileExtension: string(p.Documents[i].FileExtension),
@ -95,6 +96,23 @@ func (p *baseDocProcessor) BuildDBModel() error {
}
p.Documents[i].ID = docModel.ID
p.docModels = append(p.docModels, docModel)
if p.Documents[i].Type == knowledge.DocumentTypeImage {
id, err := p.idgen.GenID(p.ctx)
if err != nil {
logs.CtxErrorf(p.ctx, "gen id failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeIDGenCode)
}
p.imageSlices = append(p.imageSlices, &model.KnowledgeDocumentSlice{
ID: id,
KnowledgeID: p.Documents[i].KnowledgeID,
DocumentID: p.Documents[i].ID,
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
CreatorID: p.UserID,
SpaceID: p.SpaceID,
Status: int32(knowledge.SliceStatusInit),
})
}
}
return nil
@ -142,6 +160,11 @@ func (p *baseDocProcessor) InsertDBModel() (err error) {
logs.CtxErrorf(ctx, "create document failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
}
err = p.sliceRepo.BatchCreateWithTX(ctx, tx, p.imageSlices)
if err != nil {
logs.CtxErrorf(ctx, "update knowledge failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
}
err = p.knowledgeRepo.UpdateWithTx(ctx, tx, p.Documents[0].KnowledgeID, map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
})

View File

@ -84,12 +84,12 @@ type KnowledgeDocumentSliceRepo interface {
Create(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
Update(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
Delete(ctx context.Context, slice *model.KnowledgeDocumentSlice) error
BatchCreateWithTX(ctx context.Context, tx *gorm.DB, slices []*model.KnowledgeDocumentSlice) error
BatchCreate(ctx context.Context, slices []*model.KnowledgeDocumentSlice) error
BatchSetStatus(ctx context.Context, ids []int64, status int32, reason string) error
DeleteByDocument(ctx context.Context, documentID int64) error
MGetSlices(ctx context.Context, sliceIDs []int64) ([]*model.KnowledgeDocumentSlice, error)
ListPhotoSlice(ctx context.Context, opts *entity.WherePhotoSliceOpt) ([]*model.KnowledgeDocumentSlice, int64, error)
FindSliceByCondition(ctx context.Context, opts *entity.WhereSliceOpt) (
[]*model.KnowledgeDocumentSlice, int64, error)
GetDocumentSliceIDs(ctx context.Context, docIDs []int64) (sliceIDs []int64, err error)

View File

@ -28,7 +28,7 @@ import (
"golang.org/x/sync/errgroup"
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossdatacopy"
crossdatacopy "github.com/coze-dev/coze-studio/backend/crossdomain/contract/datacopy"
"github.com/coze-dev/coze-studio/backend/domain/datacopy"
copyEntity "github.com/coze-dev/coze-studio/backend/domain/datacopy/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"

View File

@ -35,12 +35,15 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/events"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
progressbarContract "github.com/coze-dev/coze-studio/backend/infra/contract/document/progressbar"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
@ -139,96 +142,166 @@ func (k *knowledgeSVC) indexDocuments(ctx context.Context, event *entity.Event)
return nil
}
type indexDocCacheRecord struct {
ProcessingIDs []int64
LastProcessedNumber int64
ParseUri string
}
const (
indexDocCacheKey = "index_doc_cache:%d:%d"
)
// indexDocumentNew handles the indexing of a new document into the knowledge system
func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (err error) {
doc := event.Document
if doc == nil {
return errorx.New(errno.ErrKnowledgeNonRetryableCode, errorx.KV("reason", "[indexDocument] document not provided"))
}
// 1. The index operations on the same document in the retry queue and the ordinary queue are concurrent, and the same document data is written twice (generated when the backend bugfix is online)
// 2. rebalance repeated consumption of the same message
// check knowledge and document status
if valid, err := k.isWritableKnowledgeAndDocument(ctx, doc.KnowledgeID, doc.ID); err != nil {
return err
} else if !valid {
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[indexDocument] not writable, knowledge_id=%d, document_id=%d", event.KnowledgeID, doc.ID))
errorx.KV("reason", "[indexDocument] document not provided"))
}
defer func() {
if e := recover(); e != nil {
err = errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("panic: %v", e)))
logs.CtxErrorf(ctx, "[indexDocument] panic, err: %v", err)
if setStatusErr := k.documentRepo.SetStatus(ctx, event.Document.ID, int32(entity.DocumentStatusFailed), err.Error()); setStatusErr != nil {
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
}
// Validate document and knowledge status
var valid bool
if valid, err = k.validateDocumentStatus(ctx, doc); err != nil || !valid {
return
}
// Setup error handling and recovery
defer k.handleIndexingErrors(ctx, event, &err)
// Start indexing process
if err = k.beginIndexingProcess(ctx, doc); err != nil {
return
}
// Process document parsing and chunking
var parseResult []*schema.Document
var cacheRecord *indexDocCacheRecord
parseResult, cacheRecord, err = k.processDocumentParsing(ctx, doc)
if err != nil {
return
}
if cacheRecord.LastProcessedNumber == 0 {
if err = k.cleanupPreviousProcessing(ctx, doc); err != nil {
return
}
if err != nil {
var errMsg string
var statusError errorx.StatusError
var status int32
if errors.As(err, &statusError) {
errMsg = errorx.ErrorWithoutStack(statusError)
if statusError.Code() == errno.ErrKnowledgeNonRetryableCode {
status = int32(entity.DocumentStatusFailed)
} else {
status = int32(entity.DocumentStatusChunking)
}
}
// Handle table-type documents specially
if doc.Type == knowledge.DocumentTypeTable {
if err = k.handleTableDocument(ctx, doc, parseResult); err != nil {
return
}
}
// Process document chunks in batches
if err = k.processDocumentChunks(ctx, doc, parseResult, cacheRecord); err != nil {
return
}
// Finalize document indexing
err = k.finalizeDocumentIndexing(ctx, event.Document.KnowledgeID, event.Document.ID)
return
}
// validateDocumentStatus checks if the document can be indexed
func (k *knowledgeSVC) validateDocumentStatus(ctx context.Context, doc *entity.Document) (bool, error) {
valid, err := k.isWritableKnowledgeAndDocument(ctx, doc.KnowledgeID, doc.ID)
if err != nil {
return false, err
}
if !valid {
return false, errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[indexDocument] not writable, knowledge_id=%d, document_id=%d",
doc.KnowledgeID, doc.ID))
}
return true, nil
}
// handleIndexingErrors manages errors and recovery during indexing
func (k *knowledgeSVC) handleIndexingErrors(ctx context.Context, event *entity.Event, err *error) {
if e := recover(); e != nil {
err = ptr.Of(errorx.New(errno.ErrKnowledgeSystemCode,
errorx.KV("msg", fmt.Sprintf("panic: %v", e))))
logs.CtxErrorf(ctx, "[indexDocument] panic, err: %v", err)
k.setDocumentStatus(ctx, event.Document.ID,
int32(entity.DocumentStatusFailed), ptr.From(err).Error())
return
}
if ptr.From(err) != nil {
var status int32
var errMsg string
var statusError errorx.StatusError
if errors.As(ptr.From(err), &statusError) {
errMsg = errorx.ErrorWithoutStack(statusError)
if statusError.Code() == errno.ErrKnowledgeNonRetryableCode {
status = int32(entity.DocumentStatusFailed)
} else {
errMsg = err.Error()
status = int32(entity.DocumentStatusChunking)
}
if setStatusErr := k.documentRepo.SetStatus(ctx, event.Document.ID, status, errMsg); setStatusErr != nil {
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
}
} else {
errMsg = ptr.From(err).Error()
status = int32(entity.DocumentStatusChunking)
}
}()
// clear
collectionName := getCollectionName(doc.KnowledgeID)
if !doc.IsAppend {
ids, err := k.sliceRepo.GetDocumentSliceIDs(ctx, []int64{doc.ID})
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get document slice ids failed, err: %v", err)))
}
if len(ids) > 0 {
if err = k.sliceRepo.DeleteByDocument(ctx, doc.ID); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("delete document slice failed, err: %v", err)))
}
for _, manager := range k.searchStoreManagers {
s, err := manager.GetSearchStore(ctx, collectionName)
if err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
}
if err := s.Delete(ctx, slices.Transform(event.SliceIDs, func(id int64) string {
return strconv.FormatInt(id, 10)
})); err != nil {
logs.Errorf("[indexDocument] delete knowledge failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("delete search store failed, err: %v", err)))
}
}
}
k.setDocumentStatus(ctx, event.Document.ID, status, errMsg)
}
}
// set chunk status
if err = k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusChunking), ""); err != nil {
// beginIndexingProcess starts the indexing process
func (k *knowledgeSVC) beginIndexingProcess(ctx context.Context, doc *entity.Document) error {
err := k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusChunking), "")
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
}
return nil
}
// parse & chunk
// processDocumentParsing handles document parsing and caching
func (k *knowledgeSVC) processDocumentParsing(ctx context.Context, doc *entity.Document) (
[]*schema.Document, *indexDocCacheRecord, error) {
cacheKey := fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)
cacheRecord := &indexDocCacheRecord{}
// Try to get cached parse results
val, err := k.cacheCli.Get(ctx, cacheKey).Result()
if err == nil {
if err = sonic.UnmarshalString(val, &cacheRecord); err != nil {
return nil, nil, errorx.New(errno.ErrKnowledgeParseJSONCode,
errorx.KV("msg", fmt.Sprintf("parse cache record failed, err: %v", err)))
}
}
// Parse document if not cached
if err != nil || len(cacheRecord.ParseUri) == 0 {
return k.parseAndCacheDocument(ctx, doc, cacheRecord, cacheKey)
}
// Load parse results from cache
return k.loadParsedDocument(ctx, cacheRecord)
}
// parseAndCacheDocument parses the document and caches the results
func (k *knowledgeSVC) parseAndCacheDocument(ctx context.Context, doc *entity.Document,
cacheRecord *indexDocCacheRecord, cacheKey string) ([]*schema.Document, *indexDocCacheRecord, error) {
// Get document content from storage
bodyBytes, err := k.storage.GetObject(ctx, doc.URI)
if err != nil {
return errorx.New(errno.ErrKnowledgeGetObjectFailCode, errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
return nil, nil, errorx.New(errno.ErrKnowledgeGetObjectFailCode,
errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
}
// Get appropriate parser for document type
docParser, err := k.parseManager.GetParser(convert.DocumentToParseConfig(doc))
if err != nil {
return errorx.New(errno.ErrKnowledgeGetParserFailCode, errorx.KV("msg", fmt.Sprintf("get parser failed, err: %v", err)))
return nil, nil, errorx.New(errno.ErrKnowledgeGetParserFailCode,
errorx.KV("msg", fmt.Sprintf("get parser failed, err: %v", err)))
}
// Parse document content
parseResult, err := docParser.Parse(ctx, bytes.NewReader(bodyBytes), parser.WithExtraMeta(map[string]any{
document.MetaDataKeyCreatorID: doc.CreatorID,
document.MetaDataKeyExternalStorage: map[string]any{
@ -236,77 +309,303 @@ func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (
},
}))
if err != nil {
return errorx.New(errno.ErrKnowledgeParserParseFailCode, errorx.KV("msg", fmt.Sprintf("parse document failed, err: %v", err)))
return nil, nil, errorx.New(errno.ErrKnowledgeParserParseFailCode,
errorx.KV("msg", fmt.Sprintf("parse document failed, err: %v", err)))
}
if doc.Type == knowledge.DocumentTypeTable {
noData, err := document.GetDocumentsColumnsOnly(parseResult)
if err != nil { // unexpected
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[indexDocument] get table data status failed, err: %v", err))
}
if noData {
parseResult = nil // clear parse result
}
// Cache parse results
if err := k.cacheParseResults(ctx, doc, parseResult, cacheRecord, cacheKey); err != nil {
return nil, nil, err
}
// set id
allIDs := make([]int64, 0, len(parseResult))
for l := 0; l < len(parseResult); l += 100 {
r := min(l+100, len(parseResult))
batchSize := r - l
ids, err := k.idgen.GenMultiIDs(ctx, batchSize)
if err != nil {
return errorx.New(errno.ErrKnowledgeIDGenCode, errorx.KV("msg", fmt.Sprintf("GenMultiIDs failed, err: %v", err)))
}
allIDs = append(allIDs, ids...)
for i := 0; i < batchSize; i++ {
id := ids[i]
index := l + i
parseResult[index].ID = strconv.FormatInt(id, 10)
}
}
return parseResult, cacheRecord, nil
}
convertFn := d2sMapping[doc.Type]
if convertFn == nil {
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convertFn is empty"))
}
// cacheParseResults stores parse results in persistent storage and cache
func (k *knowledgeSVC) cacheParseResults(ctx context.Context, doc *entity.Document,
parseResult []*schema.Document, cacheRecord *indexDocCacheRecord, cacheKey string) error {
sliceEntities, err := slices.TransformWithErrorCheck(parseResult, func(a *schema.Document) (*entity.Slice, error) {
return convertFn(a, doc.KnowledgeID, doc.ID, doc.CreatorID)
})
parseResultData, err := sonic.Marshal(parseResult)
if err != nil {
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeParseJSONCode,
errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
}
// save slices
if doc.Type == knowledge.DocumentTypeTable {
// Table type to insert data into a database
err = k.upsertDataToTable(ctx, &doc.TableInfo, sliceEntities)
fileName := fmt.Sprintf("FileBizType.Knowledge/%d_%d.txt", doc.CreatorID, doc.ID)
if err = k.storage.PutObject(ctx, fileName, parseResultData); err != nil {
return errorx.New(errno.ErrKnowledgePutObjectFailCode,
errorx.KV("msg", fmt.Sprintf("put object failed, err: %v", err)))
}
cacheRecord.ParseUri = fileName
return k.recordIndexDocumentStatus(ctx, cacheRecord, cacheKey)
}
// loadParsedDocument loads previously parsed document from cache
func (k *knowledgeSVC) loadParsedDocument(ctx context.Context,
cacheRecord *indexDocCacheRecord) ([]*schema.Document, *indexDocCacheRecord, error) {
data, err := k.storage.GetObject(ctx, cacheRecord.ParseUri)
if err != nil {
return nil, nil, errorx.New(errno.ErrKnowledgeGetObjectFailCode,
errorx.KV("msg", fmt.Sprintf("get object failed, err: %v", err)))
}
var parseResult []*schema.Document
if err = sonic.Unmarshal(data, &parseResult); err != nil {
return nil, nil, errorx.New(errno.ErrKnowledgeParseJSONCode,
errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
}
return parseResult, cacheRecord, nil
}
// handleTableDocument handles special processing for table-type documents
func (k *knowledgeSVC) handleTableDocument(ctx context.Context,
doc *entity.Document, parseResult []*schema.Document) error {
noData, err := document.GetDocumentsColumnsOnly(parseResult)
if err != nil {
return errorx.New(errno.ErrKnowledgeNonRetryableCode,
errorx.KVf("reason", "[indexDocument] get table data status failed, err: %v", err))
}
if noData {
parseResult = nil // clear parse result
}
return nil
}
// processDocumentChunks processes document chunks in batches
func (k *knowledgeSVC) processDocumentChunks(ctx context.Context,
doc *entity.Document, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord) error {
batchSize := 100
progressbar := progressbar.NewProgressBar(ctx, doc.ID,
int64(len(parseResult)*len(k.searchStoreManagers)), k.cacheCli, true)
if err := progressbar.AddN(int(cacheRecord.LastProcessedNumber) * len(k.searchStoreManagers)); err != nil {
return errorx.New(errno.ErrKnowledgeSystemCode,
errorx.KV("msg", fmt.Sprintf("add progress bar failed, err: %v", err)))
}
// Process chunks in batches
for i := int(cacheRecord.LastProcessedNumber); i < len(parseResult); i += batchSize {
chunks := parseResult[i:min(i+batchSize, len(parseResult))]
if err := k.batchProcessSlice(ctx, doc, i, chunks, cacheRecord, progressbar); err != nil {
return err
}
}
return nil
}
// finalizeDocumentIndexing completes the document indexing process
func (k *knowledgeSVC) finalizeDocumentIndexing(ctx context.Context, knowledgeID, documentID int64) error {
if err := k.documentRepo.SetStatus(ctx, documentID, int32(entity.DocumentStatusEnable), ""); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
}
if err := k.documentRepo.UpdateDocumentSliceInfo(ctx, documentID); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update document slice info failed, err: %v", err)))
}
if err := k.cacheCli.Del(ctx, fmt.Sprintf(indexDocCacheKey, knowledgeID, documentID)).Err(); err != nil {
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("del cache failed, err: %v", err)))
}
return nil
}
// batchProcessSlice processes a batch of document slices
func (k *knowledgeSVC) batchProcessSlice(ctx context.Context, doc *entity.Document,
startIdx int, parseResult []*schema.Document, cacheRecord *indexDocCacheRecord,
progressBar progressbarContract.ProgressBar) error {
collectionName := getCollectionName(doc.KnowledgeID)
length := len(parseResult)
var ids []int64
var err error
// Generate IDs for this batch
if len(cacheRecord.ProcessingIDs) == 0 {
ids, err = k.genMultiIDs(ctx, length)
if err != nil {
return err
}
} else {
ids = cacheRecord.ProcessingIDs
}
for idx := range parseResult {
parseResult[idx].ID = strconv.FormatInt(ids[idx], 10)
}
// Update cache record with processing IDs
cacheRecord.ProcessingIDs = ids
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
return err
}
// Convert documents to slices
sliceEntities, err := k.convertToSlices(doc, parseResult)
if err != nil {
return err
}
// Handle table-type documents
if doc.Type == knowledge.DocumentTypeTable {
if err := k.upsertDataToTable(ctx, &doc.TableInfo, sliceEntities); err != nil {
logs.CtxErrorf(ctx, "[indexDocument] insert data to table failed, err: %v", err)
return err
}
}
// Store slices in database
if err := k.storeSlicesInDB(ctx, doc, parseResult, startIdx, ids); err != nil {
return err
}
// Index slices in search stores
if err := k.indexSlicesInSearchStores(ctx, doc, collectionName, sliceEntities,
cacheRecord, progressBar); err != nil {
return err
}
// Update cache record after successful processing
cacheRecord.LastProcessedNumber = int64(startIdx) + int64(length)
cacheRecord.ProcessingIDs = nil
// Mark slices as done
err = k.sliceRepo.BatchSetStatus(ctx, ids, int32(model.SliceStatusDone), "")
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch set slice status failed, err: %v", err)))
}
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
return err
}
return nil
}
// convertToSlices converts parsed documents to slice entities
func (k *knowledgeSVC) convertToSlices(doc *entity.Document, parseResult []*schema.Document) ([]*entity.Slice, error) {
convertFn := d2sMapping[doc.Type]
if convertFn == nil {
return nil, errorx.New(errno.ErrKnowledgeSystemCode,
errorx.KV("msg", "convertFn is empty"))
}
return slices.TransformWithErrorCheck(parseResult, func(a *schema.Document) (*entity.Slice, error) {
return convertFn(a, doc.KnowledgeID, doc.ID, doc.CreatorID)
})
}
// cleanupPreviousProcessing cleans up partially processed data from previous attempts
func (k *knowledgeSVC) cleanupPreviousProcessing(ctx context.Context, doc *entity.Document) error {
collectionName := getCollectionName(doc.KnowledgeID)
if doc.IsAppend || doc.Type == knowledge.DocumentTypeImage {
return nil
}
ids, err := k.sliceRepo.GetDocumentSliceIDs(ctx, []int64{doc.ID})
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get document slice ids failed, err: %v", err)))
}
if len(ids) > 0 {
if err = k.sliceRepo.DeleteByDocument(ctx, doc.ID); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("delete document slice failed, err: %v", err)))
}
for _, manager := range k.searchStoreManagers {
s, err := manager.GetSearchStore(ctx, collectionName)
if err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
}
if err := s.Delete(ctx, slices.Transform(ids, func(id int64) string {
return strconv.FormatInt(id, 10)
})); err != nil {
logs.Errorf("[indexDocument] delete knowledge failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("delete search store failed, err: %v", err)))
}
}
}
if doc.Type == knowledge.DocumentTypeTable {
_, err := k.rdb.DeleteData(ctx, &rdb.DeleteDataRequest{
TableName: doc.TableInfo.PhysicalTableName,
Where: &rdb.ComplexCondition{
Conditions: []*rdb.Condition{
{
Field: consts.RDBFieldID,
Operator: rdbEntity.OperatorIn,
Value: ids,
},
},
},
})
if err != nil {
logs.CtxErrorf(ctx, "delete data failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeCrossDomainCode, errorx.KV("msg", err.Error()))
}
}
return nil
}
// storeSlicesInDB stores slice data in the database
func (k *knowledgeSVC) storeSlicesInDB(ctx context.Context, doc *entity.Document,
parseResult []*schema.Document, startIdx int, ids []int64) error {
var seqOffset float64
var err error
if doc.IsAppend {
seqOffset, err = k.sliceRepo.GetLastSequence(ctx, doc.ID)
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("get last sequence failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeDBCode,
errorx.KV("msg", fmt.Sprintf("get last sequence failed, err: %v", err)))
}
seqOffset += 1
}
if doc.Type == knowledge.DocumentTypeImage {
if len(parseResult) != 0 {
slices, _, err := k.sliceRepo.FindSliceByCondition(ctx, &entity.WhereSliceOpt{DocumentID: doc.ID})
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("find slice failed, err: %v", err)))
}
var slice *model.KnowledgeDocumentSlice
if len(slices) > 0 {
slice = slices[0]
slice.Content = parseResult[0].Content
} else {
id, err := k.idgen.GenID(ctx)
if err != nil {
return errorx.New(errno.ErrKnowledgeIDGenCode, errorx.KV("msg", fmt.Sprintf("GenID failed, err: %v", err)))
}
slice = &model.KnowledgeDocumentSlice{
ID: id,
KnowledgeID: doc.KnowledgeID,
DocumentID: doc.ID,
Content: parseResult[0].Content,
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
CreatorID: doc.CreatorID,
SpaceID: doc.SpaceID,
Status: int32(model.SliceStatusProcessing),
FailReason: "",
}
}
if err = k.sliceRepo.Update(ctx, slice); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update slice failed, err: %v", err)))
}
}
return nil
}
sliceModels := make([]*model.KnowledgeDocumentSlice, 0, len(parseResult))
for i, src := range parseResult {
now := time.Now().UnixMilli()
sliceModel := &model.KnowledgeDocumentSlice{
ID: allIDs[i],
ID: ids[i],
KnowledgeID: doc.KnowledgeID,
DocumentID: doc.ID,
Content: parseResult[i].Content,
Sequence: seqOffset + float64(i),
Sequence: seqOffset + float64(i+startIdx),
CreatedAt: now,
UpdatedAt: now,
CreatorID: doc.CreatorID,
@ -314,83 +613,108 @@ func (k *knowledgeSVC) indexDocument(ctx context.Context, event *entity.Event) (
Status: int32(model.SliceStatusProcessing),
FailReason: "",
}
if doc.Type == knowledge.DocumentTypeTable {
convertFn := d2sMapping[doc.Type]
sliceEntity, err := convertFn(src, doc.KnowledgeID, doc.ID, doc.CreatorID)
if err != nil {
logs.CtxErrorf(ctx, "[indexDocument] convert document failed, err: %v", err)
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeSystemCode,
errorx.KV("msg", fmt.Sprintf("convert document failed, err: %v", err)))
}
sliceModel.Content = sliceEntity.GetSliceContent()
}
sliceModels = append(sliceModels, sliceModel)
}
if err = k.sliceRepo.BatchCreate(ctx, sliceModels); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch create slice failed, err: %v", err)))
err = k.sliceRepo.BatchCreate(ctx, sliceModels)
if err != nil {
return errorx.New(errno.ErrKnowledgeDBCode,
errorx.KV("msg", fmt.Sprintf("batch create slice failed, err: %v", err)))
}
return nil
}
defer func() {
if err != nil { // set slice status
if setStatusErr := k.sliceRepo.BatchSetStatus(ctx, allIDs, int32(model.SliceStatusFailed), err.Error()); setStatusErr != nil {
logs.CtxErrorf(ctx, "[indexDocument] set slice status failed, err: %v", setStatusErr)
}
}
}()
// indexSlicesInSearchStores indexes slices in appropriate search stores
func (k *knowledgeSVC) indexSlicesInSearchStores(ctx context.Context, doc *entity.Document,
collectionName string, sliceEntities []*entity.Slice, cacheRecord *indexDocCacheRecord,
progressBar progressbarContract.ProgressBar) error {
// to vectorstore
fields, err := k.mapSearchFields(doc)
if err != nil {
return err
}
indexingFields := getIndexingFields(fields)
// reformat docs, mainly for enableCompactTable
// Convert slices to search documents
ssDocs, err := slices.TransformWithErrorCheck(sliceEntities, func(a *entity.Slice) (*schema.Document, error) {
return k.slice2Document(ctx, doc, a)
})
if err != nil {
return errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", fmt.Sprintf("reformat document failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeSystemCode,
errorx.KV("msg", fmt.Sprintf("reformat document failed, err: %v", err)))
}
progressbar := progressbar.NewProgressBar(ctx, doc.ID, int64(len(ssDocs)*len(k.searchStoreManagers)), k.cacheCli, true)
// Skip if it's an image document with empty content
if doc.Type == knowledge.DocumentTypeImage && len(ssDocs) == 1 && len(ssDocs[0].Content) == 0 {
return nil
}
// Index in each search store manager
for _, manager := range k.searchStoreManagers {
now := time.Now()
if err = manager.Create(ctx, &searchstore.CreateRequest{
if err := manager.Create(ctx, &searchstore.CreateRequest{
CollectionName: collectionName,
Fields: fields,
CollectionMeta: nil,
}); err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("create search store failed, err: %v", err)))
}
// Picture knowledge base kn: doc: slice = 1: n: n, maybe the content is empty, no need to write
if doc.Type == knowledge.DocumentTypeImage && len(ssDocs) == 1 && len(ssDocs[0].Content) == 0 {
continue
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
errorx.KV("msg", fmt.Sprintf("create search store failed, err: %v", err)))
}
ss, err := manager.GetSearchStore(ctx, collectionName)
if err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
errorx.KV("msg", fmt.Sprintf("get search store failed, err: %v", err)))
}
if _, err = ss.Store(ctx, ssDocs,
searchstore.WithIndexerPartitionKey(fieldNameDocumentID),
searchstore.WithPartition(strconv.FormatInt(doc.ID, 10)),
searchstore.WithIndexingFields(indexingFields),
searchstore.WithProgressBar(progressbar),
searchstore.WithProgressBar(progressBar),
); err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", fmt.Sprintf("store search store failed, err: %v", err)))
return errorx.New(errno.ErrKnowledgeSearchStoreCode,
errorx.KV("msg", fmt.Sprintf("store search store failed, err: %v", err)))
}
logs.CtxDebugf(ctx, "[indexDocument] ss type=%v, len(docs)=%d, finished after %d ms",
manager.GetType(), len(ssDocs), time.Now().Sub(now).Milliseconds())
}
// set slice status
if err = k.sliceRepo.BatchSetStatus(ctx, allIDs, int32(model.SliceStatusDone), ""); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("batch set slice status failed, err: %v", err)))
if err := k.recordIndexDocumentStatus(ctx, cacheRecord,
fmt.Sprintf(indexDocCacheKey, doc.KnowledgeID, doc.ID)); err != nil {
return err
}
}
// set document status
return nil
}
if err = k.documentRepo.SetStatus(ctx, doc.ID, int32(entity.DocumentStatusEnable), ""); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("set document status failed, err: %v", err)))
// setDocumentStatus updates document status with error handling
func (k *knowledgeSVC) setDocumentStatus(ctx context.Context, docID int64, status int32, errMsg string) {
if setStatusErr := k.documentRepo.SetStatus(ctx, docID, status, errMsg); setStatusErr != nil {
logs.CtxErrorf(ctx, "[indexDocument] set document status failed, err: %v", setStatusErr)
}
if err = k.documentRepo.UpdateDocumentSliceInfo(ctx, event.Document.ID); err != nil {
return errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", fmt.Sprintf("update document slice info failed, err: %v", err)))
}
func (k *knowledgeSVC) recordIndexDocumentStatus(ctx context.Context, r *indexDocCacheRecord, cacheKey string) error {
data, err := sonic.Marshal(r)
if err != nil {
return errorx.New(errno.ErrKnowledgeParseJSONCode, errorx.KV("msg", fmt.Sprintf("marshal parse result failed, err: %v", err)))
}
err = k.cacheCli.Set(ctx, cacheKey, data, time.Hour*2).Err()
if err != nil {
return errorx.New(errno.ErrKnowledgeCacheClientSetFailCode, errorx.KV("msg", fmt.Sprintf("set cache failed, err: %v", err)))
}
return nil
}

View File

@ -217,9 +217,10 @@ type RetrieveContext struct {
}
type KnowledgeInfo struct {
DocumentIDs []int64
DocumentType knowledge.DocumentType
TableColumns []*entity.TableColumn
KnowledgeName string
DocumentIDs []int64
DocumentType knowledge.DocumentType
TableColumns []*entity.TableColumn
}
type AlterTableSchemaRequest struct {
DocumentID int64

View File

@ -58,9 +58,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
rdbEntity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/progressbar"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
@ -70,50 +68,42 @@ import (
func NewKnowledgeSVC(config *KnowledgeSVCConfig) (Knowledge, eventbus.ConsumerHandler) {
svc := &knowledgeSVC{
knowledgeRepo: repository.NewKnowledgeDAO(config.DB),
documentRepo: repository.NewKnowledgeDocumentDAO(config.DB),
sliceRepo: repository.NewKnowledgeDocumentSliceDAO(config.DB),
reviewRepo: repository.NewKnowledgeDocumentReviewDAO(config.DB),
idgen: config.IDGen,
rdb: config.RDB,
producer: config.Producer,
searchStoreManagers: config.SearchStoreManagers,
parseManager: config.ParseManager,
storage: config.Storage,
reranker: config.Reranker,
rewriter: config.Rewriter,
nl2Sql: config.NL2Sql,
enableCompactTable: ptr.FromOrDefault(config.EnableCompactTable, true),
cacheCli: config.CacheCli,
isAutoAnnotationSupported: config.IsAutoAnnotationSupported,
modelFactory: config.ModelFactory,
}
if svc.reranker == nil {
svc.reranker = rrf.NewRRFReranker(0)
}
if svc.parseManager == nil {
svc.parseManager = builtin.NewManager(config.Storage, config.OCR, nil)
knowledgeRepo: repository.NewKnowledgeDAO(config.DB),
documentRepo: repository.NewKnowledgeDocumentDAO(config.DB),
sliceRepo: repository.NewKnowledgeDocumentSliceDAO(config.DB),
reviewRepo: repository.NewKnowledgeDocumentReviewDAO(config.DB),
idgen: config.IDGen,
rdb: config.RDB,
producer: config.Producer,
searchStoreManagers: config.SearchStoreManagers,
parseManager: config.ParseManager,
storage: config.Storage,
reranker: config.Reranker,
rewriter: config.Rewriter,
nl2Sql: config.NL2Sql,
enableCompactTable: ptr.FromOrDefault(config.EnableCompactTable, true),
cacheCli: config.CacheCli,
modelFactory: config.ModelFactory,
}
return svc, svc
}
type KnowledgeSVCConfig struct {
DB *gorm.DB // required
IDGen idgen.IDGenerator // required
RDB rdb.RDB // Required: Form storage
Producer eventbus.Producer // Required: Document indexing process goes through mq asynchronous processing
SearchStoreManagers []searchstore.Manager // Required: Vector/Full Text
ParseManager parser.Manager // Optional: document segmentation and processing capability, default builtin parser
Storage storage.Storage // required: oss
ModelFactory chatmodel.Factory // Required: Model factory
Rewriter messages2query.MessagesToQuery // Optional: Do not overwrite when not configured
Reranker rerank.Reranker // Optional: default rrf when not configured
NL2Sql nl2sql.NL2SQL // Optional: Not supported by default when not configured
EnableCompactTable *bool // Optional: Table data compression, default true
OCR ocr.OCR // Optional: ocr, ocr function is not available when not provided
CacheCli cache.Cmdable // Optional: cache implementation
IsAutoAnnotationSupported bool // Does it support automatic image labeling?
DB *gorm.DB // required
IDGen idgen.IDGenerator // required
RDB rdb.RDB // Required: Form storage
Producer eventbus.Producer // Required: Document indexing process goes through mq asynchronous processing
SearchStoreManagers []searchstore.Manager // Required: Vector/Full Text
ParseManager parser.Manager // Optional: document segmentation and processing capability, default builtin parser
Storage storage.Storage // required: oss
ModelFactory chatmodel.Factory // Required: Model factory
Rewriter messages2query.MessagesToQuery // Optional: Do not overwrite when not configured
Reranker rerank.Reranker // Optional: default rrf when not configured
NL2Sql nl2sql.NL2SQL // Optional: Not supported by default when not configured
EnableCompactTable *bool // Optional: Table data compression, default true
OCR ocr.OCR // Optional: ocr, ocr function is not available when not provided
CacheCli cache.Cmdable // Optional: cache implementation
}
type knowledgeSVC struct {
@ -123,18 +113,17 @@ type knowledgeSVC struct {
reviewRepo repository.KnowledgeDocumentReviewRepo
modelFactory chatmodel.Factory
idgen idgen.IDGenerator
rdb rdb.RDB
producer eventbus.Producer
searchStoreManagers []searchstore.Manager
parseManager parser.Manager
rewriter messages2query.MessagesToQuery
reranker rerank.Reranker
storage storage.Storage
nl2Sql nl2sql.NL2SQL
cacheCli cache.Cmdable
enableCompactTable bool // Table data compression
isAutoAnnotationSupported bool // Does it support automatic image labeling?
idgen idgen.IDGenerator
rdb rdb.RDB
producer eventbus.Producer
searchStoreManagers []searchstore.Manager
parseManager parser.Manager
rewriter messages2query.MessagesToQuery
reranker rerank.Reranker
storage storage.Storage
nl2Sql nl2sql.NL2SQL
cacheCli cache.Cmdable
enableCompactTable bool // Table data compression
}
func (k *knowledgeSVC) CreateKnowledge(ctx context.Context, request *CreateKnowledgeRequest) (response *CreateKnowledgeResponse, err error) {
@ -318,7 +307,7 @@ func (k *knowledgeSVC) checkRequest(request *CreateDocumentRequest) error {
}
for i := range request.Documents {
if request.Documents[i].Type == knowledgeModel.DocumentTypeImage && ptr.From(request.Documents[i].ParsingStrategy.CaptionType) == parser.ImageAnnotationTypeModel {
if !k.isAutoAnnotationSupported {
if !k.parseManager.IsAutoAnnotationSupported() {
return errors.New("auto caption type is not supported")
}
}
@ -887,9 +876,8 @@ func (k *knowledgeSVC) ListSlice(ctx context.Context, request *ListSliceRequest)
KnowledgeID: ptr.From(request.KnowledgeID),
DocumentID: ptr.From(request.DocumentID),
Keyword: request.Keyword,
Sequence: request.Sequence,
Offset: request.Sequence,
PageSize: request.Limit,
Offset: request.Offset,
})
if err != nil {
logs.CtxErrorf(ctx, "list slice failed, err: %v", err)
@ -1386,12 +1374,12 @@ func (k *knowledgeSVC) ListPhotoSlice(ctx context.Context, request *ListPhotoSli
if request == nil {
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is empty"))
}
sliceArr, total, err := k.sliceRepo.FindSliceByCondition(ctx, &entity.WhereSliceOpt{
sliceArr, total, err := k.sliceRepo.ListPhotoSlice(ctx, &entity.WherePhotoSliceOpt{
KnowledgeID: request.KnowledgeID,
DocumentIDs: request.DocumentIDs,
Offset: int64(ptr.From(request.Offset)),
PageSize: int64(ptr.From(request.Limit)),
NotEmpty: request.HasCaption,
Offset: request.Offset,
Limit: request.Limit,
HasCaption: request.HasCaption,
})
if err != nil {
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
@ -1411,7 +1399,7 @@ func (k *knowledgeSVC) ExtractPhotoCaption(ctx context.Context, request *Extract
if request == nil {
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is empty"))
}
if !k.isAutoAnnotationSupported {
if !k.parseManager.IsAutoAnnotationSupported() {
return nil, errorx.New(errno.ErrKnowledgeAutoAnnotationNotSupportedCode, errorx.KV("msg", "auto annotation is not supported"))
}
docInfo, err := k.documentRepo.GetByID(ctx, request.DocumentID)

View File

@ -41,6 +41,8 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
sses "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
ssmilvus "github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
hembed "github.com/coze-dev/coze-studio/backend/infra/impl/embedding/http"
@ -169,10 +171,10 @@ func (suite *KnowledgeTestSuite) SetupSuite() {
RDB: rdbService,
Producer: knowledgeProducer,
SearchStoreManagers: mgrs,
ParseManager: nil, // default builtin
ParseManager: builtin.NewManager(tosClient, nil, nil), // default builtin
Storage: tosClient,
Rewriter: nil,
Reranker: nil, // default rrf
Reranker: rrf.NewRRFReranker(0), // default rrf
EnableCompactTable: ptr.Of(true),
})

View File

@ -32,6 +32,8 @@ import (
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
"github.com/coze-dev/coze-studio/backend/infra/impl/rdb"
producerMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/eventbus"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
@ -98,11 +100,14 @@ func MockKnowledgeSVC(t *testing.T) Knowledge {
mockStorage.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
rdb := rdb.NewService(db, mockIDGen)
svc, _ := NewKnowledgeSVC(&KnowledgeSVCConfig{
DB: db,
IDGen: mockIDGen,
Storage: mockStorage,
Producer: producer,
RDB: rdb,
DB: db,
IDGen: mockIDGen,
Storage: mockStorage,
Producer: producer,
RDB: rdb,
Reranker: rrf.NewRRFReranker(0),
ParseManager: builtin.NewManager(mockStorage, nil, nil), // default builtin
})
return svc
}

View File

@ -18,11 +18,9 @@ package service
import (
"context"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"unicode/utf8"
@ -50,6 +48,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -127,6 +126,7 @@ func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequ
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name
}
}
for _, doc := range enableDocs {
@ -189,7 +189,7 @@ func (k *knowledgeSVC) prepareRAGDocuments(ctx context.Context, documentIDs []in
}
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, req *RetrieveContext) (newRetrieveContext *RetrieveContext, err error) {
if len(req.ChatHistory) == 0 {
if len(req.ChatHistory) == 1 {
// No context, no rewriting.
return req, nil
}
@ -390,6 +390,10 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
}
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
}
virtualColumnMap := map[string]*entity.TableColumn{}
for i := range doc.TableInfo.Columns {
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
}
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
if err != nil {
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
@ -404,15 +408,52 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
return nil, err
}
for i := range resp.ResultSet.Rows {
d := &schema.Document{
Content: "",
MetaData: map[string]any{
"document_id": doc.ID,
"document_name": doc.Name,
"knowledge_id": doc.KnowledgeID,
"knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName,
},
}
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
if !ok {
logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
return nil, errors.New("convert id failed")
}
d := &schema.Document{
ID: strconv.FormatInt(id, 10),
Content: "",
MetaData: map[string]any{},
byteData, err := sonic.Marshal(resp.ResultSet.Rows)
if err != nil {
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
return nil, err
}
prefix := "sql:" + sql + ";result:"
d.Content = prefix + string(byteData)
} else {
transferMap := map[string]string{}
for cName, val := range resp.ResultSet.Rows[i] {
column, found := virtualColumnMap[cName]
if !found {
logs.CtxInfof(ctx, "column not found, name: %s", cName)
continue
}
columnData, err := convert.ParseAnyData(column, val)
if err != nil {
logs.CtxErrorf(ctx, "parse any data failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeColumnParseFailCode, errorx.KV("msg", err.Error()))
}
if columnData.Type == document.TableColumnTypeString {
columnData.ValString = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
}
if columnData.Type == document.TableColumnTypeImage {
columnData.ValImage = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
}
transferMap[column.Name] = columnData.GetNullableStringValue()
}
byteData, err := sonic.Marshal(transferMap)
if err != nil {
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
return nil, err
}
d.Content = string(byteData)
d.ID = strconv.FormatInt(id, 10)
}
d.WithScore(1)
retrieveResult = append(retrieveResult, d)
@ -423,29 +464,13 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
const pkID = "_knowledge_slice_id"
func addSliceIdColumn(originalSql string) string {
lowerSql := strings.ToLower(originalSql)
selectIndex := strings.Index(lowerSql, "select ")
if selectIndex == -1 {
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
if err != nil {
logs.Errorf("add slice id column failed: %v", err)
return originalSql
}
result := originalSql[:selectIndex+len("select ")] // Keep selected part
remainder := originalSql[selectIndex+len("select "):]
lowerRemainder := strings.ToLower(remainder)
fromIndex := strings.Index(lowerRemainder, " from")
if fromIndex == -1 {
return originalSql
}
columns := strings.TrimSpace(remainder[:fromIndex])
if columns != "*" {
columns += ", " + pkID
}
result += columns + remainder[fromIndex:]
return result
return sql
}
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
res := &document.TableSchema{}
if doc.TableInfo == nil {
@ -561,18 +586,39 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
sliceIDs := make(sets.Set[int64])
docIDs := make(sets.Set[int64])
knowledgeIDs := make(sets.Set[int64])
results = []*knowledgeModel.RetrieveSlice{}
documentMap := map[int64]*model.KnowledgeDocument{}
knowledgeMap := map[int64]*model.Knowledge{}
sliceScoreMap := map[int64]float64{}
for _, doc := range retrieveResult {
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
logs.CtxErrorf(ctx, "convert id failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
if len(doc.ID) == 0 {
results = append(results, &knowledgeModel.RetrieveSlice{
Slice: &knowledgeModel.Slice{
KnowledgeID: doc.MetaData["knowledge_id"].(int64),
DocumentID: doc.MetaData["document_id"].(int64),
DocumentName: doc.MetaData["document_name"].(string),
RawContent: []*knowledgeModel.SliceContent{
{
Type: knowledgeModel.SliceContentTypeText,
Text: ptr.Of(doc.Content),
},
},
Extra: map[string]string{
consts.KnowledgeName: doc.MetaData["knowledge_name"].(string),
consts.DocumentURL: "",
},
},
Score: 1,
})
} else {
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
logs.CtxErrorf(ctx, "convert id failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
}
sliceIDs[id] = struct{}{}
sliceScoreMap[id] = doc.Score()
}
sliceIDs[id] = struct{}{}
sliceScoreMap[id] = doc.Score()
}
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
if err != nil {
@ -625,7 +671,6 @@ func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema
return nil, err
}
}
results = []*knowledgeModel.RetrieveSlice{}
for i := range slices {
doc := documentMap[slices[i].DocumentID]
kn := knowledgeMap[slices[i].KnowledgeID]

View File

@ -0,0 +1,237 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"errors"
"os"
"strings"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/repository"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
rdb_entity "github.com/coze-dev/coze-studio/backend/infra/contract/rdb/entity"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/nl2sql"
mock_db "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/rdb"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
)
func TestAddSliceIdColumn(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple select",
input: "SELECT name, age FROM users",
expected: "SELECT `name`,`age`,`_knowledge_slice_id` FROM `users`",
},
{
name: "select stmt wrong",
input: "SELECT FROM users",
expected: "SELECT FROM users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := addSliceIdColumn(tt.input)
if actual != tt.expected {
t.Errorf("AddSliceIdColumn() = %v, want %v", actual, tt.expected)
}
})
}
}
func TestNL2sqlExec(t *testing.T) {
svc := knowledgeSVC{}
ctrl := gomock.NewController(t)
db := mock_db.NewMockRDB(ctrl)
nl2SQL := mock.NewMockNL2SQL(ctrl)
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select count(*) from users", nil
})
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return &rdb.ExecuteSQLResponse{
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
{
"count(*)": 100,
},
}},
}, nil
})
svc.nl2Sql = nl2SQL
svc.rdb = db
ctx := context.Background()
docu := model.KnowledgeDocument{
ID: 110,
KnowledgeID: 111,
Name: "users",
FileExtension: "xlsx",
DocumentType: 1,
CreatorID: 666,
SpaceID: 666,
Status: 1,
TableInfo: &entity.TableInfo{
VirtualTableName: "users",
PhysicalTableName: "table_111",
TableDesc: "user table",
Columns: []*entity.TableColumn{
{
ID: 1,
Name: "_knowledge_slice_id",
Type: document.TableColumnTypeInteger,
Description: "id",
Indexing: false,
Sequence: 1,
},
{
ID: 2,
Name: "name",
Type: document.TableColumnTypeString,
Description: "name",
Indexing: true,
Sequence: 2,
},
},
},
}
retrieveCtx := &RetrieveContext{
Ctx: ctx,
OriginQuery: "select count(*) from users",
KnowledgeIDs: sets.FromSlice[int64]([]int64{111}),
Documents: []*model.KnowledgeDocument{&docu},
KnowledgeInfoMap: map[int64]*KnowledgeInfo{
111: &KnowledgeInfo{
KnowledgeName: "users",
DocumentIDs: []int64{110},
DocumentType: 1,
TableColumns: []*entity.TableColumn{
{
ID: 1,
Name: "_knowledge_slice_id",
Type: document.TableColumnTypeInteger,
Description: "id",
Indexing: false,
Sequence: 1,
},
{
ID: 2,
Name: "name",
Type: document.TableColumnTypeString,
Description: "name",
Indexing: true,
Sequence: 2,
},
},
},
},
}
docs, err := svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(docs))
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", docs[0].Content)
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "", errors.New("nl2sql error")
})
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, "nl2sql error", err.Error())
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return nil, errors.New("rdb error")
})
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select count(*) from users", nil
})
_, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, "rdb error", err.Error())
db.EXPECT().ExecuteSQL(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
return &rdb.ExecuteSQLResponse{
ResultSet: &rdb_entity.ResultSet{Rows: []map[string]interface{}{
{
"name": "666",
"_knowledge_document_slice_id": int64(999),
},
}},
}, nil
})
nl2SQL.EXPECT().NL2SQL(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
return "select name from users", nil
})
docs, err = svc.nl2SqlExec(ctx, &docu, retrieveCtx, nil)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(docs))
assert.Equal(t, "999", docs[0].ID)
}
func TestPackResults(t *testing.T) {
svc := knowledgeSVC{}
ctx := context.Background()
svc.packResults(ctx, []*schema.Document{})
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
if os.Getenv("CI_JOB_NAME") != "" {
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
}
gormDB, err := gorm.Open(mysql.Open(dsn))
assert.Equal(t, nil, err)
svc.knowledgeRepo = repository.NewKnowledgeDAO(gormDB)
svc.documentRepo = repository.NewKnowledgeDocumentDAO(gormDB)
svc.sliceRepo = repository.NewKnowledgeDocumentSliceDAO(gormDB)
docs := []*schema.Document{
{
ID: "",
Content: "sql:select count(*) from users;result:[{\"count(*)\":100}]",
MetaData: map[string]any{
"knowledge_id": int64(111),
"document_id": int64(110),
"document_name": "users",
"knowledge_name": "users",
},
},
}
res, err := svc.packResults(ctx, docs)
assert.Equal(t, nil, err)
assert.Equal(t, 1, len(res))
assert.Equal(t, "sql:select count(*) from users;result:[{\"count(*)\":100}]", ptr.From(res[0].Slice.RawContent[0].Text))
docs = []*schema.Document{
{
ID: "10000",
Content: "",
MetaData: map[string]any{
"knowledge_id": int64(111),
"document_id": int64(110),
"document_name": "users",
"knowledge_name": "users",
},
},
}
res, err = svc.packResults(ctx, docs)
assert.Equal(t, 0, len(res))
assert.Equal(t, nil, err)
}

View File

@ -36,7 +36,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
entity2 "github.com/coze-dev/coze-studio/backend/domain/memory/database/entity"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/convertor"
"github.com/coze-dev/coze-studio/backend/domain/memory/database/internal/dal/query"

View File

@ -263,28 +263,28 @@ func MustString(value any) string {
}
}
func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
func TryCorrectValueType(paramName string, schemaRef *openapi3.SchemaRef, value any) (any, error) {
if value == nil {
return "", fmt.Errorf("value of '%s' is nil", paramName)
}
switch schemaRef.Value.Type {
case openapi3.TypeString:
return tryString(value)
return tryCorrectString(value)
case openapi3.TypeNumber:
return tryFloat64(value)
return tryCorrectFloat64(value)
case openapi3.TypeInteger:
return tryInt64(value)
return tryCorrectInt64(value)
case openapi3.TypeBoolean:
return tryBool(value)
return tryCorrectBool(value)
case openapi3.TypeArray:
arrVal, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("[TryFixValueType] value '%s' is not array", paramName)
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not array", paramName)
}
for i, v := range arrVal {
_v, err := TryFixValueType(paramName, schemaRef.Value.Items, v)
_v, err := TryCorrectValueType(paramName, schemaRef.Value.Items, v)
if err != nil {
return nil, err
}
@ -296,7 +296,7 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
case openapi3.TypeObject:
mapVal, ok := value.(map[string]any)
if !ok {
return nil, fmt.Errorf("[TryFixValueType] value '%s' is not object", paramName)
return nil, fmt.Errorf("[TryCorrectValueType] value '%s' is not object", paramName)
}
for k, v := range mapVal {
@ -305,7 +305,7 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
continue
}
_v, err := TryFixValueType(k, p, v)
_v, err := TryCorrectValueType(k, p, v)
if err != nil {
return nil, err
}
@ -315,11 +315,11 @@ func TryFixValueType(paramName string, schemaRef *openapi3.SchemaRef, value any)
return mapVal, nil
default:
return nil, fmt.Errorf("[TryFixValueType] unsupported schema type '%s'", schemaRef.Value.Type)
return nil, fmt.Errorf("[TryCorrectValueType] unsupported schema type '%s'", schemaRef.Value.Type)
}
}
func tryString(value any) (string, error) {
func tryCorrectString(value any) (string, error) {
switch val := value.(type) {
case string:
return val, nil
@ -331,11 +331,15 @@ func tryString(value any) (string, error) {
case json.Number:
return val.String(), nil
default:
return "", fmt.Errorf("cannot convert type from '%T' to string", val)
b, err := sonic.MarshalString(value)
if err != nil {
return "", fmt.Errorf("tryCorrectString failed, err=%w", err)
}
return b, nil
}
}
func tryInt64(value any) (int64, error) {
func tryCorrectInt64(value any) (int64, error) {
switch val := value.(type) {
case string:
vi64, _ := strconv.ParseInt(val, 10, 64)
@ -352,7 +356,7 @@ func tryInt64(value any) (int64, error) {
}
}
func tryBool(value any) (bool, error) {
func tryCorrectBool(value any) (bool, error) {
switch val := value.(type) {
case string:
return strconv.ParseBool(val)
@ -363,7 +367,7 @@ func tryBool(value any) (bool, error) {
}
}
func tryFloat64(value any) (float64, error) {
func tryCorrectFloat64(value any) (float64, error) {
switch val := value.(type) {
case string:
return strconv.ParseFloat(val, 64)

View File

@ -37,7 +37,7 @@ import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/variables"
"github.com/coze-dev/coze-studio/backend/api/model/data/variable/project_memory"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossvariables"
crossvariables "github.com/coze-dev/coze-studio/backend/crossdomain/contract/variables"
"github.com/coze-dev/coze-studio/backend/domain/plugin/entity"
"github.com/coze-dev/coze-studio/backend/domain/plugin/internal/encoder"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
@ -1100,15 +1100,15 @@ func (t *toolExecutor) processWithInvalidRespProcessStrategyOfReturnErr(_ contex
processor = func(paramName string, paramVal any, schemaVal *openapi3.Schema) (any, error) {
switch schemaVal.Type {
case openapi3.TypeObject:
newParamValMap := map[string]any{}
paramValMap, ok := paramVal.(map[string]any)
if !ok {
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey,
"expected '%s' to be of type 'object', but got '%T'", paramName, paramVal))
}
newParamValMap := map[string]any{}
for paramName_, paramVal_ := range paramValMap {
paramSchema_, ok := schemaVal.Properties[paramName]
paramSchema_, ok := schemaVal.Properties[paramName_]
if !ok || t.disabledParam(paramSchema_.Value) { // Only the object field can be disabled, and the top level of request and response must be the object structure
continue
}
@ -1122,13 +1122,13 @@ func (t *toolExecutor) processWithInvalidRespProcessStrategyOfReturnErr(_ contex
return newParamValMap, nil
case openapi3.TypeArray:
newParamValSlice := []any{}
paramValSlice, ok := paramVal.([]any)
if !ok {
return nil, errorx.New(errno.ErrPluginExecuteToolFailed, errorx.KVf(errno.PluginMsgKey,
"expected '%s' to be of type 'array', but got '%T'", paramName, paramVal))
}
newParamValSlice := []any{}
for _, paramVal_ := range paramValSlice {
newParamVal, err := processor(paramName, paramVal_, schemaVal.Items.Value)
if err != nil {
@ -1426,7 +1426,7 @@ func (t *toolExecutor) buildRequestBody(ctx context.Context, op *model.Openapi3O
continue
}
_value, err := encoder.TryFixValueType(paramName, prop, value)
_value, err := encoder.TryCorrectValueType(paramName, prop, value)
if err != nil {
return nil, "", err
}

Some files were not shown because too many files have changed in this diff Show More